diff options
author | shadchin <shadchin@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
commit | e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch) | |
tree | 64175d5cadab313b3e7039ebaa06c5bc3295e274 /contrib/python/s3transfer/py3 | |
parent | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff) | |
download | ydb-e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0.tar.gz |
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 2 of 2.
Diffstat (limited to 'contrib/python/s3transfer/py3')
50 files changed, 13685 insertions, 13685 deletions
diff --git a/contrib/python/s3transfer/py3/.dist-info/METADATA b/contrib/python/s3transfer/py3/.dist-info/METADATA index 2805466e80..7d635068d7 100644 --- a/contrib/python/s3transfer/py3/.dist-info/METADATA +++ b/contrib/python/s3transfer/py3/.dist-info/METADATA @@ -1,42 +1,42 @@ -Metadata-Version: 2.1 -Name: s3transfer -Version: 0.5.1 -Summary: An Amazon S3 Transfer Manager -Home-page: https://github.com/boto/s3transfer -Author: Amazon Web Services -Author-email: kyknapp1@gmail.com -License: Apache License 2.0 -Platform: UNKNOWN -Classifier: Development Status :: 3 - Alpha -Classifier: Intended Audience :: Developers -Classifier: Natural Language :: English -Classifier: License :: OSI Approved :: Apache Software License -Classifier: Programming Language :: Python -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.6 -Classifier: Programming Language :: Python :: 3.7 -Classifier: Programming Language :: Python :: 3.8 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Requires-Python: >= 3.6 -License-File: LICENSE.txt -License-File: NOTICE.txt -Requires-Dist: botocore (<2.0a.0,>=1.12.36) -Provides-Extra: crt -Requires-Dist: botocore[crt] (<2.0a.0,>=1.20.29) ; extra == 'crt' - -===================================================== -s3transfer - An Amazon S3 Transfer Manager for Python -===================================================== - -S3transfer is a Python library for managing Amazon S3 transfers. -This project is maintained and published by Amazon Web Services. - -.. note:: - - This project is not currently GA. If you are planning to use this code in - production, make sure to lock to a minor version as interfaces may break - from minor version to minor version. For a basic, stable interface of - s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__ - - +Metadata-Version: 2.1 +Name: s3transfer +Version: 0.5.1 +Summary: An Amazon S3 Transfer Manager +Home-page: https://github.com/boto/s3transfer +Author: Amazon Web Services +Author-email: kyknapp1@gmail.com +License: Apache License 2.0 +Platform: UNKNOWN +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Developers +Classifier: Natural Language :: English +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Requires-Python: >= 3.6 +License-File: LICENSE.txt +License-File: NOTICE.txt +Requires-Dist: botocore (<2.0a.0,>=1.12.36) +Provides-Extra: crt +Requires-Dist: botocore[crt] (<2.0a.0,>=1.20.29) ; extra == 'crt' + +===================================================== +s3transfer - An Amazon S3 Transfer Manager for Python +===================================================== + +S3transfer is a Python library for managing Amazon S3 transfers. +This project is maintained and published by Amazon Web Services. + +.. note:: + + This project is not currently GA. If you are planning to use this code in + production, make sure to lock to a minor version as interfaces may break + from minor version to minor version. For a basic, stable interface of + s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__ + + diff --git a/contrib/python/s3transfer/py3/.dist-info/top_level.txt b/contrib/python/s3transfer/py3/.dist-info/top_level.txt index 305b80c924..572c6a92fb 100644 --- a/contrib/python/s3transfer/py3/.dist-info/top_level.txt +++ b/contrib/python/s3transfer/py3/.dist-info/top_level.txt @@ -1 +1 @@ -s3transfer +s3transfer diff --git a/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml index d5ef0ef5e6..f2f140fb3c 100644 --- a/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml +++ b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml @@ -1,2 +1,2 @@ -exclude: -- tests/integration/.+ +exclude: +- tests/integration/.+ diff --git a/contrib/python/s3transfer/py3/LICENSE.txt b/contrib/python/s3transfer/py3/LICENSE.txt index c0fd617439..d645695673 100644 --- a/contrib/python/s3transfer/py3/LICENSE.txt +++ b/contrib/python/s3transfer/py3/LICENSE.txt @@ -1,202 +1,202 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/contrib/python/s3transfer/py3/NOTICE.txt b/contrib/python/s3transfer/py3/NOTICE.txt index 96e3cc3530..3e616fdf0c 100644 --- a/contrib/python/s3transfer/py3/NOTICE.txt +++ b/contrib/python/s3transfer/py3/NOTICE.txt @@ -1,2 +1,2 @@ -s3transfer -Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +s3transfer +Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/contrib/python/s3transfer/py3/README.rst b/contrib/python/s3transfer/py3/README.rst index 923be4f461..441029109e 100644 --- a/contrib/python/s3transfer/py3/README.rst +++ b/contrib/python/s3transfer/py3/README.rst @@ -1,13 +1,13 @@ -===================================================== -s3transfer - An Amazon S3 Transfer Manager for Python -===================================================== - -S3transfer is a Python library for managing Amazon S3 transfers. -This project is maintained and published by Amazon Web Services. - -.. note:: - - This project is not currently GA. If you are planning to use this code in - production, make sure to lock to a minor version as interfaces may break - from minor version to minor version. For a basic, stable interface of - s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__ +===================================================== +s3transfer - An Amazon S3 Transfer Manager for Python +===================================================== + +S3transfer is a Python library for managing Amazon S3 transfers. +This project is maintained and published by Amazon Web Services. + +.. note:: + + This project is not currently GA. If you are planning to use this code in + production, make sure to lock to a minor version as interfaces may break + from minor version to minor version. For a basic, stable interface of + s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__ diff --git a/contrib/python/s3transfer/py3/patches/01-fix-tests.patch b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch index 4f20d019e9..aa8d3fab4e 100644 --- a/contrib/python/s3transfer/py3/patches/01-fix-tests.patch +++ b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch @@ -1,242 +1,242 @@ ---- contrib/python/s3transfer/py3/tests/functional/test_copy.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_copy.py (working tree) -@@ -15,7 +15,7 @@ from botocore.stub import Stubber - - from s3transfer.manager import TransferConfig, TransferManager - from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE --from tests import BaseGeneralInterfaceTest, FileSizeProvider -+from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider - - - class BaseCopyTest(BaseGeneralInterfaceTest): ---- contrib/python/s3transfer/py3/tests/functional/test_crt.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_crt.py (working tree) -@@ -18,7 +18,7 @@ from concurrent.futures import Future - from botocore.session import Session - - from s3transfer.subscribers import BaseSubscriber --from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest -+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest - - if HAS_CRT: - import awscrt ---- contrib/python/s3transfer/py3/tests/functional/test_delete.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_delete.py (working tree) -@@ -11,7 +11,7 @@ - # ANY KIND, either express or implied. See the License for the specific - # language governing permissions and limitations under the License. - from s3transfer.manager import TransferManager --from tests import BaseGeneralInterfaceTest -+from __tests__ import BaseGeneralInterfaceTest - - - class TestDeleteObject(BaseGeneralInterfaceTest): ---- contrib/python/s3transfer/py3/tests/functional/test_download.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_download.py (working tree) -@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError - from s3transfer.compat import SOCKET_ERROR - from s3transfer.exceptions import RetriesExceededError - from s3transfer.manager import TransferConfig, TransferManager --from tests import ( -+from __tests__ import ( - BaseGeneralInterfaceTest, - FileSizeProvider, - NonSeekableWriter, ---- contrib/python/s3transfer/py3/tests/functional/test_manager.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_manager.py (working tree) -@@ -17,7 +17,7 @@ from botocore.awsrequest import create_request_object - from s3transfer.exceptions import CancelledError, FatalError - from s3transfer.futures import BaseExecutor - from s3transfer.manager import TransferConfig, TransferManager --from tests import StubbedClientTest, mock, skip_if_using_serial_implementation -+from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation - - - class ArbitraryException(Exception): ---- contrib/python/s3transfer/py3/tests/functional/test_processpool.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_processpool.py (working tree) -@@ -21,7 +21,7 @@ from botocore.stub import Stubber - - from s3transfer.exceptions import CancelledError - from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig --from tests import FileCreator, mock, unittest -+from __tests__ import FileCreator, mock, unittest - - - class StubbedClient: ---- contrib/python/s3transfer/py3/tests/functional/test_upload.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_upload.py (working tree) -@@ -23,7 +23,7 @@ from botocore.stub import ANY - - from s3transfer.manager import TransferConfig, TransferManager - from s3transfer.utils import ChunksizeAdjuster --from tests import ( -+from __tests__ import ( - BaseGeneralInterfaceTest, - NonSeekableReader, - RecordingOSUtils, ---- contrib/python/s3transfer/py3/tests/functional/test_utils.py (index) -+++ contrib/python/s3transfer/py3/tests/functional/test_utils.py (working tree) -@@ -16,7 +16,7 @@ import socket - import tempfile - - from s3transfer.utils import OSUtils --from tests import skip_if_windows, unittest -+from __tests__ import skip_if_windows, unittest - - - @skip_if_windows('Windows does not support UNIX special files') ---- contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (working tree) -@@ -25,7 +25,7 @@ from s3transfer.bandwidth import ( - TimeUtils, - ) - from s3transfer.futures import TransferCoordinator --from tests import mock, unittest -+from __tests__ import mock, unittest - - - class FixedIncrementalTickTimeUtils(TimeUtils): ---- contrib/python/s3transfer/py3/tests/unit/test_compat.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_compat.py (working tree) -@@ -17,7 +17,7 @@ import tempfile - from io import BytesIO - - from s3transfer.compat import BaseManager, readable, seekable --from tests import skip_if_windows, unittest -+from __tests__ import skip_if_windows, unittest - - - class ErrorRaisingSeekWrapper: ---- contrib/python/s3transfer/py3/tests/unit/test_copies.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_copies.py (working tree) -@@ -11,7 +11,7 @@ - # ANY KIND, either express or implied. See the License for the specific - # language governing permissions and limitations under the License. - from s3transfer.copies import CopyObjectTask, CopyPartTask --from tests import BaseTaskTest, RecordingSubscriber -+from __tests__ import BaseTaskTest, RecordingSubscriber - - - class BaseCopyTaskTest(BaseTaskTest): ---- contrib/python/s3transfer/py3/tests/unit/test_crt.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_crt.py (working tree) -@@ -15,7 +15,7 @@ from botocore.session import Session - - from s3transfer.exceptions import TransferNotDoneError - from s3transfer.utils import CallArgs --from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest -+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest - - if HAS_CRT: - import awscrt.s3 ---- contrib/python/s3transfer/py3/tests/unit/test_delete.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_delete.py (working tree) -@@ -11,7 +11,7 @@ - # ANY KIND, either express or implied. See the License for the specific - # language governing permissions and limitations under the License. - from s3transfer.delete import DeleteObjectTask --from tests import BaseTaskTest -+from __tests__ import BaseTaskTest - - - class TestDeleteObjectTask(BaseTaskTest): ---- contrib/python/s3transfer/py3/tests/unit/test_download.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_download.py (working tree) -@@ -37,7 +37,7 @@ from s3transfer.download import ( - from s3transfer.exceptions import RetriesExceededError - from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor - from s3transfer.utils import CallArgs, OSUtils --from tests import ( -+from __tests__ import ( - BaseSubmissionTaskTest, - BaseTaskTest, - FileCreator, ---- contrib/python/s3transfer/py3/tests/unit/test_futures.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_futures.py (working tree) -@@ -37,7 +37,7 @@ from s3transfer.utils import ( - NoResourcesAvailable, - TaskSemaphore, - ) --from tests import ( -+from __tests__ import ( - RecordingExecutor, - TransferCoordinatorWithInterrupt, - mock, ---- contrib/python/s3transfer/py3/tests/unit/test_manager.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_manager.py (working tree) -@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor - from s3transfer.exceptions import CancelledError, FatalError - from s3transfer.futures import TransferCoordinator - from s3transfer.manager import TransferConfig, TransferCoordinatorController --from tests import TransferCoordinatorWithInterrupt, unittest -+from __tests__ import TransferCoordinatorWithInterrupt, unittest - - - class FutureResultException(Exception): ---- contrib/python/s3transfer/py3/tests/unit/test_processpool.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_processpool.py (working tree) -@@ -39,7 +39,7 @@ from s3transfer.processpool import ( - ignore_ctrl_c, - ) - from s3transfer.utils import CallArgs, OSUtils --from tests import ( -+from __tests__ import ( - FileCreator, - StreamWithError, - StubbedClientTest, ---- contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (working tree) -@@ -33,7 +33,7 @@ from s3transfer import ( - random_file_extension, - ) - from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError --from tests import mock, unittest -+from __tests__ import mock, unittest - - - class InMemoryOSLayer(OSUtils): ---- contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (working tree) -@@ -12,7 +12,7 @@ - # language governing permissions and limitations under the License. - from s3transfer.exceptions import InvalidSubscriberMethodError - from s3transfer.subscribers import BaseSubscriber --from tests import unittest -+from __tests__ import unittest - - - class ExtraMethodsSubscriber(BaseSubscriber): ---- contrib/python/s3transfer/py3/tests/unit/test_tasks.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_tasks.py (working tree) -@@ -23,7 +23,7 @@ from s3transfer.tasks import ( - Task, - ) - from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks --from tests import ( -+from __tests__ import ( - BaseSubmissionTaskTest, - BaseTaskTest, - RecordingSubscriber, ---- contrib/python/s3transfer/py3/tests/unit/test_upload.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_upload.py (working tree) -@@ -32,7 +32,7 @@ from s3transfer.upload import ( - UploadSubmissionTask, - ) - from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils --from tests import ( -+from __tests__ import ( - BaseSubmissionTaskTest, - BaseTaskTest, - FileSizeProvider, ---- contrib/python/s3transfer/py3/tests/unit/test_utils.py (index) -+++ contrib/python/s3transfer/py3/tests/unit/test_utils.py (working tree) -@@ -43,7 +43,7 @@ from s3transfer.utils import ( - invoke_progress_callbacks, - random_file_extension, - ) --from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest -+from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest - - - class TestGetCallbacks(unittest.TestCase): +--- contrib/python/s3transfer/py3/tests/functional/test_copy.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_copy.py (working tree) +@@ -15,7 +15,7 @@ from botocore.stub import Stubber + + from s3transfer.manager import TransferConfig, TransferManager + from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE +-from tests import BaseGeneralInterfaceTest, FileSizeProvider ++from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider + + + class BaseCopyTest(BaseGeneralInterfaceTest): +--- contrib/python/s3transfer/py3/tests/functional/test_crt.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_crt.py (working tree) +@@ -18,7 +18,7 @@ from concurrent.futures import Future + from botocore.session import Session + + from s3transfer.subscribers import BaseSubscriber +-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest ++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + + if HAS_CRT: + import awscrt +--- contrib/python/s3transfer/py3/tests/functional/test_delete.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_delete.py (working tree) +@@ -11,7 +11,7 @@ + # ANY KIND, either express or implied. See the License for the specific + # language governing permissions and limitations under the License. + from s3transfer.manager import TransferManager +-from tests import BaseGeneralInterfaceTest ++from __tests__ import BaseGeneralInterfaceTest + + + class TestDeleteObject(BaseGeneralInterfaceTest): +--- contrib/python/s3transfer/py3/tests/functional/test_download.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_download.py (working tree) +@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError + from s3transfer.compat import SOCKET_ERROR + from s3transfer.exceptions import RetriesExceededError + from s3transfer.manager import TransferConfig, TransferManager +-from tests import ( ++from __tests__ import ( + BaseGeneralInterfaceTest, + FileSizeProvider, + NonSeekableWriter, +--- contrib/python/s3transfer/py3/tests/functional/test_manager.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_manager.py (working tree) +@@ -17,7 +17,7 @@ from botocore.awsrequest import create_request_object + from s3transfer.exceptions import CancelledError, FatalError + from s3transfer.futures import BaseExecutor + from s3transfer.manager import TransferConfig, TransferManager +-from tests import StubbedClientTest, mock, skip_if_using_serial_implementation ++from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation + + + class ArbitraryException(Exception): +--- contrib/python/s3transfer/py3/tests/functional/test_processpool.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_processpool.py (working tree) +@@ -21,7 +21,7 @@ from botocore.stub import Stubber + + from s3transfer.exceptions import CancelledError + from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig +-from tests import FileCreator, mock, unittest ++from __tests__ import FileCreator, mock, unittest + + + class StubbedClient: +--- contrib/python/s3transfer/py3/tests/functional/test_upload.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_upload.py (working tree) +@@ -23,7 +23,7 @@ from botocore.stub import ANY + + from s3transfer.manager import TransferConfig, TransferManager + from s3transfer.utils import ChunksizeAdjuster +-from tests import ( ++from __tests__ import ( + BaseGeneralInterfaceTest, + NonSeekableReader, + RecordingOSUtils, +--- contrib/python/s3transfer/py3/tests/functional/test_utils.py (index) ++++ contrib/python/s3transfer/py3/tests/functional/test_utils.py (working tree) +@@ -16,7 +16,7 @@ import socket + import tempfile + + from s3transfer.utils import OSUtils +-from tests import skip_if_windows, unittest ++from __tests__ import skip_if_windows, unittest + + + @skip_if_windows('Windows does not support UNIX special files') +--- contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (working tree) +@@ -25,7 +25,7 @@ from s3transfer.bandwidth import ( + TimeUtils, + ) + from s3transfer.futures import TransferCoordinator +-from tests import mock, unittest ++from __tests__ import mock, unittest + + + class FixedIncrementalTickTimeUtils(TimeUtils): +--- contrib/python/s3transfer/py3/tests/unit/test_compat.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_compat.py (working tree) +@@ -17,7 +17,7 @@ import tempfile + from io import BytesIO + + from s3transfer.compat import BaseManager, readable, seekable +-from tests import skip_if_windows, unittest ++from __tests__ import skip_if_windows, unittest + + + class ErrorRaisingSeekWrapper: +--- contrib/python/s3transfer/py3/tests/unit/test_copies.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_copies.py (working tree) +@@ -11,7 +11,7 @@ + # ANY KIND, either express or implied. See the License for the specific + # language governing permissions and limitations under the License. + from s3transfer.copies import CopyObjectTask, CopyPartTask +-from tests import BaseTaskTest, RecordingSubscriber ++from __tests__ import BaseTaskTest, RecordingSubscriber + + + class BaseCopyTaskTest(BaseTaskTest): +--- contrib/python/s3transfer/py3/tests/unit/test_crt.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_crt.py (working tree) +@@ -15,7 +15,7 @@ from botocore.session import Session + + from s3transfer.exceptions import TransferNotDoneError + from s3transfer.utils import CallArgs +-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest ++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + + if HAS_CRT: + import awscrt.s3 +--- contrib/python/s3transfer/py3/tests/unit/test_delete.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_delete.py (working tree) +@@ -11,7 +11,7 @@ + # ANY KIND, either express or implied. See the License for the specific + # language governing permissions and limitations under the License. + from s3transfer.delete import DeleteObjectTask +-from tests import BaseTaskTest ++from __tests__ import BaseTaskTest + + + class TestDeleteObjectTask(BaseTaskTest): +--- contrib/python/s3transfer/py3/tests/unit/test_download.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_download.py (working tree) +@@ -37,7 +37,7 @@ from s3transfer.download import ( + from s3transfer.exceptions import RetriesExceededError + from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor + from s3transfer.utils import CallArgs, OSUtils +-from tests import ( ++from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileCreator, +--- contrib/python/s3transfer/py3/tests/unit/test_futures.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_futures.py (working tree) +@@ -37,7 +37,7 @@ from s3transfer.utils import ( + NoResourcesAvailable, + TaskSemaphore, + ) +-from tests import ( ++from __tests__ import ( + RecordingExecutor, + TransferCoordinatorWithInterrupt, + mock, +--- contrib/python/s3transfer/py3/tests/unit/test_manager.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_manager.py (working tree) +@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor + from s3transfer.exceptions import CancelledError, FatalError + from s3transfer.futures import TransferCoordinator + from s3transfer.manager import TransferConfig, TransferCoordinatorController +-from tests import TransferCoordinatorWithInterrupt, unittest ++from __tests__ import TransferCoordinatorWithInterrupt, unittest + + + class FutureResultException(Exception): +--- contrib/python/s3transfer/py3/tests/unit/test_processpool.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_processpool.py (working tree) +@@ -39,7 +39,7 @@ from s3transfer.processpool import ( + ignore_ctrl_c, + ) + from s3transfer.utils import CallArgs, OSUtils +-from tests import ( ++from __tests__ import ( + FileCreator, + StreamWithError, + StubbedClientTest, +--- contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (working tree) +@@ -33,7 +33,7 @@ from s3transfer import ( + random_file_extension, + ) + from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError +-from tests import mock, unittest ++from __tests__ import mock, unittest + + + class InMemoryOSLayer(OSUtils): +--- contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (working tree) +@@ -12,7 +12,7 @@ + # language governing permissions and limitations under the License. + from s3transfer.exceptions import InvalidSubscriberMethodError + from s3transfer.subscribers import BaseSubscriber +-from tests import unittest ++from __tests__ import unittest + + + class ExtraMethodsSubscriber(BaseSubscriber): +--- contrib/python/s3transfer/py3/tests/unit/test_tasks.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_tasks.py (working tree) +@@ -23,7 +23,7 @@ from s3transfer.tasks import ( + Task, + ) + from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks +-from tests import ( ++from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + RecordingSubscriber, +--- contrib/python/s3transfer/py3/tests/unit/test_upload.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_upload.py (working tree) +@@ -32,7 +32,7 @@ from s3transfer.upload import ( + UploadSubmissionTask, + ) + from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils +-from tests import ( ++from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileSizeProvider, +--- contrib/python/s3transfer/py3/tests/unit/test_utils.py (index) ++++ contrib/python/s3transfer/py3/tests/unit/test_utils.py (working tree) +@@ -43,7 +43,7 @@ from s3transfer.utils import ( + invoke_progress_callbacks, + random_file_extension, + ) +-from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest ++from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest + + + class TestGetCallbacks(unittest.TestCase): diff --git a/contrib/python/s3transfer/py3/s3transfer/__init__.py b/contrib/python/s3transfer/py3/s3transfer/__init__.py index cb2591e1cd..1a749c712e 100644 --- a/contrib/python/s3transfer/py3/s3transfer/__init__.py +++ b/contrib/python/s3transfer/py3/s3transfer/__init__.py @@ -72,7 +72,7 @@ client operation. Here are a few examples using ``upload_file``:: extra_args={'ContentType': "application/json"}) -The ``S3Transfer`` class also supports progress callbacks so you can +The ``S3Transfer`` class also supports progress callbacks so you can provide transfer progress to users. Both the ``upload_file`` and ``download_file`` methods take an optional ``callback`` parameter. Here's an example of how to print a simple progress percentage @@ -123,28 +123,28 @@ transfer. For example: """ -import concurrent.futures +import concurrent.futures import functools import logging -import math -import os -import queue -import random +import math +import os +import queue +import random import socket -import string +import string import threading -from botocore.compat import six # noqa: F401 +from botocore.compat import six # noqa: F401 from botocore.exceptions import IncompleteReadError -from botocore.vendored.requests.packages.urllib3.exceptions import ( - ReadTimeoutError, -) +from botocore.vendored.requests.packages.urllib3.exceptions import ( + ReadTimeoutError, +) import s3transfer.compat from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError __author__ = 'Amazon Web Services' -__version__ = '0.5.1' +__version__ = '0.5.1' class NullHandler(logging.Handler): @@ -164,16 +164,16 @@ def random_file_extension(num_digits=8): def disable_upload_callbacks(request, operation_name, **kwargs): - if operation_name in ['PutObject', 'UploadPart'] and hasattr( - request.body, 'disable_callback' - ): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'disable_callback' + ): request.body.disable_callback() def enable_upload_callbacks(request, operation_name, **kwargs): - if operation_name in ['PutObject', 'UploadPart'] and hasattr( - request.body, 'enable_callback' - ): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'enable_callback' + ): request.body.enable_callback() @@ -181,16 +181,16 @@ class QueueShutdownError(Exception): pass -class ReadFileChunk: - def __init__( - self, - fileobj, - start_byte, - chunk_size, - full_file_size, - callback=None, - enable_callback=True, - ): +class ReadFileChunk: + def __init__( + self, + fileobj, + start_byte, + chunk_size, + full_file_size, + callback=None, + enable_callback=True, + ): """ Given a file object shown below: @@ -222,25 +222,25 @@ class ReadFileChunk: self._fileobj = fileobj self._start_byte = start_byte self._size = self._calculate_file_size( - self._fileobj, - requested_size=chunk_size, - start_byte=start_byte, - actual_file_size=full_file_size, - ) + self._fileobj, + requested_size=chunk_size, + start_byte=start_byte, + actual_file_size=full_file_size, + ) self._fileobj.seek(self._start_byte) self._amount_read = 0 self._callback = callback self._callback_enabled = enable_callback @classmethod - def from_filename( - cls, - filename, - start_byte, - chunk_size, - callback=None, - enable_callback=True, - ): + def from_filename( + cls, + filename, + start_byte, + chunk_size, + callback=None, + enable_callback=True, + ): """Convenience factory function to create from a filename. :type start_byte: int @@ -268,13 +268,13 @@ class ReadFileChunk: """ f = open(filename, 'rb') file_size = os.fstat(f.fileno()).st_size - return cls( - f, start_byte, chunk_size, file_size, callback, enable_callback - ) + return cls( + f, start_byte, chunk_size, file_size, callback, enable_callback + ) - def _calculate_file_size( - self, fileobj, requested_size, start_byte, actual_file_size - ): + def _calculate_file_size( + self, fileobj, requested_size, start_byte, actual_file_size + ): max_chunk_size = actual_file_size - start_byte return min(max_chunk_size, requested_size) @@ -331,9 +331,9 @@ class ReadFileChunk: return iter([]) -class StreamReaderProgress: +class StreamReaderProgress: """Wrapper for a read only stream that adds progress callbacks.""" - + def __init__(self, stream, callback=None): self._stream = stream self._callback = callback @@ -345,14 +345,14 @@ class StreamReaderProgress: return value -class OSUtils: +class OSUtils: def get_file_size(self, filename): return os.path.getsize(filename) def open_file_chunk_reader(self, filename, start_byte, size, callback): - return ReadFileChunk.from_filename( - filename, start_byte, size, callback, enable_callback=False - ) + return ReadFileChunk.from_filename( + filename, start_byte, size, callback, enable_callback=False + ) def open(self, filename, mode): return open(filename, mode) @@ -370,7 +370,7 @@ class OSUtils: s3transfer.compat.rename_file(current_filename, new_filename) -class MultipartUploader: +class MultipartUploader: # These are the extra_args that need to be forwarded onto # subsequent upload_parts. UPLOAD_PART_ARGS = [ @@ -380,13 +380,13 @@ class MultipartUploader: 'RequestPayer', ] - def __init__( - self, - client, - config, - osutil, - executor_cls=concurrent.futures.ThreadPoolExecutor, - ): + def __init__( + self, + client, + config, + osutil, + executor_cls=concurrent.futures.ThreadPoolExecutor, + ): self._client = client self._config = config self._os = osutil @@ -402,83 +402,83 @@ class MultipartUploader: return upload_parts_args def upload_file(self, filename, bucket, key, callback, extra_args): - response = self._client.create_multipart_upload( - Bucket=bucket, Key=key, **extra_args - ) + response = self._client.create_multipart_upload( + Bucket=bucket, Key=key, **extra_args + ) upload_id = response['UploadId'] try: - parts = self._upload_parts( - upload_id, filename, bucket, key, callback, extra_args - ) + parts = self._upload_parts( + upload_id, filename, bucket, key, callback, extra_args + ) except Exception as e: - logger.debug( - "Exception raised while uploading parts, " - "aborting multipart upload.", - exc_info=True, - ) + logger.debug( + "Exception raised while uploading parts, " + "aborting multipart upload.", + exc_info=True, + ) self._client.abort_multipart_upload( - Bucket=bucket, Key=key, UploadId=upload_id - ) + Bucket=bucket, Key=key, UploadId=upload_id + ) raise S3UploadFailedError( - "Failed to upload {} to {}: {}".format( - filename, '/'.join([bucket, key]), e - ) - ) + "Failed to upload {} to {}: {}".format( + filename, '/'.join([bucket, key]), e + ) + ) self._client.complete_multipart_upload( - Bucket=bucket, - Key=key, - UploadId=upload_id, - MultipartUpload={'Parts': parts}, - ) - - def _upload_parts( - self, upload_id, filename, bucket, key, callback, extra_args - ): + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={'Parts': parts}, + ) + + def _upload_parts( + self, upload_id, filename, bucket, key, callback, extra_args + ): upload_parts_extra_args = self._extra_upload_part_args(extra_args) parts = [] part_size = self._config.multipart_chunksize num_parts = int( - math.ceil(self._os.get_file_size(filename) / float(part_size)) - ) + math.ceil(self._os.get_file_size(filename) / float(part_size)) + ) max_workers = self._config.max_concurrency with self._executor_cls(max_workers=max_workers) as executor: upload_partial = functools.partial( - self._upload_one_part, - filename, - bucket, - key, - upload_id, - part_size, - upload_parts_extra_args, - callback, - ) + self._upload_one_part, + filename, + bucket, + key, + upload_id, + part_size, + upload_parts_extra_args, + callback, + ) for part in executor.map(upload_partial, range(1, num_parts + 1)): parts.append(part) return parts - def _upload_one_part( - self, - filename, - bucket, - key, - upload_id, - part_size, - extra_args, - callback, - part_number, - ): + def _upload_one_part( + self, + filename, + bucket, + key, + upload_id, + part_size, + extra_args, + callback, + part_number, + ): open_chunk_reader = self._os.open_file_chunk_reader - with open_chunk_reader( - filename, part_size * (part_number - 1), part_size, callback - ) as body: + with open_chunk_reader( + filename, part_size * (part_number - 1), part_size, callback + ) as body: response = self._client.upload_part( - Bucket=bucket, - Key=key, - UploadId=upload_id, - PartNumber=part_number, - Body=body, - **extra_args, - ) + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + Body=body, + **extra_args, + ) etag = response['ETag'] return {'ETag': etag, 'PartNumber': part_number} @@ -494,7 +494,7 @@ class ShutdownQueue(queue.Queue): to be a drop in replacement for ``queue.Queue``. """ - + def _init(self, maxsize): self._shutdown = False self._shutdown_lock = threading.Lock() @@ -511,50 +511,50 @@ class ShutdownQueue(queue.Queue): # Need to hook into the condition vars used by this class. with self._shutdown_lock: if self._shutdown: - raise QueueShutdownError( - "Cannot put item to queue when " "queue has been shutdown." - ) + raise QueueShutdownError( + "Cannot put item to queue when " "queue has been shutdown." + ) return queue.Queue.put(self, item) -class MultipartDownloader: - def __init__( - self, - client, - config, - osutil, - executor_cls=concurrent.futures.ThreadPoolExecutor, - ): +class MultipartDownloader: + def __init__( + self, + client, + config, + osutil, + executor_cls=concurrent.futures.ThreadPoolExecutor, + ): self._client = client self._config = config self._os = osutil self._executor_cls = executor_cls self._ioqueue = ShutdownQueue(self._config.max_io_queue) - def download_file( - self, bucket, key, filename, object_size, extra_args, callback=None - ): + def download_file( + self, bucket, key, filename, object_size, extra_args, callback=None + ): with self._executor_cls(max_workers=2) as controller: # 1 thread for the future that manages the uploading of files # 1 thread for the future that manages IO writes. download_parts_handler = functools.partial( self._download_file_as_future, - bucket, - key, - filename, - object_size, - callback, - ) + bucket, + key, + filename, + object_size, + callback, + ) parts_future = controller.submit(download_parts_handler) io_writes_handler = functools.partial( - self._perform_io_writes, filename - ) + self._perform_io_writes, filename + ) io_future = controller.submit(io_writes_handler) results = concurrent.futures.wait( [parts_future, io_future], - return_when=concurrent.futures.FIRST_EXCEPTION, - ) + return_when=concurrent.futures.FIRST_EXCEPTION, + ) self._process_future_results(results) def _process_future_results(self, futures): @@ -562,21 +562,21 @@ class MultipartDownloader: for future in finished: future.result() - def _download_file_as_future( - self, bucket, key, filename, object_size, callback - ): + def _download_file_as_future( + self, bucket, key, filename, object_size, callback + ): part_size = self._config.multipart_chunksize num_parts = int(math.ceil(object_size / float(part_size))) max_workers = self._config.max_concurrency download_partial = functools.partial( - self._download_range, - bucket, - key, - filename, - part_size, - num_parts, - callback, - ) + self._download_range, + bucket, + key, + filename, + part_size, + num_parts, + callback, + ) try: with self._executor_cls(max_workers=max_workers) as executor: list(executor.map(download_partial, range(num_parts))) @@ -589,16 +589,16 @@ class MultipartDownloader: end_range = '' else: end_range = start_range + part_size - 1 - range_param = f'bytes={start_range}-{end_range}' + range_param = f'bytes={start_range}-{end_range}' return range_param - def _download_range( - self, bucket, key, filename, part_size, num_parts, callback, part_index - ): + def _download_range( + self, bucket, key, filename, part_size, num_parts, callback, part_index + ): try: range_param = self._calculate_range_param( - part_size, part_index, num_parts - ) + part_size, part_index, num_parts + ) max_attempts = self._config.num_download_attempts last_exception = None @@ -606,33 +606,33 @@ class MultipartDownloader: try: logger.debug("Making get_object call.") response = self._client.get_object( - Bucket=bucket, Key=key, Range=range_param - ) + Bucket=bucket, Key=key, Range=range_param + ) streaming_body = StreamReaderProgress( - response['Body'], callback - ) + response['Body'], callback + ) buffer_size = 1024 * 16 current_index = part_size * part_index - for chunk in iter( - lambda: streaming_body.read(buffer_size), b'' - ): + for chunk in iter( + lambda: streaming_body.read(buffer_size), b'' + ): self._ioqueue.put((current_index, chunk)) current_index += len(chunk) return - except ( - socket.timeout, - OSError, - ReadTimeoutError, - IncompleteReadError, - ) as e: - logger.debug( - "Retrying exception caught (%s), " - "retrying request, (attempt %s / %s)", - e, - i, - max_attempts, - exc_info=True, - ) + except ( + socket.timeout, + OSError, + ReadTimeoutError, + IncompleteReadError, + ) as e: + logger.debug( + "Retrying exception caught (%s), " + "retrying request, (attempt %s / %s)", + e, + i, + max_attempts, + exc_info=True, + ) last_exception = e continue raise RetriesExceededError(last_exception) @@ -644,10 +644,10 @@ class MultipartDownloader: while True: task = self._ioqueue.get() if task is SHUTDOWN_SENTINEL: - logger.debug( - "Shutdown sentinel received in IO handler, " - "shutting down IO handler." - ) + logger.debug( + "Shutdown sentinel received in IO handler, " + "shutting down IO handler." + ) return else: try: @@ -655,24 +655,24 @@ class MultipartDownloader: f.seek(offset) f.write(data) except Exception as e: - logger.debug( - "Caught exception in IO thread: %s", - e, - exc_info=True, - ) + logger.debug( + "Caught exception in IO thread: %s", + e, + exc_info=True, + ) self._ioqueue.trigger_shutdown() raise -class TransferConfig: - def __init__( - self, - multipart_threshold=8 * MB, - max_concurrency=10, - multipart_chunksize=8 * MB, - num_download_attempts=5, - max_io_queue=100, - ): +class TransferConfig: + def __init__( + self, + multipart_threshold=8 * MB, + max_concurrency=10, + multipart_chunksize=8 * MB, + num_download_attempts=5, + max_io_queue=100, + ): self.multipart_threshold = multipart_threshold self.max_concurrency = max_concurrency self.multipart_chunksize = multipart_chunksize @@ -680,7 +680,7 @@ class TransferConfig: self.max_io_queue = max_io_queue -class S3Transfer: +class S3Transfer: ALLOWED_DOWNLOAD_ARGS = [ 'VersionId', @@ -710,8 +710,8 @@ class S3Transfer: 'SSECustomerKey', 'SSECustomerKeyMD5', 'SSEKMSKeyId', - 'SSEKMSEncryptionContext', - 'Tagging', + 'SSEKMSEncryptionContext', + 'Tagging', ] def __init__(self, client, config=None, osutil=None): @@ -723,9 +723,9 @@ class S3Transfer: osutil = OSUtils() self._osutil = osutil - def upload_file( - self, filename, bucket, key, callback=None, extra_args=None - ): + def upload_file( + self, filename, bucket, key, callback=None, extra_args=None + ): """Upload a file to an S3 object. Variants have also been injected into S3 client, Bucket and Object. @@ -735,20 +735,20 @@ class S3Transfer: extra_args = {} self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS) events = self._client.meta.events - events.register_first( - 'request-created.s3', - disable_upload_callbacks, - unique_id='s3upload-callback-disable', - ) - events.register_last( - 'request-created.s3', - enable_upload_callbacks, - unique_id='s3upload-callback-enable', - ) - if ( - self._osutil.get_file_size(filename) - >= self._config.multipart_threshold - ): + events.register_first( + 'request-created.s3', + disable_upload_callbacks, + unique_id='s3upload-callback-disable', + ) + events.register_last( + 'request-created.s3', + enable_upload_callbacks, + unique_id='s3upload-callback-enable', + ) + if ( + self._osutil.get_file_size(filename) + >= self._config.multipart_threshold + ): self._multipart_upload(filename, bucket, key, callback, extra_args) else: self._put_object(filename, bucket, key, callback, extra_args) @@ -757,19 +757,19 @@ class S3Transfer: # We're using open_file_chunk_reader so we can take advantage of the # progress callback functionality. open_chunk_reader = self._osutil.open_file_chunk_reader - with open_chunk_reader( - filename, - 0, - self._osutil.get_file_size(filename), - callback=callback, - ) as body: - self._client.put_object( - Bucket=bucket, Key=key, Body=body, **extra_args - ) - - def download_file( - self, bucket, key, filename, extra_args=None, callback=None - ): + with open_chunk_reader( + filename, + 0, + self._osutil.get_file_size(filename), + callback=callback, + ) as body: + self._client.put_object( + Bucket=bucket, Key=key, Body=body, **extra_args + ) + + def download_file( + self, bucket, key, filename, extra_args=None, callback=None + ): """Download an S3 object to a file. Variants have also been injected into S3 client, Bucket and Object. @@ -784,28 +784,28 @@ class S3Transfer: object_size = self._object_size(bucket, key, extra_args) temp_filename = filename + os.extsep + random_file_extension() try: - self._download_file( - bucket, key, temp_filename, object_size, extra_args, callback - ) + self._download_file( + bucket, key, temp_filename, object_size, extra_args, callback + ) except Exception: - logger.debug( - "Exception caught in download_file, removing partial " - "file: %s", - temp_filename, - exc_info=True, - ) + logger.debug( + "Exception caught in download_file, removing partial " + "file: %s", + temp_filename, + exc_info=True, + ) self._osutil.remove_file(temp_filename) raise else: self._osutil.rename_file(temp_filename, filename) - def _download_file( - self, bucket, key, filename, object_size, extra_args, callback - ): + def _download_file( + self, bucket, key, filename, object_size, extra_args, callback + ): if object_size >= self._config.multipart_threshold: - self._ranged_download( - bucket, key, filename, object_size, extra_args, callback - ) + self._ranged_download( + bucket, key, filename, object_size, extra_args, callback + ) else: self._get_object(bucket, key, filename, extra_args, callback) @@ -814,18 +814,18 @@ class S3Transfer: if kwarg not in allowed: raise ValueError( "Invalid extra_args key '%s', " - "must be one of: %s" % (kwarg, ', '.join(allowed)) - ) - - def _ranged_download( - self, bucket, key, filename, object_size, extra_args, callback - ): - downloader = MultipartDownloader( - self._client, self._config, self._osutil - ) - downloader.download_file( - bucket, key, filename, object_size, extra_args, callback - ) + "must be one of: %s" % (kwarg, ', '.join(allowed)) + ) + + def _ranged_download( + self, bucket, key, filename, object_size, extra_args, callback + ): + downloader = MultipartDownloader( + self._client, self._config, self._osutil + ) + downloader.download_file( + bucket, key, filename, object_size, extra_args, callback + ) def _get_object(self, bucket, key, filename, extra_args, callback): # precondition: num_download_attempts > 0 @@ -833,42 +833,42 @@ class S3Transfer: last_exception = None for i in range(max_attempts): try: - return self._do_get_object( - bucket, key, filename, extra_args, callback - ) - except ( - socket.timeout, - OSError, - ReadTimeoutError, - IncompleteReadError, - ) as e: + return self._do_get_object( + bucket, key, filename, extra_args, callback + ) + except ( + socket.timeout, + OSError, + ReadTimeoutError, + IncompleteReadError, + ) as e: # TODO: we need a way to reset the callback if the # download failed. - logger.debug( - "Retrying exception caught (%s), " - "retrying request, (attempt %s / %s)", - e, - i, - max_attempts, - exc_info=True, - ) + logger.debug( + "Retrying exception caught (%s), " + "retrying request, (attempt %s / %s)", + e, + i, + max_attempts, + exc_info=True, + ) last_exception = e continue raise RetriesExceededError(last_exception) def _do_get_object(self, bucket, key, filename, extra_args, callback): - response = self._client.get_object( - Bucket=bucket, Key=key, **extra_args - ) - streaming_body = StreamReaderProgress(response['Body'], callback) + response = self._client.get_object( + Bucket=bucket, Key=key, **extra_args + ) + streaming_body = StreamReaderProgress(response['Body'], callback) with self._osutil.open(filename, 'wb') as f: for chunk in iter(lambda: streaming_body.read(8192), b''): f.write(chunk) def _object_size(self, bucket, key, extra_args): - return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[ - 'ContentLength' - ] + return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[ + 'ContentLength' + ] def _multipart_upload(self, filename, bucket, key, callback, extra_args): uploader = MultipartUploader(self._client, self._config, self._osutil) diff --git a/contrib/python/s3transfer/py3/s3transfer/bandwidth.py b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py index 957049dffc..9bac5885e1 100644 --- a/contrib/python/s3transfer/py3/s3transfer/bandwidth.py +++ b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import threading +import threading import time @@ -30,19 +30,19 @@ class RequestExceededException(Exception): """ self.requested_amt = requested_amt self.retry_time = retry_time - msg = 'Request amount {} exceeded the amount available. Retry in {}'.format( - requested_amt, retry_time + msg = 'Request amount {} exceeded the amount available. Retry in {}'.format( + requested_amt, retry_time ) - super().__init__(msg) + super().__init__(msg) -class RequestToken: +class RequestToken: """A token to pass as an identifier when consuming from the LeakyBucket""" - + pass -class TimeUtils: +class TimeUtils: def time(self): """Get the current time back @@ -60,7 +60,7 @@ class TimeUtils: return time.sleep(value) -class BandwidthLimiter: +class BandwidthLimiter: def __init__(self, leaky_bucket, time_utils=None): """Limits bandwidth for shared S3 transfers @@ -75,9 +75,9 @@ class BandwidthLimiter: if time_utils is None: self._time_utils = TimeUtils() - def get_bandwith_limited_stream( - self, fileobj, transfer_coordinator, enabled=True - ): + def get_bandwith_limited_stream( + self, fileobj, transfer_coordinator, enabled=True + ): """Wraps a fileobj in a bandwidth limited stream wrapper :type fileobj: file-like obj @@ -91,22 +91,22 @@ class BandwidthLimiter: :param enabled: Whether bandwidth limiting should be enabled to start """ stream = BandwidthLimitedStream( - fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils - ) + fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils + ) if not enabled: stream.disable_bandwidth_limiting() return stream -class BandwidthLimitedStream: - def __init__( - self, - fileobj, - leaky_bucket, - transfer_coordinator, - time_utils=None, - bytes_threshold=256 * 1024, - ): +class BandwidthLimitedStream: + def __init__( + self, + fileobj, + leaky_bucket, + transfer_coordinator, + time_utils=None, + bytes_threshold=256 * 1024, + ): """Limits bandwidth for reads on a wrapped stream :type fileobj: file-like object @@ -163,7 +163,7 @@ class BandwidthLimitedStream: return self._fileobj.read(amount) def _consume_through_leaky_bucket(self): - # NOTE: If the read amount on the stream are high, it will result + # NOTE: If the read amount on the stream are high, it will result # in large bursty behavior as there is not an interface for partial # reads. However given the read's on this abstraction are at most 256KB # (via downloads), it reduces the burstiness to be small KB bursts at @@ -171,8 +171,8 @@ class BandwidthLimitedStream: while not self._transfer_coordinator.exception: try: self._leaky_bucket.consume( - self._bytes_seen, self._request_token - ) + self._bytes_seen, self._request_token + ) self._bytes_seen = 0 return except RequestExceededException as e: @@ -188,8 +188,8 @@ class BandwidthLimitedStream: """Signal that data being read is not being transferred to S3""" self.disable_bandwidth_limiting() - def seek(self, where, whence=0): - self._fileobj.seek(where, whence) + def seek(self, where, whence=0): + self._fileobj.seek(where, whence) def tell(self): return self._fileobj.tell() @@ -211,14 +211,14 @@ class BandwidthLimitedStream: self.close() -class LeakyBucket: - def __init__( - self, - max_rate, - time_utils=None, - rate_tracker=None, - consumption_scheduler=None, - ): +class LeakyBucket: + def __init__( + self, + max_rate, + time_utils=None, + rate_tracker=None, + consumption_scheduler=None, + ): """A leaky bucket abstraction to limit bandwidth consumption :type rate: int @@ -269,12 +269,12 @@ class LeakyBucket: time_now = self._time_utils.time() if self._consumption_scheduler.is_scheduled(request_token): return self._release_requested_amt_for_scheduled_request( - amt, request_token, time_now - ) + amt, request_token, time_now + ) elif self._projected_to_exceed_max_rate(amt, time_now): self._raise_request_exceeded_exception( - amt, request_token, time_now - ) + amt, request_token, time_now + ) else: return self._release_requested_amt(amt, time_now) @@ -282,29 +282,29 @@ class LeakyBucket: projected_rate = self._rate_tracker.get_projected_rate(amt, time_now) return projected_rate > self._max_rate - def _release_requested_amt_for_scheduled_request( - self, amt, request_token, time_now - ): + def _release_requested_amt_for_scheduled_request( + self, amt, request_token, time_now + ): self._consumption_scheduler.process_scheduled_consumption( - request_token - ) + request_token + ) return self._release_requested_amt(amt, time_now) def _raise_request_exceeded_exception(self, amt, request_token, time_now): - allocated_time = amt / float(self._max_rate) + allocated_time = amt / float(self._max_rate) retry_time = self._consumption_scheduler.schedule_consumption( - amt, request_token, allocated_time - ) + amt, request_token, allocated_time + ) raise RequestExceededException( - requested_amt=amt, retry_time=retry_time - ) + requested_amt=amt, retry_time=retry_time + ) def _release_requested_amt(self, amt, time_now): self._rate_tracker.record_consumption_rate(amt, time_now) return amt -class ConsumptionScheduler: +class ConsumptionScheduler: def __init__(self): """Schedules when to consume a desired amount""" self._tokens_to_scheduled_consumption = {} @@ -354,11 +354,11 @@ class ConsumptionScheduler: """ scheduled_retry = self._tokens_to_scheduled_consumption.pop(token) self._total_wait = max( - self._total_wait - scheduled_retry['time_to_consume'], 0 - ) + self._total_wait - scheduled_retry['time_to_consume'], 0 + ) -class BandwidthRateTracker: +class BandwidthRateTracker: def __init__(self, alpha=0.8): """Tracks the rate of bandwidth consumption @@ -401,8 +401,8 @@ class BandwidthRateTracker: if self._last_time is None: return 0.0 return self._calculate_exponential_moving_average_rate( - amt, time_at_consumption - ) + amt, time_at_consumption + ) def record_consumption_rate(self, amt, time_at_consumption): """Record the consumption rate based off amount and time point @@ -418,22 +418,22 @@ class BandwidthRateTracker: self._current_rate = 0.0 return self._current_rate = self._calculate_exponential_moving_average_rate( - amt, time_at_consumption - ) + amt, time_at_consumption + ) self._last_time = time_at_consumption def _calculate_rate(self, amt, time_at_consumption): time_delta = time_at_consumption - self._last_time if time_delta <= 0: - # While it is really unlikely to see this in an actual transfer, + # While it is really unlikely to see this in an actual transfer, # we do not want to be returning back a negative rate or try to # divide the amount by zero. So instead return back an infinite # rate as the time delta is infinitesimally small. return float('inf') return amt / (time_delta) - def _calculate_exponential_moving_average_rate( - self, amt, time_at_consumption - ): + def _calculate_exponential_moving_average_rate( + self, amt, time_at_consumption + ): new_rate = self._calculate_rate(amt, time_at_consumption) return self._alpha * new_rate + (1 - self._alpha) * self._current_rate diff --git a/contrib/python/s3transfer/py3/s3transfer/compat.py b/contrib/python/s3transfer/py3/s3transfer/compat.py index 5746025983..68267ad0e2 100644 --- a/contrib/python/s3transfer/py3/s3transfer/compat.py +++ b/contrib/python/s3transfer/py3/s3transfer/compat.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import errno +import errno import inspect import os import socket -import sys +import sys from botocore.compat import six @@ -34,18 +34,18 @@ else: rename_file = os.rename -def accepts_kwargs(func): - return inspect.getfullargspec(func)[2] +def accepts_kwargs(func): + return inspect.getfullargspec(func)[2] -# In python 3, socket.error is OSError, which is too general -# for what we want (i.e FileNotFoundError is a subclass of OSError). -# In python 3, all the socket related errors are in a newly created -# ConnectionError. -SOCKET_ERROR = ConnectionError -MAXINT = None +# In python 3, socket.error is OSError, which is too general +# for what we want (i.e FileNotFoundError is a subclass of OSError). +# In python 3, all the socket related errors are in a newly created +# ConnectionError. +SOCKET_ERROR = ConnectionError +MAXINT = None + - def seekable(fileobj): """Backwards compat function to determine if a fileobj is seekable @@ -63,7 +63,7 @@ def seekable(fileobj): try: fileobj.seek(0, 1) return True - except OSError: + except OSError: # If an io related error was thrown then it is not seekable. return False # Else, the fileobj is not seekable @@ -81,14 +81,14 @@ def readable(fileobj): return fileobj.readable() return hasattr(fileobj, 'read') - - -def fallocate(fileobj, size): - if hasattr(os, 'posix_fallocate'): - os.posix_fallocate(fileobj.fileno(), 0, size) - else: - fileobj.truncate(size) - - -# Import at end of file to avoid circular dependencies -from multiprocessing.managers import BaseManager # noqa: F401,E402 + + +def fallocate(fileobj, size): + if hasattr(os, 'posix_fallocate'): + os.posix_fallocate(fileobj.fileno(), 0, size) + else: + fileobj.truncate(size) + + +# Import at end of file to avoid circular dependencies +from multiprocessing.managers import BaseManager # noqa: F401,E402 diff --git a/contrib/python/s3transfer/py3/s3transfer/constants.py b/contrib/python/s3transfer/py3/s3transfer/constants.py index 29da842982..ba35bc72e9 100644 --- a/contrib/python/s3transfer/py3/s3transfer/constants.py +++ b/contrib/python/s3transfer/py3/s3transfer/constants.py @@ -1,29 +1,29 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import s3transfer - -KB = 1024 -MB = KB * KB -GB = MB * KB - -ALLOWED_DOWNLOAD_ARGS = [ - 'VersionId', - 'SSECustomerAlgorithm', - 'SSECustomerKey', - 'SSECustomerKeyMD5', - 'RequestPayer', - 'ExpectedBucketOwner', -] - -USER_AGENT = 's3transfer/%s' % s3transfer.__version__ -PROCESS_USER_AGENT = '%s processpool' % USER_AGENT +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import s3transfer + +KB = 1024 +MB = KB * KB +GB = MB * KB + +ALLOWED_DOWNLOAD_ARGS = [ + 'VersionId', + 'SSECustomerAlgorithm', + 'SSECustomerKey', + 'SSECustomerKeyMD5', + 'RequestPayer', + 'ExpectedBucketOwner', +] + +USER_AGENT = 's3transfer/%s' % s3transfer.__version__ +PROCESS_USER_AGENT = '%s processpool' % USER_AGENT diff --git a/contrib/python/s3transfer/py3/s3transfer/copies.py b/contrib/python/s3transfer/py3/s3transfer/copies.py index 8daa710281..a1dfdc8ba3 100644 --- a/contrib/python/s3transfer/py3/s3transfer/copies.py +++ b/contrib/python/s3transfer/py3/s3transfer/copies.py @@ -13,18 +13,18 @@ import copy import math -from s3transfer.tasks import ( - CompleteMultipartUploadTask, - CreateMultipartUploadTask, - SubmissionTask, - Task, -) -from s3transfer.utils import ( - ChunksizeAdjuster, - calculate_range_parameter, - get_callbacks, - get_filtered_dict, -) +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import ( + ChunksizeAdjuster, + calculate_range_parameter, + get_callbacks, + get_filtered_dict, +) class CopySubmissionTask(SubmissionTask): @@ -38,8 +38,8 @@ class CopySubmissionTask(SubmissionTask): 'CopySourceSSECustomerKey': 'SSECustomerKey', 'CopySourceSSECustomerAlgorithm': 'SSECustomerAlgorithm', 'CopySourceSSECustomerKeyMD5': 'SSECustomerKeyMD5', - 'RequestPayer': 'RequestPayer', - 'ExpectedBucketOwner': 'ExpectedBucketOwner', + 'RequestPayer': 'RequestPayer', + 'ExpectedBucketOwner': 'ExpectedBucketOwner', } UPLOAD_PART_COPY_ARGS = [ @@ -54,7 +54,7 @@ class CopySubmissionTask(SubmissionTask): 'SSECustomerAlgorithm', 'SSECustomerKeyMD5', 'RequestPayer', - 'ExpectedBucketOwner', + 'ExpectedBucketOwner', ] CREATE_MULTIPART_ARGS_BLACKLIST = [ @@ -65,15 +65,15 @@ class CopySubmissionTask(SubmissionTask): 'CopySourceSSECustomerKey', 'CopySourceSSECustomerAlgorithm', 'CopySourceSSECustomerKeyMD5', - 'MetadataDirective', - 'TaggingDirective', + 'MetadataDirective', + 'TaggingDirective', ] - COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner'] + COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner'] - def _submit( - self, client, config, osutil, request_executor, transfer_future - ): + def _submit( + self, client, config, osutil, request_executor, transfer_future + ): """ :param client: The client associated with the transfer manager @@ -100,11 +100,11 @@ class CopySubmissionTask(SubmissionTask): # of the client, they may have to provide the file size themselves # with a completely new client. call_args = transfer_future.meta.call_args - head_object_request = ( + head_object_request = ( self._get_head_object_request_from_copy_source( - call_args.copy_source - ) - ) + call_args.copy_source + ) + ) extra_args = call_args.extra_args # Map any values that may be used in the head object that is @@ -112,30 +112,30 @@ class CopySubmissionTask(SubmissionTask): for param, value in extra_args.items(): if param in self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING: head_object_request[ - self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param] - ] = value + self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param] + ] = value response = call_args.source_client.head_object( - **head_object_request - ) + **head_object_request + ) transfer_future.meta.provide_transfer_size( - response['ContentLength'] - ) + response['ContentLength'] + ) # If it is greater than threshold do a multipart copy, otherwise # do a regular copy object. if transfer_future.meta.size < config.multipart_threshold: self._submit_copy_request( - client, config, osutil, request_executor, transfer_future - ) + client, config, osutil, request_executor, transfer_future + ) else: self._submit_multipart_request( - client, config, osutil, request_executor, transfer_future - ) + client, config, osutil, request_executor, transfer_future + ) - def _submit_copy_request( - self, client, config, osutil, request_executor, transfer_future - ): + def _submit_copy_request( + self, client, config, osutil, request_executor, transfer_future + ): call_args = transfer_future.meta.call_args # Get the needed progress callbacks for the task @@ -153,15 +153,15 @@ class CopySubmissionTask(SubmissionTask): 'key': call_args.key, 'extra_args': call_args.extra_args, 'callbacks': progress_callbacks, - 'size': transfer_future.meta.size, + 'size': transfer_future.meta.size, }, - is_final=True, - ), + is_final=True, + ), ) - def _submit_multipart_request( - self, client, config, osutil, request_executor, transfer_future - ): + def _submit_multipart_request( + self, client, config, osutil, request_executor, transfer_future + ): call_args = transfer_future.meta.call_args # Submit the request to create a multipart upload and make sure it @@ -180,8 +180,8 @@ class CopySubmissionTask(SubmissionTask): 'bucket': call_args.bucket, 'key': call_args.key, 'extra_args': create_multipart_extra_args, - }, - ), + }, + ), ) # Determine how many parts are needed based on filesize and @@ -189,11 +189,11 @@ class CopySubmissionTask(SubmissionTask): part_size = config.multipart_chunksize adjuster = ChunksizeAdjuster() part_size = adjuster.adjust_chunksize( - part_size, transfer_future.meta.size - ) + part_size, transfer_future.meta.size + ) num_parts = int( - math.ceil(transfer_future.meta.size / float(part_size)) - ) + math.ceil(transfer_future.meta.size / float(part_size)) + ) # Submit requests to upload the parts of the file. part_futures = [] @@ -201,24 +201,24 @@ class CopySubmissionTask(SubmissionTask): for part_number in range(1, num_parts + 1): extra_part_args = self._extra_upload_part_args( - call_args.extra_args - ) + call_args.extra_args + ) # The part number for upload part starts at 1 while the # range parameter starts at zero, so just subtract 1 off of # the part number extra_part_args['CopySourceRange'] = calculate_range_parameter( - part_size, - part_number - 1, - num_parts, - transfer_future.meta.size, - ) + part_size, + part_number - 1, + num_parts, + transfer_future.meta.size, + ) # Get the size of the part copy as well for the progress # callbacks. size = self._get_transfer_size( - part_size, - part_number - 1, - num_parts, - transfer_future.meta.size, + part_size, + part_number - 1, + num_parts, + transfer_future.meta.size, ) part_futures.append( self._transfer_coordinator.submit( @@ -233,18 +233,18 @@ class CopySubmissionTask(SubmissionTask): 'part_number': part_number, 'extra_args': extra_part_args, 'callbacks': progress_callbacks, - 'size': size, + 'size': size, }, pending_main_kwargs={ 'upload_id': create_multipart_future - }, - ), + }, + ), ) ) complete_multipart_extra_args = self._extra_complete_multipart_args( - call_args.extra_args - ) + call_args.extra_args + ) # Submit the request to complete the multipart upload. self._transfer_coordinator.submit( request_executor, @@ -258,10 +258,10 @@ class CopySubmissionTask(SubmissionTask): }, pending_main_kwargs={ 'upload_id': create_multipart_future, - 'parts': part_futures, + 'parts': part_futures, }, - is_final=True, - ), + is_final=True, + ), ) def _get_head_object_request_from_copy_source(self, copy_source): @@ -271,7 +271,7 @@ class CopySubmissionTask(SubmissionTask): raise TypeError( 'Expecting dictionary formatted: ' '{"Bucket": bucket_name, "Key": key} ' - 'but got %s or type %s.' % (copy_source, type(copy_source)) + 'but got %s or type %s.' % (copy_source, type(copy_source)) ) def _extra_upload_part_args(self, extra_args): @@ -282,9 +282,9 @@ class CopySubmissionTask(SubmissionTask): def _extra_complete_multipart_args(self, extra_args): return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS) - def _get_transfer_size( - self, part_size, part_index, num_parts, total_transfer_size - ): + def _get_transfer_size( + self, part_size, part_index, num_parts, total_transfer_size + ): if part_index == num_parts - 1: # The last part may be different in size then the rest of the # parts. @@ -294,10 +294,10 @@ class CopySubmissionTask(SubmissionTask): class CopyObjectTask(Task): """Task to do a nonmultipart copy""" - - def _main( - self, client, copy_source, bucket, key, extra_args, callbacks, size - ): + + def _main( + self, client, copy_source, bucket, key, extra_args, callbacks, size + ): """ :param client: The client to use when calling PutObject :param copy_source: The CopySource parameter to use @@ -311,27 +311,27 @@ class CopyObjectTask(Task): """ client.copy_object( - CopySource=copy_source, Bucket=bucket, Key=key, **extra_args - ) + CopySource=copy_source, Bucket=bucket, Key=key, **extra_args + ) for callback in callbacks: callback(bytes_transferred=size) class CopyPartTask(Task): """Task to upload a part in a multipart copy""" - - def _main( - self, - client, - copy_source, - bucket, - key, - upload_id, - part_number, - extra_args, - callbacks, - size, - ): + + def _main( + self, + client, + copy_source, + bucket, + key, + upload_id, + part_number, + extra_args, + callbacks, + size, + ): """ :param client: The client to use when calling PutObject :param copy_source: The CopySource parameter to use @@ -355,13 +355,13 @@ class CopyPartTask(Task): the multipart upload. """ response = client.upload_part_copy( - CopySource=copy_source, - Bucket=bucket, - Key=key, - UploadId=upload_id, - PartNumber=part_number, - **extra_args - ) + CopySource=copy_source, + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + **extra_args + ) for callback in callbacks: callback(bytes_transferred=size) etag = response['CopyPartResult']['ETag'] diff --git a/contrib/python/s3transfer/py3/s3transfer/crt.py b/contrib/python/s3transfer/py3/s3transfer/crt.py index b1d573c1b5..7b5d130136 100644 --- a/contrib/python/s3transfer/py3/s3transfer/crt.py +++ b/contrib/python/s3transfer/py3/s3transfer/crt.py @@ -1,644 +1,644 @@ -# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import logging -import threading -from io import BytesIO - -import awscrt.http -import botocore.awsrequest -import botocore.session -from awscrt.auth import AwsCredentials, AwsCredentialsProvider -from awscrt.io import ( - ClientBootstrap, - ClientTlsContext, - DefaultHostResolver, - EventLoopGroup, - TlsContextOptions, -) -from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType -from botocore import UNSIGNED -from botocore.compat import urlsplit -from botocore.config import Config -from botocore.exceptions import NoCredentialsError - -from s3transfer.constants import GB, MB -from s3transfer.exceptions import TransferNotDoneError -from s3transfer.futures import BaseTransferFuture, BaseTransferMeta -from s3transfer.utils import CallArgs, OSUtils, get_callbacks - -logger = logging.getLogger(__name__) - - -class CRTCredentialProviderAdapter: - def __init__(self, botocore_credential_provider): - self._botocore_credential_provider = botocore_credential_provider - self._loaded_credentials = None - self._lock = threading.Lock() - - def __call__(self): - credentials = self._get_credentials().get_frozen_credentials() - return AwsCredentials( - credentials.access_key, credentials.secret_key, credentials.token - ) - - def _get_credentials(self): - with self._lock: - if self._loaded_credentials is None: - loaded_creds = ( - self._botocore_credential_provider.load_credentials() - ) - if loaded_creds is None: - raise NoCredentialsError() - self._loaded_credentials = loaded_creds - return self._loaded_credentials - - -def create_s3_crt_client( - region, - botocore_credential_provider=None, - num_threads=None, - target_throughput=5 * GB / 8, - part_size=8 * MB, - use_ssl=True, - verify=None, -): - """ - :type region: str - :param region: The region used for signing - - :type botocore_credential_provider: - Optional[botocore.credentials.CredentialResolver] - :param botocore_credential_provider: Provide credentials for CRT - to sign the request if not set, the request will not be signed - - :type num_threads: Optional[int] - :param num_threads: Number of worker threads generated. Default - is the number of processors in the machine. - - :type target_throughput: Optional[int] - :param target_throughput: Throughput target in Bytes. - Default is 0.625 GB/s (which translates to 5 Gb/s). - - :type part_size: Optional[int] - :param part_size: Size, in Bytes, of parts that files will be downloaded - or uploaded in. - - :type use_ssl: boolean - :param use_ssl: Whether or not to use SSL. By default, SSL is used. - Note that not all services support non-ssl connections. - - :type verify: Optional[boolean/string] - :param verify: Whether or not to verify SSL certificates. - By default SSL certificates are verified. You can provide the - following values: - - * False - do not validate SSL certificates. SSL will still be - used (unless use_ssl is False), but SSL certificates - will not be verified. - * path/to/cert/bundle.pem - A filename of the CA cert bundle to - use. Specify this argument if you want to use a custom CA cert - bundle instead of the default one on your system. - """ - - event_loop_group = EventLoopGroup(num_threads) - host_resolver = DefaultHostResolver(event_loop_group) - bootstrap = ClientBootstrap(event_loop_group, host_resolver) - provider = None - tls_connection_options = None - - tls_mode = ( - S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED - ) - if verify is not None: - tls_ctx_options = TlsContextOptions() - if verify: - tls_ctx_options.override_default_trust_store_from_path( - ca_filepath=verify - ) - else: - tls_ctx_options.verify_peer = False - client_tls_option = ClientTlsContext(tls_ctx_options) - tls_connection_options = client_tls_option.new_connection_options() - if botocore_credential_provider: - credentails_provider_adapter = CRTCredentialProviderAdapter( - botocore_credential_provider - ) - provider = AwsCredentialsProvider.new_delegate( - credentails_provider_adapter - ) - - target_gbps = target_throughput * 8 / GB - return S3Client( - bootstrap=bootstrap, - region=region, - credential_provider=provider, - part_size=part_size, - tls_mode=tls_mode, - tls_connection_options=tls_connection_options, - throughput_target_gbps=target_gbps, - ) - - -class CRTTransferManager: - def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): - """A transfer manager interface for Amazon S3 on CRT s3 client. - - :type crt_s3_client: awscrt.s3.S3Client - :param crt_s3_client: The CRT s3 client, handling all the - HTTP requests and functions under then hood - - :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer - :param crt_request_serializer: Serializer, generates unsigned crt HTTP - request. - - :type osutil: s3transfer.utils.OSUtils - :param osutil: OSUtils object to use for os-related behavior when - using with transfer manager. - """ - if osutil is None: - self._osutil = OSUtils() - self._crt_s3_client = crt_s3_client - self._s3_args_creator = S3ClientArgsCreator( - crt_request_serializer, self._osutil - ) - self._future_coordinators = [] - self._semaphore = threading.Semaphore(128) # not configurable - # A counter to create unique id's for each transfer submitted. - self._id_counter = 0 - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, *args): - cancel = False - if exc_type: - cancel = True - self._shutdown(cancel) - - def download( - self, bucket, key, fileobj, extra_args=None, subscribers=None - ): - if extra_args is None: - extra_args = {} - if subscribers is None: - subscribers = {} - callargs = CallArgs( - bucket=bucket, - key=key, - fileobj=fileobj, - extra_args=extra_args, - subscribers=subscribers, - ) - return self._submit_transfer("get_object", callargs) - - def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): - if extra_args is None: - extra_args = {} - if subscribers is None: - subscribers = {} - callargs = CallArgs( - bucket=bucket, - key=key, - fileobj=fileobj, - extra_args=extra_args, - subscribers=subscribers, - ) - return self._submit_transfer("put_object", callargs) - - def delete(self, bucket, key, extra_args=None, subscribers=None): - if extra_args is None: - extra_args = {} - if subscribers is None: - subscribers = {} - callargs = CallArgs( - bucket=bucket, - key=key, - extra_args=extra_args, - subscribers=subscribers, - ) - return self._submit_transfer("delete_object", callargs) - - def shutdown(self, cancel=False): - self._shutdown(cancel) - - def _cancel_transfers(self): - for coordinator in self._future_coordinators: - if not coordinator.done(): - coordinator.cancel() - - def _finish_transfers(self): - for coordinator in self._future_coordinators: - coordinator.result() - - def _wait_transfers_done(self): - for coordinator in self._future_coordinators: - coordinator.wait_until_on_done_callbacks_complete() - - def _shutdown(self, cancel=False): - if cancel: - self._cancel_transfers() - try: - self._finish_transfers() - - except KeyboardInterrupt: - self._cancel_transfers() - except Exception: - pass - finally: - self._wait_transfers_done() - - def _release_semaphore(self, **kwargs): - self._semaphore.release() - - def _submit_transfer(self, request_type, call_args): - on_done_after_calls = [self._release_semaphore] - coordinator = CRTTransferCoordinator(transfer_id=self._id_counter) - components = { - 'meta': CRTTransferMeta(self._id_counter, call_args), - 'coordinator': coordinator, - } - future = CRTTransferFuture(**components) - afterdone = AfterDoneHandler(coordinator) - on_done_after_calls.append(afterdone) - - try: - self._semaphore.acquire() - on_queued = self._s3_args_creator.get_crt_callback( - future, 'queued' - ) - on_queued() - crt_callargs = self._s3_args_creator.get_make_request_args( - request_type, - call_args, - coordinator, - future, - on_done_after_calls, - ) - crt_s3_request = self._crt_s3_client.make_request(**crt_callargs) - except Exception as e: - coordinator.set_exception(e, True) - on_done = self._s3_args_creator.get_crt_callback( - future, 'done', after_subscribers=on_done_after_calls - ) - on_done(error=e) - else: - coordinator.set_s3_request(crt_s3_request) - self._future_coordinators.append(coordinator) - - self._id_counter += 1 - return future - - -class CRTTransferMeta(BaseTransferMeta): - """Holds metadata about the CRTTransferFuture""" - - def __init__(self, transfer_id=None, call_args=None): - self._transfer_id = transfer_id - self._call_args = call_args - self._user_context = {} - - @property - def call_args(self): - return self._call_args - - @property - def transfer_id(self): - return self._transfer_id - - @property - def user_context(self): - return self._user_context - - -class CRTTransferFuture(BaseTransferFuture): - def __init__(self, meta=None, coordinator=None): - """The future associated to a submitted transfer request via CRT S3 client - - :type meta: s3transfer.crt.CRTTransferMeta - :param meta: The metadata associated to the transfer future. - - :type coordinator: s3transfer.crt.CRTTransferCoordinator - :param coordinator: The coordinator associated to the transfer future. - """ - self._meta = meta - if meta is None: - self._meta = CRTTransferMeta() - self._coordinator = coordinator - - @property - def meta(self): - return self._meta - - def done(self): - return self._coordinator.done() - - def result(self, timeout=None): - self._coordinator.result(timeout) - - def cancel(self): - self._coordinator.cancel() - - def set_exception(self, exception): - """Sets the exception on the future.""" - if not self.done(): - raise TransferNotDoneError( - 'set_exception can only be called once the transfer is ' - 'complete.' - ) - self._coordinator.set_exception(exception, override=True) - - -class BaseCRTRequestSerializer: - def serialize_http_request(self, transfer_type, future): - """Serialize CRT HTTP requests. - - :type transfer_type: string - :param transfer_type: the type of transfer made, - e.g 'put_object', 'get_object', 'delete_object' - - :type future: s3transfer.crt.CRTTransferFuture - - :rtype: awscrt.http.HttpRequest - :returns: An unsigned HTTP request to be used for the CRT S3 client - """ - raise NotImplementedError('serialize_http_request()') - - -class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer): - def __init__(self, session, client_kwargs=None): - """Serialize CRT HTTP request using botocore logic - It also takes into account configuration from both the session - and any keyword arguments that could be passed to - `Session.create_client()` when serializing the request. - - :type session: botocore.session.Session - - :type client_kwargs: Optional[Dict[str, str]]) - :param client_kwargs: The kwargs for the botocore - s3 client initialization. - """ - self._session = session - if client_kwargs is None: - client_kwargs = {} - self._resolve_client_config(session, client_kwargs) - self._client = session.create_client(**client_kwargs) - self._client.meta.events.register( - 'request-created.s3.*', self._capture_http_request - ) - self._client.meta.events.register( - 'after-call.s3.*', self._change_response_to_serialized_http_request - ) - self._client.meta.events.register( - 'before-send.s3.*', self._make_fake_http_response - ) - - def _resolve_client_config(self, session, client_kwargs): - user_provided_config = None - if session.get_default_client_config(): - user_provided_config = session.get_default_client_config() - if 'config' in client_kwargs: - user_provided_config = client_kwargs['config'] - - client_config = Config(signature_version=UNSIGNED) - if user_provided_config: - client_config = user_provided_config.merge(client_config) - client_kwargs['config'] = client_config - client_kwargs["service_name"] = "s3" - - def _crt_request_from_aws_request(self, aws_request): - url_parts = urlsplit(aws_request.url) - crt_path = url_parts.path - if url_parts.query: - crt_path = f'{crt_path}?{url_parts.query}' - headers_list = [] - for name, value in aws_request.headers.items(): - if isinstance(value, str): - headers_list.append((name, value)) - else: - headers_list.append((name, str(value, 'utf-8'))) - - crt_headers = awscrt.http.HttpHeaders(headers_list) - # CRT requires body (if it exists) to be an I/O stream. - crt_body_stream = None - if aws_request.body: - if hasattr(aws_request.body, 'seek'): - crt_body_stream = aws_request.body - else: - crt_body_stream = BytesIO(aws_request.body) - - crt_request = awscrt.http.HttpRequest( - method=aws_request.method, - path=crt_path, - headers=crt_headers, - body_stream=crt_body_stream, - ) - return crt_request - - def _convert_to_crt_http_request(self, botocore_http_request): - # Logic that does CRTUtils.crt_request_from_aws_request - crt_request = self._crt_request_from_aws_request(botocore_http_request) - if crt_request.headers.get("host") is None: - # If host is not set, set it for the request before using CRT s3 - url_parts = urlsplit(botocore_http_request.url) - crt_request.headers.set("host", url_parts.netloc) - if crt_request.headers.get('Content-MD5') is not None: - crt_request.headers.remove("Content-MD5") - return crt_request - - def _capture_http_request(self, request, **kwargs): - request.context['http_request'] = request - - def _change_response_to_serialized_http_request( - self, context, parsed, **kwargs - ): - request = context['http_request'] - parsed['HTTPRequest'] = request.prepare() - - def _make_fake_http_response(self, request, **kwargs): - return botocore.awsrequest.AWSResponse( - None, - 200, - {}, - FakeRawResponse(b""), - ) - - def _get_botocore_http_request(self, client_method, call_args): - return getattr(self._client, client_method)( - Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args - )['HTTPRequest'] - - def serialize_http_request(self, transfer_type, future): - botocore_http_request = self._get_botocore_http_request( - transfer_type, future.meta.call_args - ) - crt_request = self._convert_to_crt_http_request(botocore_http_request) - return crt_request - - -class FakeRawResponse(BytesIO): - def stream(self, amt=1024, decode_content=None): - while True: - chunk = self.read(amt) - if not chunk: - break - yield chunk - - -class CRTTransferCoordinator: - """A helper class for managing CRTTransferFuture""" - - def __init__(self, transfer_id=None, s3_request=None): - self.transfer_id = transfer_id - self._s3_request = s3_request - self._lock = threading.Lock() - self._exception = None - self._crt_future = None - self._done_event = threading.Event() - - @property - def s3_request(self): - return self._s3_request - - def set_done_callbacks_complete(self): - self._done_event.set() - - def wait_until_on_done_callbacks_complete(self, timeout=None): - self._done_event.wait(timeout) - - def set_exception(self, exception, override=False): - with self._lock: - if not self.done() or override: - self._exception = exception - - def cancel(self): - if self._s3_request: - self._s3_request.cancel() - - def result(self, timeout=None): - if self._exception: - raise self._exception - try: - self._crt_future.result(timeout) - except KeyboardInterrupt: - self.cancel() - raise - finally: - if self._s3_request: - self._s3_request = None - self._crt_future.result(timeout) - - def done(self): - if self._crt_future is None: - return False - return self._crt_future.done() - - def set_s3_request(self, s3_request): - self._s3_request = s3_request - self._crt_future = self._s3_request.finished_future - - -class S3ClientArgsCreator: - def __init__(self, crt_request_serializer, os_utils): - self._request_serializer = crt_request_serializer - self._os_utils = os_utils - - def get_make_request_args( - self, request_type, call_args, coordinator, future, on_done_after_calls - ): - recv_filepath = None - send_filepath = None - s3_meta_request_type = getattr( - S3RequestType, request_type.upper(), S3RequestType.DEFAULT - ) - on_done_before_calls = [] - if s3_meta_request_type == S3RequestType.GET_OBJECT: - final_filepath = call_args.fileobj - recv_filepath = self._os_utils.get_temp_filename(final_filepath) - file_ondone_call = RenameTempFileHandler( - coordinator, final_filepath, recv_filepath, self._os_utils - ) - on_done_before_calls.append(file_ondone_call) - elif s3_meta_request_type == S3RequestType.PUT_OBJECT: - send_filepath = call_args.fileobj - data_len = self._os_utils.get_file_size(send_filepath) - call_args.extra_args["ContentLength"] = data_len - - crt_request = self._request_serializer.serialize_http_request( - request_type, future - ) - - return { - 'request': crt_request, - 'type': s3_meta_request_type, - 'recv_filepath': recv_filepath, - 'send_filepath': send_filepath, - 'on_done': self.get_crt_callback( - future, 'done', on_done_before_calls, on_done_after_calls - ), - 'on_progress': self.get_crt_callback(future, 'progress'), - } - - def get_crt_callback( - self, - future, - callback_type, - before_subscribers=None, - after_subscribers=None, - ): - def invoke_all_callbacks(*args, **kwargs): - callbacks_list = [] - if before_subscribers is not None: - callbacks_list += before_subscribers - callbacks_list += get_callbacks(future, callback_type) - if after_subscribers is not None: - callbacks_list += after_subscribers - for callback in callbacks_list: - # The get_callbacks helper will set the first augment - # by keyword, the other augments need to be set by keyword - # as well - if callback_type == "progress": - callback(bytes_transferred=args[0]) - else: - callback(*args, **kwargs) - - return invoke_all_callbacks - - -class RenameTempFileHandler: - def __init__(self, coordinator, final_filename, temp_filename, osutil): - self._coordinator = coordinator - self._final_filename = final_filename - self._temp_filename = temp_filename - self._osutil = osutil - - def __call__(self, **kwargs): - error = kwargs['error'] - if error: - self._osutil.remove_file(self._temp_filename) - else: - try: - self._osutil.rename_file( - self._temp_filename, self._final_filename - ) - except Exception as e: - self._osutil.remove_file(self._temp_filename) - # the CRT future has done already at this point - self._coordinator.set_exception(e) - - -class AfterDoneHandler: - def __init__(self, coordinator): - self._coordinator = coordinator - - def __call__(self, **kwargs): - self._coordinator.set_done_callbacks_complete() +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import logging +import threading +from io import BytesIO + +import awscrt.http +import botocore.awsrequest +import botocore.session +from awscrt.auth import AwsCredentials, AwsCredentialsProvider +from awscrt.io import ( + ClientBootstrap, + ClientTlsContext, + DefaultHostResolver, + EventLoopGroup, + TlsContextOptions, +) +from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from botocore import UNSIGNED +from botocore.compat import urlsplit +from botocore.config import Config +from botocore.exceptions import NoCredentialsError + +from s3transfer.constants import GB, MB +from s3transfer.exceptions import TransferNotDoneError +from s3transfer.futures import BaseTransferFuture, BaseTransferMeta +from s3transfer.utils import CallArgs, OSUtils, get_callbacks + +logger = logging.getLogger(__name__) + + +class CRTCredentialProviderAdapter: + def __init__(self, botocore_credential_provider): + self._botocore_credential_provider = botocore_credential_provider + self._loaded_credentials = None + self._lock = threading.Lock() + + def __call__(self): + credentials = self._get_credentials().get_frozen_credentials() + return AwsCredentials( + credentials.access_key, credentials.secret_key, credentials.token + ) + + def _get_credentials(self): + with self._lock: + if self._loaded_credentials is None: + loaded_creds = ( + self._botocore_credential_provider.load_credentials() + ) + if loaded_creds is None: + raise NoCredentialsError() + self._loaded_credentials = loaded_creds + return self._loaded_credentials + + +def create_s3_crt_client( + region, + botocore_credential_provider=None, + num_threads=None, + target_throughput=5 * GB / 8, + part_size=8 * MB, + use_ssl=True, + verify=None, +): + """ + :type region: str + :param region: The region used for signing + + :type botocore_credential_provider: + Optional[botocore.credentials.CredentialResolver] + :param botocore_credential_provider: Provide credentials for CRT + to sign the request if not set, the request will not be signed + + :type num_threads: Optional[int] + :param num_threads: Number of worker threads generated. Default + is the number of processors in the machine. + + :type target_throughput: Optional[int] + :param target_throughput: Throughput target in Bytes. + Default is 0.625 GB/s (which translates to 5 Gb/s). + + :type part_size: Optional[int] + :param part_size: Size, in Bytes, of parts that files will be downloaded + or uploaded in. + + :type use_ssl: boolean + :param use_ssl: Whether or not to use SSL. By default, SSL is used. + Note that not all services support non-ssl connections. + + :type verify: Optional[boolean/string] + :param verify: Whether or not to verify SSL certificates. + By default SSL certificates are verified. You can provide the + following values: + + * False - do not validate SSL certificates. SSL will still be + used (unless use_ssl is False), but SSL certificates + will not be verified. + * path/to/cert/bundle.pem - A filename of the CA cert bundle to + use. Specify this argument if you want to use a custom CA cert + bundle instead of the default one on your system. + """ + + event_loop_group = EventLoopGroup(num_threads) + host_resolver = DefaultHostResolver(event_loop_group) + bootstrap = ClientBootstrap(event_loop_group, host_resolver) + provider = None + tls_connection_options = None + + tls_mode = ( + S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED + ) + if verify is not None: + tls_ctx_options = TlsContextOptions() + if verify: + tls_ctx_options.override_default_trust_store_from_path( + ca_filepath=verify + ) + else: + tls_ctx_options.verify_peer = False + client_tls_option = ClientTlsContext(tls_ctx_options) + tls_connection_options = client_tls_option.new_connection_options() + if botocore_credential_provider: + credentails_provider_adapter = CRTCredentialProviderAdapter( + botocore_credential_provider + ) + provider = AwsCredentialsProvider.new_delegate( + credentails_provider_adapter + ) + + target_gbps = target_throughput * 8 / GB + return S3Client( + bootstrap=bootstrap, + region=region, + credential_provider=provider, + part_size=part_size, + tls_mode=tls_mode, + tls_connection_options=tls_connection_options, + throughput_target_gbps=target_gbps, + ) + + +class CRTTransferManager: + def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): + """A transfer manager interface for Amazon S3 on CRT s3 client. + + :type crt_s3_client: awscrt.s3.S3Client + :param crt_s3_client: The CRT s3 client, handling all the + HTTP requests and functions under then hood + + :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer + :param crt_request_serializer: Serializer, generates unsigned crt HTTP + request. + + :type osutil: s3transfer.utils.OSUtils + :param osutil: OSUtils object to use for os-related behavior when + using with transfer manager. + """ + if osutil is None: + self._osutil = OSUtils() + self._crt_s3_client = crt_s3_client + self._s3_args_creator = S3ClientArgsCreator( + crt_request_serializer, self._osutil + ) + self._future_coordinators = [] + self._semaphore = threading.Semaphore(128) # not configurable + # A counter to create unique id's for each transfer submitted. + self._id_counter = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, *args): + cancel = False + if exc_type: + cancel = True + self._shutdown(cancel) + + def download( + self, bucket, key, fileobj, extra_args=None, subscribers=None + ): + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = {} + callargs = CallArgs( + bucket=bucket, + key=key, + fileobj=fileobj, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer("get_object", callargs) + + def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = {} + callargs = CallArgs( + bucket=bucket, + key=key, + fileobj=fileobj, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer("put_object", callargs) + + def delete(self, bucket, key, extra_args=None, subscribers=None): + if extra_args is None: + extra_args = {} + if subscribers is None: + subscribers = {} + callargs = CallArgs( + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, + ) + return self._submit_transfer("delete_object", callargs) + + def shutdown(self, cancel=False): + self._shutdown(cancel) + + def _cancel_transfers(self): + for coordinator in self._future_coordinators: + if not coordinator.done(): + coordinator.cancel() + + def _finish_transfers(self): + for coordinator in self._future_coordinators: + coordinator.result() + + def _wait_transfers_done(self): + for coordinator in self._future_coordinators: + coordinator.wait_until_on_done_callbacks_complete() + + def _shutdown(self, cancel=False): + if cancel: + self._cancel_transfers() + try: + self._finish_transfers() + + except KeyboardInterrupt: + self._cancel_transfers() + except Exception: + pass + finally: + self._wait_transfers_done() + + def _release_semaphore(self, **kwargs): + self._semaphore.release() + + def _submit_transfer(self, request_type, call_args): + on_done_after_calls = [self._release_semaphore] + coordinator = CRTTransferCoordinator(transfer_id=self._id_counter) + components = { + 'meta': CRTTransferMeta(self._id_counter, call_args), + 'coordinator': coordinator, + } + future = CRTTransferFuture(**components) + afterdone = AfterDoneHandler(coordinator) + on_done_after_calls.append(afterdone) + + try: + self._semaphore.acquire() + on_queued = self._s3_args_creator.get_crt_callback( + future, 'queued' + ) + on_queued() + crt_callargs = self._s3_args_creator.get_make_request_args( + request_type, + call_args, + coordinator, + future, + on_done_after_calls, + ) + crt_s3_request = self._crt_s3_client.make_request(**crt_callargs) + except Exception as e: + coordinator.set_exception(e, True) + on_done = self._s3_args_creator.get_crt_callback( + future, 'done', after_subscribers=on_done_after_calls + ) + on_done(error=e) + else: + coordinator.set_s3_request(crt_s3_request) + self._future_coordinators.append(coordinator) + + self._id_counter += 1 + return future + + +class CRTTransferMeta(BaseTransferMeta): + """Holds metadata about the CRTTransferFuture""" + + def __init__(self, transfer_id=None, call_args=None): + self._transfer_id = transfer_id + self._call_args = call_args + self._user_context = {} + + @property + def call_args(self): + return self._call_args + + @property + def transfer_id(self): + return self._transfer_id + + @property + def user_context(self): + return self._user_context + + +class CRTTransferFuture(BaseTransferFuture): + def __init__(self, meta=None, coordinator=None): + """The future associated to a submitted transfer request via CRT S3 client + + :type meta: s3transfer.crt.CRTTransferMeta + :param meta: The metadata associated to the transfer future. + + :type coordinator: s3transfer.crt.CRTTransferCoordinator + :param coordinator: The coordinator associated to the transfer future. + """ + self._meta = meta + if meta is None: + self._meta = CRTTransferMeta() + self._coordinator = coordinator + + @property + def meta(self): + return self._meta + + def done(self): + return self._coordinator.done() + + def result(self, timeout=None): + self._coordinator.result(timeout) + + def cancel(self): + self._coordinator.cancel() + + def set_exception(self, exception): + """Sets the exception on the future.""" + if not self.done(): + raise TransferNotDoneError( + 'set_exception can only be called once the transfer is ' + 'complete.' + ) + self._coordinator.set_exception(exception, override=True) + + +class BaseCRTRequestSerializer: + def serialize_http_request(self, transfer_type, future): + """Serialize CRT HTTP requests. + + :type transfer_type: string + :param transfer_type: the type of transfer made, + e.g 'put_object', 'get_object', 'delete_object' + + :type future: s3transfer.crt.CRTTransferFuture + + :rtype: awscrt.http.HttpRequest + :returns: An unsigned HTTP request to be used for the CRT S3 client + """ + raise NotImplementedError('serialize_http_request()') + + +class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer): + def __init__(self, session, client_kwargs=None): + """Serialize CRT HTTP request using botocore logic + It also takes into account configuration from both the session + and any keyword arguments that could be passed to + `Session.create_client()` when serializing the request. + + :type session: botocore.session.Session + + :type client_kwargs: Optional[Dict[str, str]]) + :param client_kwargs: The kwargs for the botocore + s3 client initialization. + """ + self._session = session + if client_kwargs is None: + client_kwargs = {} + self._resolve_client_config(session, client_kwargs) + self._client = session.create_client(**client_kwargs) + self._client.meta.events.register( + 'request-created.s3.*', self._capture_http_request + ) + self._client.meta.events.register( + 'after-call.s3.*', self._change_response_to_serialized_http_request + ) + self._client.meta.events.register( + 'before-send.s3.*', self._make_fake_http_response + ) + + def _resolve_client_config(self, session, client_kwargs): + user_provided_config = None + if session.get_default_client_config(): + user_provided_config = session.get_default_client_config() + if 'config' in client_kwargs: + user_provided_config = client_kwargs['config'] + + client_config = Config(signature_version=UNSIGNED) + if user_provided_config: + client_config = user_provided_config.merge(client_config) + client_kwargs['config'] = client_config + client_kwargs["service_name"] = "s3" + + def _crt_request_from_aws_request(self, aws_request): + url_parts = urlsplit(aws_request.url) + crt_path = url_parts.path + if url_parts.query: + crt_path = f'{crt_path}?{url_parts.query}' + headers_list = [] + for name, value in aws_request.headers.items(): + if isinstance(value, str): + headers_list.append((name, value)) + else: + headers_list.append((name, str(value, 'utf-8'))) + + crt_headers = awscrt.http.HttpHeaders(headers_list) + # CRT requires body (if it exists) to be an I/O stream. + crt_body_stream = None + if aws_request.body: + if hasattr(aws_request.body, 'seek'): + crt_body_stream = aws_request.body + else: + crt_body_stream = BytesIO(aws_request.body) + + crt_request = awscrt.http.HttpRequest( + method=aws_request.method, + path=crt_path, + headers=crt_headers, + body_stream=crt_body_stream, + ) + return crt_request + + def _convert_to_crt_http_request(self, botocore_http_request): + # Logic that does CRTUtils.crt_request_from_aws_request + crt_request = self._crt_request_from_aws_request(botocore_http_request) + if crt_request.headers.get("host") is None: + # If host is not set, set it for the request before using CRT s3 + url_parts = urlsplit(botocore_http_request.url) + crt_request.headers.set("host", url_parts.netloc) + if crt_request.headers.get('Content-MD5') is not None: + crt_request.headers.remove("Content-MD5") + return crt_request + + def _capture_http_request(self, request, **kwargs): + request.context['http_request'] = request + + def _change_response_to_serialized_http_request( + self, context, parsed, **kwargs + ): + request = context['http_request'] + parsed['HTTPRequest'] = request.prepare() + + def _make_fake_http_response(self, request, **kwargs): + return botocore.awsrequest.AWSResponse( + None, + 200, + {}, + FakeRawResponse(b""), + ) + + def _get_botocore_http_request(self, client_method, call_args): + return getattr(self._client, client_method)( + Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args + )['HTTPRequest'] + + def serialize_http_request(self, transfer_type, future): + botocore_http_request = self._get_botocore_http_request( + transfer_type, future.meta.call_args + ) + crt_request = self._convert_to_crt_http_request(botocore_http_request) + return crt_request + + +class FakeRawResponse(BytesIO): + def stream(self, amt=1024, decode_content=None): + while True: + chunk = self.read(amt) + if not chunk: + break + yield chunk + + +class CRTTransferCoordinator: + """A helper class for managing CRTTransferFuture""" + + def __init__(self, transfer_id=None, s3_request=None): + self.transfer_id = transfer_id + self._s3_request = s3_request + self._lock = threading.Lock() + self._exception = None + self._crt_future = None + self._done_event = threading.Event() + + @property + def s3_request(self): + return self._s3_request + + def set_done_callbacks_complete(self): + self._done_event.set() + + def wait_until_on_done_callbacks_complete(self, timeout=None): + self._done_event.wait(timeout) + + def set_exception(self, exception, override=False): + with self._lock: + if not self.done() or override: + self._exception = exception + + def cancel(self): + if self._s3_request: + self._s3_request.cancel() + + def result(self, timeout=None): + if self._exception: + raise self._exception + try: + self._crt_future.result(timeout) + except KeyboardInterrupt: + self.cancel() + raise + finally: + if self._s3_request: + self._s3_request = None + self._crt_future.result(timeout) + + def done(self): + if self._crt_future is None: + return False + return self._crt_future.done() + + def set_s3_request(self, s3_request): + self._s3_request = s3_request + self._crt_future = self._s3_request.finished_future + + +class S3ClientArgsCreator: + def __init__(self, crt_request_serializer, os_utils): + self._request_serializer = crt_request_serializer + self._os_utils = os_utils + + def get_make_request_args( + self, request_type, call_args, coordinator, future, on_done_after_calls + ): + recv_filepath = None + send_filepath = None + s3_meta_request_type = getattr( + S3RequestType, request_type.upper(), S3RequestType.DEFAULT + ) + on_done_before_calls = [] + if s3_meta_request_type == S3RequestType.GET_OBJECT: + final_filepath = call_args.fileobj + recv_filepath = self._os_utils.get_temp_filename(final_filepath) + file_ondone_call = RenameTempFileHandler( + coordinator, final_filepath, recv_filepath, self._os_utils + ) + on_done_before_calls.append(file_ondone_call) + elif s3_meta_request_type == S3RequestType.PUT_OBJECT: + send_filepath = call_args.fileobj + data_len = self._os_utils.get_file_size(send_filepath) + call_args.extra_args["ContentLength"] = data_len + + crt_request = self._request_serializer.serialize_http_request( + request_type, future + ) + + return { + 'request': crt_request, + 'type': s3_meta_request_type, + 'recv_filepath': recv_filepath, + 'send_filepath': send_filepath, + 'on_done': self.get_crt_callback( + future, 'done', on_done_before_calls, on_done_after_calls + ), + 'on_progress': self.get_crt_callback(future, 'progress'), + } + + def get_crt_callback( + self, + future, + callback_type, + before_subscribers=None, + after_subscribers=None, + ): + def invoke_all_callbacks(*args, **kwargs): + callbacks_list = [] + if before_subscribers is not None: + callbacks_list += before_subscribers + callbacks_list += get_callbacks(future, callback_type) + if after_subscribers is not None: + callbacks_list += after_subscribers + for callback in callbacks_list: + # The get_callbacks helper will set the first augment + # by keyword, the other augments need to be set by keyword + # as well + if callback_type == "progress": + callback(bytes_transferred=args[0]) + else: + callback(*args, **kwargs) + + return invoke_all_callbacks + + +class RenameTempFileHandler: + def __init__(self, coordinator, final_filename, temp_filename, osutil): + self._coordinator = coordinator + self._final_filename = final_filename + self._temp_filename = temp_filename + self._osutil = osutil + + def __call__(self, **kwargs): + error = kwargs['error'] + if error: + self._osutil.remove_file(self._temp_filename) + else: + try: + self._osutil.rename_file( + self._temp_filename, self._final_filename + ) + except Exception as e: + self._osutil.remove_file(self._temp_filename) + # the CRT future has done already at this point + self._coordinator.set_exception(e) + + +class AfterDoneHandler: + def __init__(self, coordinator): + self._coordinator = coordinator + + def __call__(self, **kwargs): + self._coordinator.set_done_callbacks_complete() diff --git a/contrib/python/s3transfer/py3/s3transfer/delete.py b/contrib/python/s3transfer/py3/s3transfer/delete.py index e179d45d4c..74084d312a 100644 --- a/contrib/python/s3transfer/py3/s3transfer/delete.py +++ b/contrib/python/s3transfer/py3/s3transfer/delete.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from s3transfer.tasks import SubmissionTask, Task +from s3transfer.tasks import SubmissionTask, Task class DeleteSubmissionTask(SubmissionTask): @@ -47,8 +47,8 @@ class DeleteSubmissionTask(SubmissionTask): 'key': call_args.key, 'extra_args': call_args.extra_args, }, - is_final=True, - ), + is_final=True, + ), ) diff --git a/contrib/python/s3transfer/py3/s3transfer/download.py b/contrib/python/s3transfer/py3/s3transfer/download.py index 911912cab5..dc8980d4ed 100644 --- a/contrib/python/s3transfer/py3/s3transfer/download.py +++ b/contrib/python/s3transfer/py3/s3transfer/download.py @@ -10,30 +10,30 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import heapq +import heapq import logging import threading from s3transfer.compat import seekable from s3transfer.exceptions import RetriesExceededError from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG -from s3transfer.tasks import SubmissionTask, Task -from s3transfer.utils import ( - S3_RETRYABLE_DOWNLOAD_ERRORS, - CountCallbackInvoker, - DeferredOpenFile, - FunctionContainer, - StreamReaderProgress, - calculate_num_parts, - calculate_range_parameter, - get_callbacks, - invoke_progress_callbacks, -) +from s3transfer.tasks import SubmissionTask, Task +from s3transfer.utils import ( + S3_RETRYABLE_DOWNLOAD_ERRORS, + CountCallbackInvoker, + DeferredOpenFile, + FunctionContainer, + StreamReaderProgress, + calculate_num_parts, + calculate_range_parameter, + get_callbacks, + invoke_progress_callbacks, +) logger = logging.getLogger(__name__) -class DownloadOutputManager: +class DownloadOutputManager: """Base manager class for handling various types of files for downloads This class is typically used for the DownloadSubmissionTask class to help @@ -46,7 +46,7 @@ class DownloadOutputManager: that may be accepted. All implementations must subclass and override public methods from this class. """ - + def __init__(self, osutil, transfer_coordinator, io_executor): self._osutil = osutil self._transfer_coordinator = transfer_coordinator @@ -94,8 +94,8 @@ class DownloadOutputManager: """ self._transfer_coordinator.submit( - self._io_executor, self.get_io_write_task(fileobj, data, offset) - ) + self._io_executor, self.get_io_write_task(fileobj, data, offset) + ) def get_io_write_task(self, fileobj, data, offset): """Get an IO write task for the requested set of data @@ -120,7 +120,7 @@ class DownloadOutputManager: 'fileobj': fileobj, 'data': data, 'offset': offset, - }, + }, ) def get_final_io_task(self): @@ -134,12 +134,12 @@ class DownloadOutputManager: :rtype: s3transfer.tasks.Task :returns: A final task to completed in the io executor """ - raise NotImplementedError('must implement get_final_io_task()') + raise NotImplementedError('must implement get_final_io_task()') def _get_fileobj_from_filename(self, filename): f = DeferredOpenFile( - filename, mode='wb', open_function=self._osutil.open - ) + filename, mode='wb', open_function=self._osutil.open + ) # Make sure the file gets closed and we remove the temporary file # if anything goes wrong during the process. self._transfer_coordinator.add_failure_cleanup(f.close) @@ -148,19 +148,19 @@ class DownloadOutputManager: class DownloadFilenameOutputManager(DownloadOutputManager): def __init__(self, osutil, transfer_coordinator, io_executor): - super().__init__(osutil, transfer_coordinator, io_executor) + super().__init__(osutil, transfer_coordinator, io_executor) self._final_filename = None self._temp_filename = None self._temp_fileobj = None @classmethod def is_compatible(cls, download_target, osutil): - return isinstance(download_target, str) + return isinstance(download_target, str) def get_fileobj_for_io_writes(self, transfer_future): fileobj = transfer_future.meta.call_args.fileobj self._final_filename = fileobj - self._temp_filename = self._osutil.get_temp_filename(fileobj) + self._temp_filename = self._osutil.get_temp_filename(fileobj) self._temp_fileobj = self._get_temp_fileobj() return self._temp_fileobj @@ -173,16 +173,16 @@ class DownloadFilenameOutputManager(DownloadOutputManager): main_kwargs={ 'fileobj': self._temp_fileobj, 'final_filename': self._final_filename, - 'osutil': self._osutil, + 'osutil': self._osutil, }, - is_final=True, + is_final=True, ) def _get_temp_fileobj(self): f = self._get_fileobj_from_filename(self._temp_filename) self._transfer_coordinator.add_failure_cleanup( - self._osutil.remove_file, self._temp_filename - ) + self._osutil.remove_file, self._temp_filename + ) return f @@ -199,15 +199,15 @@ class DownloadSeekableOutputManager(DownloadOutputManager): # This task will serve the purpose of signaling when all of the io # writes have finished so done callbacks can be called. return CompleteDownloadNOOPTask( - transfer_coordinator=self._transfer_coordinator - ) + transfer_coordinator=self._transfer_coordinator + ) class DownloadNonSeekableOutputManager(DownloadOutputManager): - def __init__( - self, osutil, transfer_coordinator, io_executor, defer_queue=None - ): - super().__init__(osutil, transfer_coordinator, io_executor) + def __init__( + self, osutil, transfer_coordinator, io_executor, defer_queue=None + ): + super().__init__(osutil, transfer_coordinator, io_executor) if defer_queue is None: defer_queue = DeferQueue() self._defer_queue = defer_queue @@ -225,20 +225,20 @@ class DownloadNonSeekableOutputManager(DownloadOutputManager): def get_final_io_task(self): return CompleteDownloadNOOPTask( - transfer_coordinator=self._transfer_coordinator - ) + transfer_coordinator=self._transfer_coordinator + ) def queue_file_io_task(self, fileobj, data, offset): with self._io_submit_lock: writes = self._defer_queue.request_writes(offset, data) for write in writes: data = write['data'] - logger.debug( - "Queueing IO offset %s for fileobj: %s", - write['offset'], - fileobj, - ) - super().queue_file_io_task(fileobj, data, offset) + logger.debug( + "Queueing IO offset %s for fileobj: %s", + write['offset'], + fileobj, + ) + super().queue_file_io_task(fileobj, data, offset) def get_io_write_task(self, fileobj, data, offset): return IOStreamingWriteTask( @@ -246,24 +246,24 @@ class DownloadNonSeekableOutputManager(DownloadOutputManager): main_kwargs={ 'fileobj': fileobj, 'data': data, - }, + }, ) class DownloadSpecialFilenameOutputManager(DownloadNonSeekableOutputManager): - def __init__( - self, osutil, transfer_coordinator, io_executor, defer_queue=None - ): - super().__init__( - osutil, transfer_coordinator, io_executor, defer_queue - ) + def __init__( + self, osutil, transfer_coordinator, io_executor, defer_queue=None + ): + super().__init__( + osutil, transfer_coordinator, io_executor, defer_queue + ) self._fileobj = None @classmethod def is_compatible(cls, download_target, osutil): - return isinstance(download_target, str) and osutil.is_special_file( - download_target - ) + return isinstance(download_target, str) and osutil.is_special_file( + download_target + ) def get_fileobj_for_io_writes(self, transfer_future): filename = transfer_future.meta.call_args.fileobj @@ -275,8 +275,8 @@ class DownloadSpecialFilenameOutputManager(DownloadNonSeekableOutputManager): return IOCloseTask( transfer_coordinator=self._transfer_coordinator, is_final=True, - main_kwargs={'fileobj': self._fileobj}, - ) + main_kwargs={'fileobj': self._fileobj}, + ) class DownloadSubmissionTask(SubmissionTask): @@ -307,21 +307,21 @@ class DownloadSubmissionTask(SubmissionTask): if download_manager_cls.is_compatible(fileobj, osutil): return download_manager_cls raise RuntimeError( - 'Output {} of type: {} is not supported.'.format( - fileobj, type(fileobj) - ) - ) - - def _submit( - self, - client, - config, - osutil, - request_executor, - io_executor, - transfer_future, - bandwidth_limiter=None, - ): + 'Output {} of type: {} is not supported.'.format( + fileobj, type(fileobj) + ) + ) + + def _submit( + self, + client, + config, + osutil, + request_executor, + io_executor, + transfer_future, + bandwidth_limiter=None, + ): """ :param client: The client associated with the transfer manager @@ -354,59 +354,59 @@ class DownloadSubmissionTask(SubmissionTask): response = client.head_object( Bucket=transfer_future.meta.call_args.bucket, Key=transfer_future.meta.call_args.key, - **transfer_future.meta.call_args.extra_args, + **transfer_future.meta.call_args.extra_args, ) transfer_future.meta.provide_transfer_size( - response['ContentLength'] - ) + response['ContentLength'] + ) download_output_manager = self._get_download_output_manager_cls( - transfer_future, osutil - )(osutil, self._transfer_coordinator, io_executor) + transfer_future, osutil + )(osutil, self._transfer_coordinator, io_executor) # If it is greater than threshold do a ranged download, otherwise # do a regular GetObject download. if transfer_future.meta.size < config.multipart_threshold: self._submit_download_request( - client, - config, - osutil, - request_executor, - io_executor, - download_output_manager, - transfer_future, - bandwidth_limiter, - ) + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ) else: self._submit_ranged_download_request( - client, - config, - osutil, - request_executor, - io_executor, - download_output_manager, - transfer_future, - bandwidth_limiter, - ) - - def _submit_download_request( - self, - client, - config, - osutil, - request_executor, - io_executor, - download_output_manager, - transfer_future, - bandwidth_limiter, - ): + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ) + + def _submit_download_request( + self, + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ): call_args = transfer_future.meta.call_args # Get a handle to the file that will be used for writing downloaded # contents fileobj = download_output_manager.get_fileobj_for_io_writes( - transfer_future - ) + transfer_future + ) # Get the needed callbacks for the task progress_callbacks = get_callbacks(transfer_future, 'progress') @@ -432,24 +432,24 @@ class DownloadSubmissionTask(SubmissionTask): 'max_attempts': config.num_download_attempts, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, - 'bandwidth_limiter': bandwidth_limiter, + 'bandwidth_limiter': bandwidth_limiter, }, - done_callbacks=[final_task], + done_callbacks=[final_task], ), - tag=get_object_tag, + tag=get_object_tag, ) - def _submit_ranged_download_request( - self, - client, - config, - osutil, - request_executor, - io_executor, - download_output_manager, - transfer_future, - bandwidth_limiter, - ): + def _submit_ranged_download_request( + self, + client, + config, + osutil, + request_executor, + io_executor, + download_output_manager, + transfer_future, + bandwidth_limiter, + ): call_args = transfer_future.meta.call_args # Get the needed progress callbacks for the task @@ -458,12 +458,12 @@ class DownloadSubmissionTask(SubmissionTask): # Get a handle to the file that will be used for writing downloaded # contents fileobj = download_output_manager.get_fileobj_for_io_writes( - transfer_future - ) + transfer_future + ) # Determine the number of parts part_size = config.multipart_chunksize - num_parts = calculate_num_parts(transfer_future.meta.size, part_size) + num_parts = calculate_num_parts(transfer_future.meta.size, part_size) # Get any associated tags for the get object task. get_object_tag = download_output_manager.get_download_task_tag() @@ -478,8 +478,8 @@ class DownloadSubmissionTask(SubmissionTask): for i in range(num_parts): # Calculate the range parameter range_parameter = calculate_range_parameter( - part_size, i, num_parts - ) + part_size, i, num_parts + ) # Inject the Range parameter to the parameters to be passed in # as extra args @@ -502,21 +502,21 @@ class DownloadSubmissionTask(SubmissionTask): 'start_index': i * part_size, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, - 'bandwidth_limiter': bandwidth_limiter, + 'bandwidth_limiter': bandwidth_limiter, }, - done_callbacks=[finalize_download_invoker.decrement], + done_callbacks=[finalize_download_invoker.decrement], ), - tag=get_object_tag, + tag=get_object_tag, ) finalize_download_invoker.finalize() - def _get_final_io_task_submission_callback( - self, download_manager, io_executor - ): + def _get_final_io_task_submission_callback( + self, download_manager, io_executor + ): final_task = download_manager.get_final_io_task() return FunctionContainer( - self._transfer_coordinator.submit, io_executor, final_task - ) + self._transfer_coordinator.submit, io_executor, final_task + ) def _calculate_range_param(self, part_size, part_index, num_parts): # Used to calculate the Range parameter @@ -525,32 +525,32 @@ class DownloadSubmissionTask(SubmissionTask): end_range = '' else: end_range = start_range + part_size - 1 - range_param = f'bytes={start_range}-{end_range}' + range_param = f'bytes={start_range}-{end_range}' return range_param class GetObjectTask(Task): - def _main( - self, - client, - bucket, - key, - fileobj, - extra_args, - callbacks, - max_attempts, - download_output_manager, - io_chunksize, - start_index=0, - bandwidth_limiter=None, - ): + def _main( + self, + client, + bucket, + key, + fileobj, + extra_args, + callbacks, + max_attempts, + download_output_manager, + io_chunksize, + start_index=0, + bandwidth_limiter=None, + ): """Downloads an object and places content into io queue :param client: The client to use when calling GetObject :param bucket: The bucket to download from :param key: The key to download from :param fileobj: The file handle to write content to - :param exta_args: Any extra arguments to include in GetObject request + :param exta_args: Any extra arguments to include in GetObject request :param callbacks: List of progress callbacks to invoke on download :param max_attempts: The number of retries to do when downloading :param download_output_manager: The download output manager associated @@ -565,19 +565,19 @@ class GetObjectTask(Task): last_exception = None for i in range(max_attempts): try: - current_index = start_index + current_index = start_index response = client.get_object( - Bucket=bucket, Key=key, **extra_args - ) + Bucket=bucket, Key=key, **extra_args + ) streaming_body = StreamReaderProgress( - response['Body'], callbacks - ) + response['Body'], callbacks + ) if bandwidth_limiter: - streaming_body = ( + streaming_body = ( bandwidth_limiter.get_bandwith_limited_stream( - streaming_body, self._transfer_coordinator - ) - ) + streaming_body, self._transfer_coordinator + ) + ) chunks = DownloadChunkIterator(streaming_body, io_chunksize) for chunk in chunks: @@ -586,31 +586,31 @@ class GetObjectTask(Task): # data to be written and break out of the download. if not self._transfer_coordinator.done(): self._handle_io( - download_output_manager, - fileobj, - chunk, - current_index, + download_output_manager, + fileobj, + chunk, + current_index, ) current_index += len(chunk) else: return return - except S3_RETRYABLE_DOWNLOAD_ERRORS as e: - logger.debug( - "Retrying exception caught (%s), " - "retrying request, (attempt %s / %s)", - e, - i, - max_attempts, - exc_info=True, - ) + except S3_RETRYABLE_DOWNLOAD_ERRORS as e: + logger.debug( + "Retrying exception caught (%s), " + "retrying request, (attempt %s / %s)", + e, + i, + max_attempts, + exc_info=True, + ) last_exception = e # Also invoke the progress callbacks to indicate that we # are trying to download the stream again and all progress # for this GetObject has been lost. invoke_progress_callbacks( - callbacks, start_index - current_index - ) + callbacks, start_index - current_index + ) continue raise RetriesExceededError(last_exception) @@ -625,7 +625,7 @@ class ImmediatelyWriteIOGetObjectTask(GetObjectTask): downloading the object so there is no reason to go through the overhead of using an IO queue and executor. """ - + def _handle_io(self, download_output_manager, fileobj, chunk, index): task = download_output_manager.get_io_write_task(fileobj, chunk, index) task() @@ -649,7 +649,7 @@ class IOStreamingWriteTask(Task): def _main(self, fileobj, data): """Write data to a fileobj. - Data will be written directly to the fileobj without + Data will be written directly to the fileobj without any prior seeking. :param fileobj: The fileobj to write content to @@ -667,7 +667,7 @@ class IORenameFileTask(Task): upon completion of writing the contents. :param osutil: OS utility """ - + def _main(self, fileobj, final_filename, osutil): fileobj.close() osutil.rename_file(fileobj.name, final_filename) @@ -678,7 +678,7 @@ class IOCloseTask(Task): :param fileobj: The fileobj to close. """ - + def _main(self, fileobj): fileobj.close() @@ -689,28 +689,28 @@ class CompleteDownloadNOOPTask(Task): Note that the default for is_final is set to True because this should always be the last task. """ - - def __init__( - self, - transfer_coordinator, - main_kwargs=None, - pending_main_kwargs=None, - done_callbacks=None, - is_final=True, - ): - super().__init__( + + def __init__( + self, + transfer_coordinator, + main_kwargs=None, + pending_main_kwargs=None, + done_callbacks=None, + is_final=True, + ): + super().__init__( transfer_coordinator=transfer_coordinator, main_kwargs=main_kwargs, pending_main_kwargs=pending_main_kwargs, done_callbacks=done_callbacks, - is_final=is_final, + is_final=is_final, ) def _main(self): pass -class DownloadChunkIterator: +class DownloadChunkIterator: def __init__(self, body, chunksize): """Iterator to chunk out a downloaded S3 stream @@ -732,7 +732,7 @@ class DownloadChunkIterator: elif self._num_reads == 1: # Even though the response may have not had any # content, we still want to account for an empty object's - # existence so return the empty chunk for that initial + # existence so return the empty chunk for that initial # read. return chunk raise StopIteration() @@ -740,7 +740,7 @@ class DownloadChunkIterator: next = __next__ -class DeferQueue: +class DeferQueue: """IO queue that defers write requests until they are queued sequentially. This class is used to track IO data for a *single* fileobj. @@ -749,7 +749,7 @@ class DeferQueue: until it has the next contiguous block available (starting at 0). """ - + def __init__(self): self._writes = [] self._pending_offsets = set() diff --git a/contrib/python/s3transfer/py3/s3transfer/exceptions.py b/contrib/python/s3transfer/py3/s3transfer/exceptions.py index 2b4ebea66c..6150fe650d 100644 --- a/contrib/python/s3transfer/py3/s3transfer/exceptions.py +++ b/contrib/python/s3transfer/py3/s3transfer/exceptions.py @@ -15,7 +15,7 @@ from concurrent.futures import CancelledError class RetriesExceededError(Exception): def __init__(self, last_exception, msg='Max Retries Exceeded'): - super().__init__(msg) + super().__init__(msg) self.last_exception = last_exception @@ -33,5 +33,5 @@ class TransferNotDoneError(Exception): class FatalError(CancelledError): """A CancelledError raised from an error in the TransferManager""" - + pass diff --git a/contrib/python/s3transfer/py3/s3transfer/futures.py b/contrib/python/s3transfer/py3/s3transfer/futures.py index 5fe4a7c624..39e071fb60 100644 --- a/contrib/python/s3transfer/py3/s3transfer/futures.py +++ b/contrib/python/s3transfer/py3/s3transfer/futures.py @@ -14,61 +14,61 @@ import copy import logging import sys import threading -from collections import namedtuple -from concurrent import futures +from collections import namedtuple +from concurrent import futures from s3transfer.compat import MAXINT from s3transfer.exceptions import CancelledError, TransferNotDoneError -from s3transfer.utils import FunctionContainer, TaskSemaphore +from s3transfer.utils import FunctionContainer, TaskSemaphore logger = logging.getLogger(__name__) -class BaseTransferFuture: - @property - def meta(self): - """The metadata associated to the TransferFuture""" - raise NotImplementedError('meta') - - def done(self): - """Determines if a TransferFuture has completed - - :returns: True if completed. False, otherwise. - """ - raise NotImplementedError('done()') - - def result(self): - """Waits until TransferFuture is done and returns the result - - If the TransferFuture succeeded, it will return the result. If the - TransferFuture failed, it will raise the exception associated to the - failure. - """ - raise NotImplementedError('result()') - - def cancel(self): - """Cancels the request associated with the TransferFuture""" - raise NotImplementedError('cancel()') - - -class BaseTransferMeta: - @property - def call_args(self): - """The call args used in the transfer request""" - raise NotImplementedError('call_args') - - @property - def transfer_id(self): - """The unique id of the transfer""" - raise NotImplementedError('transfer_id') - - @property - def user_context(self): - """A dictionary that requesters can store data in""" - raise NotImplementedError('user_context') - - -class TransferFuture(BaseTransferFuture): +class BaseTransferFuture: + @property + def meta(self): + """The metadata associated to the TransferFuture""" + raise NotImplementedError('meta') + + def done(self): + """Determines if a TransferFuture has completed + + :returns: True if completed. False, otherwise. + """ + raise NotImplementedError('done()') + + def result(self): + """Waits until TransferFuture is done and returns the result + + If the TransferFuture succeeded, it will return the result. If the + TransferFuture failed, it will raise the exception associated to the + failure. + """ + raise NotImplementedError('result()') + + def cancel(self): + """Cancels the request associated with the TransferFuture""" + raise NotImplementedError('cancel()') + + +class BaseTransferMeta: + @property + def call_args(self): + """The call args used in the transfer request""" + raise NotImplementedError('call_args') + + @property + def transfer_id(self): + """The unique id of the transfer""" + raise NotImplementedError('transfer_id') + + @property + def user_context(self): + """A dictionary that requesters can store data in""" + raise NotImplementedError('user_context') + + +class TransferFuture(BaseTransferFuture): def __init__(self, meta=None, coordinator=None): """The future associated to a submitted transfer request @@ -99,7 +99,7 @@ class TransferFuture(BaseTransferFuture): try: # Usually the result() method blocks until the transfer is done, # however if a KeyboardInterrupt is raised we want want to exit - # out of this and propagate the exception. + # out of this and propagate the exception. return self._coordinator.result() except KeyboardInterrupt as e: self.cancel() @@ -113,14 +113,14 @@ class TransferFuture(BaseTransferFuture): if not self.done(): raise TransferNotDoneError( 'set_exception can only be called once the transfer is ' - 'complete.' - ) + 'complete.' + ) self._coordinator.set_exception(exception, override=True) -class TransferMeta(BaseTransferMeta): +class TransferMeta(BaseTransferMeta): """Holds metadata about the TransferFuture""" - + def __init__(self, call_args=None, transfer_id=None): self._call_args = call_args self._transfer_id = transfer_id @@ -157,9 +157,9 @@ class TransferMeta(BaseTransferMeta): self._size = size -class TransferCoordinator: +class TransferCoordinator: """A helper class for managing TransferFuture""" - + def __init__(self, transfer_id=None): self.transfer_id = transfer_id self._status = 'not-started' @@ -175,9 +175,9 @@ class TransferCoordinator: self._failure_cleanups_lock = threading.Lock() def __repr__(self): - return '{}(transfer_id={})'.format( - self.__class__.__name__, self.transfer_id - ) + return '{}(transfer_id={})'.format( + self.__class__.__name__, self.transfer_id + ) @property def exception(self): @@ -296,8 +296,8 @@ class TransferCoordinator: if self.done(): raise RuntimeError( 'Unable to transition from done state %s to non-done ' - 'state %s.' % (self.status, desired_state) - ) + 'state %s.' % (self.status, desired_state) + ) self._status = desired_state def submit(self, executor, task, tag=None): @@ -316,17 +316,17 @@ class TransferCoordinator: :returns: A future representing the submitted task """ logger.debug( - "Submitting task {} to executor {} for transfer request: {}.".format( - task, executor, self.transfer_id - ) + "Submitting task {} to executor {} for transfer request: {}.".format( + task, executor, self.transfer_id + ) ) future = executor.submit(task, tag=tag) # Add this created future to the list of associated future just # in case it is needed during cleanups. self.add_associated_future(future) future.add_done_callback( - FunctionContainer(self.remove_associated_future, future) - ) + FunctionContainer(self.remove_associated_future, future) + ) return future def done(self): @@ -358,8 +358,8 @@ class TransferCoordinator: """Adds a callback to call upon failure""" with self._failure_cleanups_lock: self._failure_cleanups.append( - FunctionContainer(function, *args, **kwargs) - ) + FunctionContainer(function, *args, **kwargs) + ) def announce_done(self): """Announce that future is done running and run associated callbacks @@ -398,18 +398,18 @@ class TransferCoordinator: try: callback() # We do not want a callback interrupting the process, especially - # in the failure cleanups. So log and catch, the exception. + # in the failure cleanups. So log and catch, the exception. except Exception: logger.debug("Exception raised in %s." % callback, exc_info=True) -class BoundedExecutor: +class BoundedExecutor: EXECUTOR_CLS = futures.ThreadPoolExecutor - def __init__( - self, max_size, max_num_threads, tag_semaphores=None, executor_cls=None - ): - """An executor implementation that has a maximum queued up tasks + def __init__( + self, max_size, max_num_threads, tag_semaphores=None, executor_cls=None + ): + """An executor implementation that has a maximum queued up tasks The executor will block if the number of tasks that have been submitted and is currently working on is past its maximum. @@ -455,7 +455,7 @@ class BoundedExecutor: False, if not to wait and raise an error if not able to submit a task. - :returns: The future associated to the submitted task + :returns: The future associated to the submitted task """ semaphore = self._semaphore # If a tag was provided, use the semaphore associated to that @@ -468,8 +468,8 @@ class BoundedExecutor: # Create a callback to invoke when task is done in order to call # release on the semaphore. release_callback = FunctionContainer( - semaphore.release, task.transfer_id, acquire_token - ) + semaphore.release, task.transfer_id, acquire_token + ) # Submit the task to the underlying executor. future = ExecutorFuture(self._executor.submit(task)) # Add the Semaphore.release() callback to the future such that @@ -481,7 +481,7 @@ class BoundedExecutor: self._executor.shutdown(wait) -class ExecutorFuture: +class ExecutorFuture: def __init__(self, future): """A future returned from the executor @@ -501,7 +501,7 @@ class ExecutorFuture: def add_done_callback(self, fn): """Adds a callback to be completed once future is done - :param fn: A callable that takes no arguments. Note that is different + :param fn: A callable that takes no arguments. Note that is different than concurrent.futures.Future.add_done_callback that requires a single argument for the future. """ @@ -510,16 +510,16 @@ class ExecutorFuture: # proper signature wrapper that will invoke the callback provided. def done_callback(future_passed_to_callback): return fn() - + self._future.add_done_callback(done_callback) def done(self): return self._future.done() -class BaseExecutor: +class BaseExecutor: """Base Executor class implementation needed to work with s3transfer""" - + def __init__(self, max_workers=None): pass @@ -532,7 +532,7 @@ class BaseExecutor: class NonThreadedExecutor(BaseExecutor): """A drop-in replacement non-threaded version of ThreadPoolExecutor""" - + def submit(self, fn, *args, **kwargs): future = NonThreadedExecutorFuture() try: @@ -542,9 +542,9 @@ class NonThreadedExecutor(BaseExecutor): e, tb = sys.exc_info()[1:] logger.debug( 'Setting exception for %s to %s with traceback %s', - future, - e, - tb, + future, + e, + tb, ) future.set_exception_info(e, tb) return future @@ -553,13 +553,13 @@ class NonThreadedExecutor(BaseExecutor): pass -class NonThreadedExecutorFuture: +class NonThreadedExecutorFuture: """The Future returned from NonThreadedExecutor Note that this future is **not** thread-safe as it is being used from the context of a non-threaded environment. """ - + def __init__(self): self._result = None self._exception = None @@ -578,7 +578,7 @@ class NonThreadedExecutorFuture: def result(self, timeout=None): if self._exception: - raise self._exception.with_traceback(self._traceback) + raise self._exception.with_traceback(self._traceback) return self._result def _set_done(self): diff --git a/contrib/python/s3transfer/py3/s3transfer/manager.py b/contrib/python/s3transfer/py3/s3transfer/manager.py index 095917c426..ff6afa12c1 100644 --- a/contrib/python/s3transfer/py3/s3transfer/manager.py +++ b/contrib/python/s3transfer/py3/s3transfer/manager.py @@ -12,54 +12,54 @@ # language governing permissions and limitations under the License. import copy import logging -import re +import re import threading -from s3transfer.bandwidth import BandwidthLimiter, LeakyBucket -from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, KB, MB -from s3transfer.copies import CopySubmissionTask -from s3transfer.delete import DeleteSubmissionTask +from s3transfer.bandwidth import BandwidthLimiter, LeakyBucket +from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, KB, MB +from s3transfer.copies import CopySubmissionTask +from s3transfer.delete import DeleteSubmissionTask from s3transfer.download import DownloadSubmissionTask -from s3transfer.exceptions import CancelledError, FatalError -from s3transfer.futures import ( - IN_MEMORY_DOWNLOAD_TAG, - IN_MEMORY_UPLOAD_TAG, - BoundedExecutor, - TransferCoordinator, - TransferFuture, - TransferMeta, -) +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import ( + IN_MEMORY_DOWNLOAD_TAG, + IN_MEMORY_UPLOAD_TAG, + BoundedExecutor, + TransferCoordinator, + TransferFuture, + TransferMeta, +) from s3transfer.upload import UploadSubmissionTask -from s3transfer.utils import ( - CallArgs, - OSUtils, - SlidingWindowSemaphore, - TaskSemaphore, - get_callbacks, - signal_not_transferring, - signal_transferring, -) +from s3transfer.utils import ( + CallArgs, + OSUtils, + SlidingWindowSemaphore, + TaskSemaphore, + get_callbacks, + signal_not_transferring, + signal_transferring, +) logger = logging.getLogger(__name__) -class TransferConfig: - def __init__( - self, - multipart_threshold=8 * MB, - multipart_chunksize=8 * MB, - max_request_concurrency=10, - max_submission_concurrency=5, - max_request_queue_size=1000, - max_submission_queue_size=1000, - max_io_queue_size=1000, - io_chunksize=256 * KB, - num_download_attempts=5, - max_in_memory_upload_chunks=10, - max_in_memory_download_chunks=10, - max_bandwidth=None, - ): - """Configurations for the transfer manager +class TransferConfig: + def __init__( + self, + multipart_threshold=8 * MB, + multipart_chunksize=8 * MB, + max_request_concurrency=10, + max_submission_concurrency=5, + max_request_queue_size=1000, + max_submission_queue_size=1000, + max_io_queue_size=1000, + io_chunksize=256 * KB, + num_download_attempts=5, + max_in_memory_upload_chunks=10, + max_in_memory_download_chunks=10, + max_bandwidth=None, + ): + """Configurations for the transfer manager :param multipart_threshold: The threshold for which multipart transfers occur. @@ -71,7 +71,7 @@ class TransferConfig: processing a call to a TransferManager method. Processing a call usually entails determining which S3 API requests that need to be enqueued, but does **not** entail making any of the - S3 API data transferring requests needed to perform the transfer. + S3 API data transferring requests needed to perform the transfer. The threads controlled by ``max_request_concurrency`` is responsible for that. @@ -79,14 +79,14 @@ class TransferConfig: becomes a multipart transfer. :param max_request_queue_size: The maximum amount of S3 API requests - that can be queued at a time. + that can be queued at a time. :param max_submission_queue_size: The maximum amount of - TransferManager method calls that can be queued at a time. + TransferManager method calls that can be queued at a time. :param max_io_queue_size: The maximum amount of read parts that - can be queued to be written to disk per download. The default - size for each elementin this queue is 8 KB. + can be queued to be written to disk per download. The default + size for each elementin this queue is 8 KB. :param io_chunksize: The max size of each chunk in the io queue. Currently, this is size used when reading from the downloaded @@ -94,9 +94,9 @@ class TransferConfig: :param num_download_attempts: The number of download attempts that will be tried upon errors with downloading an object in S3. Note - that these retries account for errors that occur when streaming + that these retries account for errors that occur when streaming down the data from s3 (i.e. socket errors and read timeouts that - occur after receiving an OK response from s3). + occur after receiving an OK response from s3). Other retryable exceptions such as throttling errors and 5xx errors are already retried by botocore (this default is 5). The ``num_download_attempts`` does not take into account the @@ -120,7 +120,7 @@ class TransferConfig: :param max_in_memory_download_chunks: The number of chunks that can be buffered in memory and **not** in the io queue at a time for all - ongoing download requests. This pertains specifically to file-like + ongoing download requests. This pertains specifically to file-like objects that cannot be seeked. The total maximum memory footprint due to a in-memory download chunks is roughly equal to: @@ -145,16 +145,16 @@ class TransferConfig: self._validate_attrs_are_nonzero() def _validate_attrs_are_nonzero(self): - for attr, attr_val in self.__dict__.items(): + for attr, attr_val in self.__dict__.items(): if attr_val is not None and attr_val <= 0: raise ValueError( 'Provided parameter %s of value %s must be greater than ' - '0.' % (attr, attr_val) - ) + '0.' % (attr, attr_val) + ) -class TransferManager: - ALLOWED_DOWNLOAD_ARGS = ALLOWED_DOWNLOAD_ARGS +class TransferManager: + ALLOWED_DOWNLOAD_ARGS = ALLOWED_DOWNLOAD_ARGS ALLOWED_UPLOAD_ARGS = [ 'ACL', @@ -163,7 +163,7 @@ class TransferManager: 'ContentEncoding', 'ContentLanguage', 'ContentType', - 'ExpectedBucketOwner', + 'ExpectedBucketOwner', 'Expires', 'GrantFullControl', 'GrantRead', @@ -177,9 +177,9 @@ class TransferManager: 'SSECustomerKey', 'SSECustomerKeyMD5', 'SSEKMSKeyId', - 'SSEKMSEncryptionContext', - 'Tagging', - 'WebsiteRedirectLocation', + 'SSEKMSEncryptionContext', + 'Tagging', + 'WebsiteRedirectLocation', ] ALLOWED_COPY_ARGS = ALLOWED_UPLOAD_ARGS + [ @@ -190,26 +190,26 @@ class TransferManager: 'CopySourceSSECustomerAlgorithm', 'CopySourceSSECustomerKey', 'CopySourceSSECustomerKeyMD5', - 'MetadataDirective', - 'TaggingDirective', + 'MetadataDirective', + 'TaggingDirective', ] ALLOWED_DELETE_ARGS = [ 'MFA', 'VersionId', 'RequestPayer', - 'ExpectedBucketOwner', + 'ExpectedBucketOwner', ] - VALIDATE_SUPPORTED_BUCKET_VALUES = True - - _UNSUPPORTED_BUCKET_PATTERNS = { - 'S3 Object Lambda': re.compile( - r'^arn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:' - r'accesspoint[/:][a-zA-Z0-9\-]{1,63}' - ), - } - + VALIDATE_SUPPORTED_BUCKET_VALUES = True + + _UNSUPPORTED_BUCKET_PATTERNS = { + 'S3 Object Lambda': re.compile( + r'^arn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:' + r'accesspoint[/:][a-zA-Z0-9\-]{1,63}' + ), + } + def __init__(self, client, config=None, osutil=None, executor_cls=None): """A transfer manager interface for Amazon S3 @@ -239,13 +239,13 @@ class TransferManager: max_num_threads=self._config.max_request_concurrency, tag_semaphores={ IN_MEMORY_UPLOAD_TAG: TaskSemaphore( - self._config.max_in_memory_upload_chunks - ), + self._config.max_in_memory_upload_chunks + ), IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore( - self._config.max_in_memory_download_chunks - ), + self._config.max_in_memory_download_chunks + ), }, - executor_cls=executor_cls, + executor_cls=executor_cls, ) # The executor responsible for submitting the necessary tasks to @@ -253,7 +253,7 @@ class TransferManager: self._submission_executor = BoundedExecutor( max_size=self._config.max_submission_queue_size, max_num_threads=self._config.max_submission_concurrency, - executor_cls=executor_cls, + executor_cls=executor_cls, ) # There is one thread available for writing to disk. It will handle @@ -261,7 +261,7 @@ class TransferManager: self._io_executor = BoundedExecutor( max_size=self._config.max_io_queue_size, max_num_threads=1, - executor_cls=executor_cls, + executor_cls=executor_cls, ) # The component responsible for limiting bandwidth usage if it @@ -269,21 +269,21 @@ class TransferManager: self._bandwidth_limiter = None if self._config.max_bandwidth is not None: logger.debug( - 'Setting max_bandwidth to %s', self._config.max_bandwidth - ) + 'Setting max_bandwidth to %s', self._config.max_bandwidth + ) leaky_bucket = LeakyBucket(self._config.max_bandwidth) self._bandwidth_limiter = BandwidthLimiter(leaky_bucket) self._register_handlers() - @property - def client(self): - return self._client - - @property - def config(self): - return self._config - + @property + def client(self): + return self._client + + @property + def config(self): + return self._config + def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): """Uploads a file to S3 @@ -315,24 +315,24 @@ class TransferManager: if subscribers is None: subscribers = [] self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS) - self._validate_if_bucket_supported(bucket) + self._validate_if_bucket_supported(bucket) call_args = CallArgs( - fileobj=fileobj, - bucket=bucket, - key=key, - extra_args=extra_args, - subscribers=subscribers, + fileobj=fileobj, + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, ) extra_main_kwargs = {} if self._bandwidth_limiter: extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter return self._submit_transfer( - call_args, UploadSubmissionTask, extra_main_kwargs - ) + call_args, UploadSubmissionTask, extra_main_kwargs + ) - def download( - self, bucket, key, fileobj, extra_args=None, subscribers=None - ): + def download( + self, bucket, key, fileobj, extra_args=None, subscribers=None + ): """Downloads a file from S3 :type bucket: str @@ -363,30 +363,30 @@ class TransferManager: if subscribers is None: subscribers = [] self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS) - self._validate_if_bucket_supported(bucket) + self._validate_if_bucket_supported(bucket) call_args = CallArgs( - bucket=bucket, - key=key, - fileobj=fileobj, - extra_args=extra_args, - subscribers=subscribers, + bucket=bucket, + key=key, + fileobj=fileobj, + extra_args=extra_args, + subscribers=subscribers, ) extra_main_kwargs = {'io_executor': self._io_executor} if self._bandwidth_limiter: extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter return self._submit_transfer( - call_args, DownloadSubmissionTask, extra_main_kwargs - ) - - def copy( - self, - copy_source, - bucket, - key, - extra_args=None, - subscribers=None, - source_client=None, - ): + call_args, DownloadSubmissionTask, extra_main_kwargs + ) + + def copy( + self, + copy_source, + bucket, + key, + extra_args=None, + subscribers=None, + source_client=None, + ): """Copies a file in S3 :type copy_source: dict @@ -428,16 +428,16 @@ class TransferManager: if source_client is None: source_client = self._client self._validate_all_known_args(extra_args, self.ALLOWED_COPY_ARGS) - if isinstance(copy_source, dict): - self._validate_if_bucket_supported(copy_source.get('Bucket')) - self._validate_if_bucket_supported(bucket) + if isinstance(copy_source, dict): + self._validate_if_bucket_supported(copy_source.get('Bucket')) + self._validate_if_bucket_supported(bucket) call_args = CallArgs( - copy_source=copy_source, - bucket=bucket, - key=key, - extra_args=extra_args, - subscribers=subscribers, - source_client=source_client, + copy_source=copy_source, + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, + source_client=source_client, ) return self._submit_transfer(call_args, CopySubmissionTask) @@ -468,46 +468,46 @@ class TransferManager: if subscribers is None: subscribers = [] self._validate_all_known_args(extra_args, self.ALLOWED_DELETE_ARGS) - self._validate_if_bucket_supported(bucket) + self._validate_if_bucket_supported(bucket) call_args = CallArgs( - bucket=bucket, - key=key, - extra_args=extra_args, - subscribers=subscribers, + bucket=bucket, + key=key, + extra_args=extra_args, + subscribers=subscribers, ) return self._submit_transfer(call_args, DeleteSubmissionTask) - def _validate_if_bucket_supported(self, bucket): - # s3 high level operations don't support some resources - # (eg. S3 Object Lambda) only direct API calls are available - # for such resources - if self.VALIDATE_SUPPORTED_BUCKET_VALUES: - for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items(): - match = pattern.match(bucket) - if match: - raise ValueError( - 'TransferManager methods do not support %s ' - 'resource. Use direct client calls instead.' % resource - ) - + def _validate_if_bucket_supported(self, bucket): + # s3 high level operations don't support some resources + # (eg. S3 Object Lambda) only direct API calls are available + # for such resources + if self.VALIDATE_SUPPORTED_BUCKET_VALUES: + for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items(): + match = pattern.match(bucket) + if match: + raise ValueError( + 'TransferManager methods do not support %s ' + 'resource. Use direct client calls instead.' % resource + ) + def _validate_all_known_args(self, actual, allowed): for kwarg in actual: if kwarg not in allowed: raise ValueError( "Invalid extra_args key '%s', " - "must be one of: %s" % (kwarg, ', '.join(allowed)) - ) + "must be one of: %s" % (kwarg, ', '.join(allowed)) + ) - def _submit_transfer( - self, call_args, submission_task_cls, extra_main_kwargs=None - ): + def _submit_transfer( + self, call_args, submission_task_cls, extra_main_kwargs=None + ): if not extra_main_kwargs: extra_main_kwargs = {} # Create a TransferFuture to return back to the user transfer_future, components = self._get_future_with_components( - call_args - ) + call_args + ) # Add any provided done callbacks to the created transfer future # to be invoked on the transfer future being complete. @@ -516,15 +516,15 @@ class TransferManager: # Get the main kwargs needed to instantiate the submission task main_kwargs = self._get_submission_task_main_kwargs( - transfer_future, extra_main_kwargs - ) + transfer_future, extra_main_kwargs + ) # Submit a SubmissionTask that will submit all of the necessary # tasks needed to complete the S3 transfer. self._submission_executor.submit( submission_task_cls( transfer_coordinator=components['coordinator'], - main_kwargs=main_kwargs, + main_kwargs=main_kwargs, ) ) @@ -539,30 +539,30 @@ class TransferManager: transfer_coordinator = TransferCoordinator(transfer_id=transfer_id) # Track the transfer coordinator for transfers to manage. self._coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) + transfer_coordinator + ) # Also make sure that the transfer coordinator is removed once # the transfer completes so it does not stick around in memory. transfer_coordinator.add_done_callback( self._coordinator_controller.remove_transfer_coordinator, - transfer_coordinator, - ) + transfer_coordinator, + ) components = { 'meta': TransferMeta(call_args, transfer_id=transfer_id), - 'coordinator': transfer_coordinator, + 'coordinator': transfer_coordinator, } transfer_future = TransferFuture(**components) return transfer_future, components def _get_submission_task_main_kwargs( - self, transfer_future, extra_main_kwargs - ): + self, transfer_future, extra_main_kwargs + ): main_kwargs = { 'client': self._client, 'config': self._config, 'osutil': self._osutil, 'request_executor': self._request_executor, - 'transfer_future': transfer_future, + 'transfer_future': transfer_future, } main_kwargs.update(extra_main_kwargs) return main_kwargs @@ -571,13 +571,13 @@ class TransferManager: # Register handlers to enable/disable callbacks on uploads. event_name = 'request-created.s3' self._client.meta.events.register_first( - event_name, - signal_not_transferring, - unique_id='s3upload-not-transferring', - ) + event_name, + signal_not_transferring, + unique_id='s3upload-not-transferring', + ) self._client.meta.events.register_last( - event_name, signal_transferring, unique_id='s3upload-transferring' - ) + event_name, signal_transferring, unique_id='s3upload-transferring' + ) def __enter__(self): return self @@ -590,7 +590,7 @@ class TransferManager: # all of the inprogress futures in the shutdown. if exc_type: cancel = True - cancel_msg = str(exc_value) + cancel_msg = str(exc_value) if not cancel_msg: cancel_msg = repr(exc_value) # If it was a KeyboardInterrupt, the cancellation was initiated @@ -641,7 +641,7 @@ class TransferManager: self._io_executor.shutdown() -class TransferCoordinatorController: +class TransferCoordinatorController: def __init__(self): """Abstraction to control all transfer coordinators @@ -671,7 +671,7 @@ class TransferCoordinatorController: self._tracked_transfer_coordinators.add(transfer_coordinator) def remove_transfer_coordinator(self, transfer_coordinator): - """Remove a transfer coordinator from cancellation consideration + """Remove a transfer coordinator from cancellation consideration Typically, this method is invoked by the transfer coordinator itself to remove its self when it completes its transfer. @@ -700,7 +700,7 @@ class TransferCoordinatorController: def wait(self): """Wait until there are no more inprogress transfers - This will not stop when failures are encountered and not propagate any + This will not stop when failures are encountered and not propagate any of these errors from failed transfers, but it can be interrupted with a KeyboardInterrupt. """ @@ -716,8 +716,8 @@ class TransferCoordinatorController: if transfer_coordinator: logger.debug( 'On KeyboardInterrupt was waiting for %s', - transfer_coordinator, - ) + transfer_coordinator, + ) raise except Exception: # A general exception could have been thrown because diff --git a/contrib/python/s3transfer/py3/s3transfer/processpool.py b/contrib/python/s3transfer/py3/s3transfer/processpool.py index 783518fd19..017eeb4499 100644 --- a/contrib/python/s3transfer/py3/s3transfer/processpool.py +++ b/contrib/python/s3transfer/py3/s3transfer/processpool.py @@ -1,1008 +1,1008 @@ -# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Speeds up S3 throughput by using processes - -Getting Started -=============== - -The :class:`ProcessPoolDownloader` can be used to download a single file by -calling :meth:`ProcessPoolDownloader.download_file`: - -.. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - - with ProcessPoolDownloader() as downloader: - downloader.download_file('mybucket', 'mykey', 'myfile') - - -This snippet downloads the S3 object located in the bucket ``mybucket`` at the -key ``mykey`` to the local file ``myfile``. Any errors encountered during the -transfer are not propagated. To determine if a transfer succeeded or -failed, use the `Futures`_ interface. - - -The :class:`ProcessPoolDownloader` can be used to download multiple files as -well: - -.. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - - with ProcessPoolDownloader() as downloader: - downloader.download_file('mybucket', 'mykey', 'myfile') - downloader.download_file('mybucket', 'myotherkey', 'myotherfile') - - -When running this snippet, the downloading of ``mykey`` and ``myotherkey`` -happen in parallel. The first ``download_file`` call does not block the -second ``download_file`` call. The snippet blocks when exiting -the context manager and blocks until both downloads are complete. - -Alternatively, the ``ProcessPoolDownloader`` can be instantiated -and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`: - -.. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - - downloader = ProcessPoolDownloader() - downloader.download_file('mybucket', 'mykey', 'myfile') - downloader.download_file('mybucket', 'myotherkey', 'myotherfile') - downloader.shutdown() - - -For this code snippet, the call to ``shutdown`` blocks until both -downloads are complete. - - -Additional Parameters -===================== - -Additional parameters can be provided to the ``download_file`` method: - -* ``extra_args``: A dictionary containing any additional client arguments - to include in the - `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_ - API request. For example: - - .. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - - with ProcessPoolDownloader() as downloader: - downloader.download_file( - 'mybucket', 'mykey', 'myfile', - extra_args={'VersionId': 'myversion'}) - - -* ``expected_size``: By default, the downloader will make a HeadObject - call to determine the size of the object. To opt-out of this additional - API call, you can provide the size of the object in bytes: - - .. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - - MB = 1024 * 1024 - with ProcessPoolDownloader() as downloader: - downloader.download_file( - 'mybucket', 'mykey', 'myfile', expected_size=2 * MB) - - -Futures -======= - -When ``download_file`` is called, it immediately returns a -:class:`ProcessPoolTransferFuture`. The future can be used to poll the state -of a particular transfer. To get the result of the download, -call :meth:`ProcessPoolTransferFuture.result`. The method blocks -until the transfer completes, whether it succeeds or fails. For example: - -.. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - - with ProcessPoolDownloader() as downloader: - future = downloader.download_file('mybucket', 'mykey', 'myfile') - print(future.result()) - - -If the download succeeds, the future returns ``None``: - -.. code:: python - - None - - -If the download fails, the exception causing the failure is raised. For -example, if ``mykey`` did not exist, the following error would be raised - - -.. code:: python - - botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found - - -.. note:: - - :meth:`ProcessPoolTransferFuture.result` can only be called while the - ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or - inside the context manager). - - -Process Pool Configuration -========================== - -By default, the downloader has the following configuration options: - -* ``multipart_threshold``: The threshold size for performing ranged downloads - in bytes. By default, ranged downloads happen for S3 objects that are - greater than or equal to 8 MB in size. - -* ``multipart_chunksize``: The size of each ranged download in bytes. By - default, the size of each ranged download is 8 MB. - -* ``max_request_processes``: The maximum number of processes used to download - S3 objects. By default, the maximum is 10 processes. - - -To change the default configuration, use the :class:`ProcessTransferConfig`: - -.. code:: python - - from s3transfer.processpool import ProcessPoolDownloader - from s3transfer.processpool import ProcessTransferConfig - - config = ProcessTransferConfig( - multipart_threshold=64 * 1024 * 1024, # 64 MB - max_request_processes=50 - ) - downloader = ProcessPoolDownloader(config=config) - - -Client Configuration -==================== - -The process pool downloader creates ``botocore`` clients on your behalf. In -order to affect how the client is created, pass the keyword arguments -that would have been used in the :meth:`botocore.Session.create_client` call: - -.. code:: python - - - from s3transfer.processpool import ProcessPoolDownloader - from s3transfer.processpool import ProcessTransferConfig - - downloader = ProcessPoolDownloader( - client_kwargs={'region_name': 'us-west-2'}) - - -This snippet ensures that all clients created by the ``ProcessPoolDownloader`` -are using ``us-west-2`` as their region. - -""" -import collections -import contextlib -import logging -import multiprocessing -import signal -import threading -from copy import deepcopy - -import botocore.session -from botocore.config import Config - -from s3transfer.compat import MAXINT, BaseManager -from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT -from s3transfer.exceptions import CancelledError, RetriesExceededError -from s3transfer.futures import BaseTransferFuture, BaseTransferMeta -from s3transfer.utils import ( - S3_RETRYABLE_DOWNLOAD_ERRORS, - CallArgs, - OSUtils, - calculate_num_parts, - calculate_range_parameter, -) - -logger = logging.getLogger(__name__) - -SHUTDOWN_SIGNAL = 'SHUTDOWN' - -# The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader -# to the GetObjectSubmitter in order for the submitter to begin submitting -# GetObjectJobs to the GetObjectWorkers. -DownloadFileRequest = collections.namedtuple( - 'DownloadFileRequest', - [ - 'transfer_id', # The unique id for the transfer - 'bucket', # The bucket to download the object from - 'key', # The key to download the object from - 'filename', # The user-requested download location - 'extra_args', # Extra arguments to provide to client calls - 'expected_size', # The user-provided expected size of the download - ], -) - -# The GetObjectJob tuple is submitted from the GetObjectSubmitter -# to the GetObjectWorkers to download the file or parts of the file. -GetObjectJob = collections.namedtuple( - 'GetObjectJob', - [ - 'transfer_id', # The unique id for the transfer - 'bucket', # The bucket to download the object from - 'key', # The key to download the object from - 'temp_filename', # The temporary file to write the content to via - # completed GetObject calls. - 'extra_args', # Extra arguments to provide to the GetObject call - 'offset', # The offset to write the content for the temp file. - 'filename', # The user-requested download location. The worker - # of final GetObjectJob will move the file located at - # temp_filename to the location of filename. - ], -) - - -@contextlib.contextmanager -def ignore_ctrl_c(): - original_handler = _add_ignore_handler_for_interrupts() - yield - signal.signal(signal.SIGINT, original_handler) - - -def _add_ignore_handler_for_interrupts(): - # Windows is unable to pickle signal.signal directly so it needs to - # be wrapped in a function defined at the module level - return signal.signal(signal.SIGINT, signal.SIG_IGN) - - -class ProcessTransferConfig: - def __init__( - self, - multipart_threshold=8 * MB, - multipart_chunksize=8 * MB, - max_request_processes=10, - ): - """Configuration for the ProcessPoolDownloader - - :param multipart_threshold: The threshold for which ranged downloads - occur. - - :param multipart_chunksize: The chunk size of each ranged download. - - :param max_request_processes: The maximum number of processes that - will be making S3 API transfer-related requests at a time. - """ - self.multipart_threshold = multipart_threshold - self.multipart_chunksize = multipart_chunksize - self.max_request_processes = max_request_processes - - -class ProcessPoolDownloader: - def __init__(self, client_kwargs=None, config=None): - """Downloads S3 objects using process pools - - :type client_kwargs: dict - :param client_kwargs: The keyword arguments to provide when - instantiating S3 clients. The arguments must match the keyword - arguments provided to the - `botocore.session.Session.create_client()` method. - - :type config: ProcessTransferConfig - :param config: Configuration for the downloader - """ - if client_kwargs is None: - client_kwargs = {} - self._client_factory = ClientFactory(client_kwargs) - - self._transfer_config = config - if config is None: - self._transfer_config = ProcessTransferConfig() - - self._download_request_queue = multiprocessing.Queue(1000) - self._worker_queue = multiprocessing.Queue(1000) - self._osutil = OSUtils() - - self._started = False - self._start_lock = threading.Lock() - - # These below are initialized in the start() method - self._manager = None - self._transfer_monitor = None - self._submitter = None - self._workers = [] - - def download_file( - self, bucket, key, filename, extra_args=None, expected_size=None - ): - """Downloads the object's contents to a file - - :type bucket: str - :param bucket: The name of the bucket to download from - - :type key: str - :param key: The name of the key to download from - - :type filename: str - :param filename: The name of a file to download to. - - :type extra_args: dict - :param extra_args: Extra arguments that may be passed to the - client operation - - :type expected_size: int - :param expected_size: The expected size in bytes of the download. If - provided, the downloader will not call HeadObject to determine the - object's size and use the provided value instead. The size is - needed to determine whether to do a multipart download. - - :rtype: s3transfer.futures.TransferFuture - :returns: Transfer future representing the download - """ - self._start_if_needed() - if extra_args is None: - extra_args = {} - self._validate_all_known_args(extra_args) - transfer_id = self._transfer_monitor.notify_new_transfer() - download_file_request = DownloadFileRequest( - transfer_id=transfer_id, - bucket=bucket, - key=key, - filename=filename, - extra_args=extra_args, - expected_size=expected_size, - ) - logger.debug( - 'Submitting download file request: %s.', download_file_request - ) - self._download_request_queue.put(download_file_request) - call_args = CallArgs( - bucket=bucket, - key=key, - filename=filename, - extra_args=extra_args, - expected_size=expected_size, - ) - future = self._get_transfer_future(transfer_id, call_args) - return future - - def shutdown(self): - """Shutdown the downloader - - It will wait till all downloads are complete before returning. - """ - self._shutdown_if_needed() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, *args): - if isinstance(exc_value, KeyboardInterrupt): - if self._transfer_monitor is not None: - self._transfer_monitor.notify_cancel_all_in_progress() - self.shutdown() - - def _start_if_needed(self): - with self._start_lock: - if not self._started: - self._start() - - def _start(self): - self._start_transfer_monitor_manager() - self._start_submitter() - self._start_get_object_workers() - self._started = True - - def _validate_all_known_args(self, provided): - for kwarg in provided: - if kwarg not in ALLOWED_DOWNLOAD_ARGS: - download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS) - raise ValueError( - f"Invalid extra_args key '{kwarg}', " - f"must be one of: {download_args}" - ) - - def _get_transfer_future(self, transfer_id, call_args): - meta = ProcessPoolTransferMeta( - call_args=call_args, transfer_id=transfer_id - ) - future = ProcessPoolTransferFuture( - monitor=self._transfer_monitor, meta=meta - ) - return future - - def _start_transfer_monitor_manager(self): - logger.debug('Starting the TransferMonitorManager.') - self._manager = TransferMonitorManager() - # We do not want Ctrl-C's to cause the manager to shutdown immediately - # as worker processes will still need to communicate with it when they - # are shutting down. So instead we ignore Ctrl-C and let the manager - # be explicitly shutdown when shutting down the downloader. - self._manager.start(_add_ignore_handler_for_interrupts) - self._transfer_monitor = self._manager.TransferMonitor() - - def _start_submitter(self): - logger.debug('Starting the GetObjectSubmitter.') - self._submitter = GetObjectSubmitter( - transfer_config=self._transfer_config, - client_factory=self._client_factory, - transfer_monitor=self._transfer_monitor, - osutil=self._osutil, - download_request_queue=self._download_request_queue, - worker_queue=self._worker_queue, - ) - self._submitter.start() - - def _start_get_object_workers(self): - logger.debug( - 'Starting %s GetObjectWorkers.', - self._transfer_config.max_request_processes, - ) - for _ in range(self._transfer_config.max_request_processes): - worker = GetObjectWorker( - queue=self._worker_queue, - client_factory=self._client_factory, - transfer_monitor=self._transfer_monitor, - osutil=self._osutil, - ) - worker.start() - self._workers.append(worker) - - def _shutdown_if_needed(self): - with self._start_lock: - if self._started: - self._shutdown() - - def _shutdown(self): - self._shutdown_submitter() - self._shutdown_get_object_workers() - self._shutdown_transfer_monitor_manager() - self._started = False - - def _shutdown_transfer_monitor_manager(self): - logger.debug('Shutting down the TransferMonitorManager.') - self._manager.shutdown() - - def _shutdown_submitter(self): - logger.debug('Shutting down the GetObjectSubmitter.') - self._download_request_queue.put(SHUTDOWN_SIGNAL) - self._submitter.join() - - def _shutdown_get_object_workers(self): - logger.debug('Shutting down the GetObjectWorkers.') - for _ in self._workers: - self._worker_queue.put(SHUTDOWN_SIGNAL) - for worker in self._workers: - worker.join() - - -class ProcessPoolTransferFuture(BaseTransferFuture): - def __init__(self, monitor, meta): - """The future associated to a submitted process pool transfer request - - :type monitor: TransferMonitor - :param monitor: The monitor associated to the process pool downloader - - :type meta: ProcessPoolTransferMeta - :param meta: The metadata associated to the request. This object - is visible to the requester. - """ - self._monitor = monitor - self._meta = meta - - @property - def meta(self): - return self._meta - - def done(self): - return self._monitor.is_done(self._meta.transfer_id) - - def result(self): - try: - return self._monitor.poll_for_result(self._meta.transfer_id) - except KeyboardInterrupt: - # For the multiprocessing Manager, a thread is given a single - # connection to reuse in communicating between the thread in the - # main process and the Manager's process. If a Ctrl-C happens when - # polling for the result, it will make the main thread stop trying - # to receive from the connection, but the Manager process will not - # know that the main process has stopped trying to receive and - # will not close the connection. As a result if another message is - # sent to the Manager process, the listener in the Manager - # processes will not process the new message as it is still trying - # trying to process the previous message (that was Ctrl-C'd) and - # thus cause the thread in the main process to hang on its send. - # The only way around this is to create a new connection and send - # messages from that new connection instead. - self._monitor._connect() - self.cancel() - raise - - def cancel(self): - self._monitor.notify_exception( - self._meta.transfer_id, CancelledError() - ) - - -class ProcessPoolTransferMeta(BaseTransferMeta): - """Holds metadata about the ProcessPoolTransferFuture""" - - def __init__(self, transfer_id, call_args): - self._transfer_id = transfer_id - self._call_args = call_args - self._user_context = {} - - @property - def call_args(self): - return self._call_args - - @property - def transfer_id(self): - return self._transfer_id - - @property - def user_context(self): - return self._user_context - - -class ClientFactory: - def __init__(self, client_kwargs=None): - """Creates S3 clients for processes - - Botocore sessions and clients are not pickleable so they cannot be - inherited across Process boundaries. Instead, they must be instantiated - once a process is running. - """ - self._client_kwargs = client_kwargs - if self._client_kwargs is None: - self._client_kwargs = {} - - client_config = deepcopy(self._client_kwargs.get('config', Config())) - if not client_config.user_agent_extra: - client_config.user_agent_extra = PROCESS_USER_AGENT - else: - client_config.user_agent_extra += " " + PROCESS_USER_AGENT - self._client_kwargs['config'] = client_config - - def create_client(self): - """Create a botocore S3 client""" - return botocore.session.Session().create_client( - 's3', **self._client_kwargs - ) - - -class TransferMonitor: - def __init__(self): - """Monitors transfers for cross-process communication - - Notifications can be sent to the monitor and information can be - retrieved from the monitor for a particular transfer. This abstraction - is ran in a ``multiprocessing.managers.BaseManager`` in order to be - shared across processes. - """ - # TODO: Add logic that removes the TransferState if the transfer is - # marked as done and the reference to the future is no longer being - # held onto. Without this logic, this dictionary will continue to - # grow in size with no limit. - self._transfer_states = {} - self._id_count = 0 - self._init_lock = threading.Lock() - - def notify_new_transfer(self): - with self._init_lock: - transfer_id = self._id_count - self._transfer_states[transfer_id] = TransferState() - self._id_count += 1 - return transfer_id - - def is_done(self, transfer_id): - """Determine a particular transfer is complete - - :param transfer_id: Unique identifier for the transfer - :return: True, if done. False, otherwise. - """ - return self._transfer_states[transfer_id].done - - def notify_done(self, transfer_id): - """Notify a particular transfer is complete - - :param transfer_id: Unique identifier for the transfer - """ - self._transfer_states[transfer_id].set_done() - - def poll_for_result(self, transfer_id): - """Poll for the result of a transfer - - :param transfer_id: Unique identifier for the transfer - :return: If the transfer succeeded, it will return the result. If the - transfer failed, it will raise the exception associated to the - failure. - """ - self._transfer_states[transfer_id].wait_till_done() - exception = self._transfer_states[transfer_id].exception - if exception: - raise exception - return None - - def notify_exception(self, transfer_id, exception): - """Notify an exception was encountered for a transfer - - :param transfer_id: Unique identifier for the transfer - :param exception: The exception encountered for that transfer - """ - # TODO: Not all exceptions are pickleable so if we are running - # this in a multiprocessing.BaseManager we will want to - # make sure to update this signature to ensure pickleability of the - # arguments or have the ProxyObject do the serialization. - self._transfer_states[transfer_id].exception = exception - - def notify_cancel_all_in_progress(self): - for transfer_state in self._transfer_states.values(): - if not transfer_state.done: - transfer_state.exception = CancelledError() - - def get_exception(self, transfer_id): - """Retrieve the exception encountered for the transfer - - :param transfer_id: Unique identifier for the transfer - :return: The exception encountered for that transfer. Otherwise - if there were no exceptions, returns None. - """ - return self._transfer_states[transfer_id].exception - - def notify_expected_jobs_to_complete(self, transfer_id, num_jobs): - """Notify the amount of jobs expected for a transfer - - :param transfer_id: Unique identifier for the transfer - :param num_jobs: The number of jobs to complete the transfer - """ - self._transfer_states[transfer_id].jobs_to_complete = num_jobs - - def notify_job_complete(self, transfer_id): - """Notify that a single job is completed for a transfer - - :param transfer_id: Unique identifier for the transfer - :return: The number of jobs remaining to complete the transfer - """ - return self._transfer_states[transfer_id].decrement_jobs_to_complete() - - -class TransferState: - """Represents the current state of an individual transfer""" - - # NOTE: Ideally the TransferState object would be used directly by the - # various different abstractions in the ProcessPoolDownloader and remove - # the need for the TransferMonitor. However, it would then impose the - # constraint that two hops are required to make or get any changes in the - # state of a transfer across processes: one hop to get a proxy object for - # the TransferState and then a second hop to communicate calling the - # specific TransferState method. - def __init__(self): - self._exception = None - self._done_event = threading.Event() - self._job_lock = threading.Lock() - self._jobs_to_complete = 0 - - @property - def done(self): - return self._done_event.is_set() - - def set_done(self): - self._done_event.set() - - def wait_till_done(self): - self._done_event.wait(MAXINT) - - @property - def exception(self): - return self._exception - - @exception.setter - def exception(self, val): - self._exception = val - - @property - def jobs_to_complete(self): - return self._jobs_to_complete - - @jobs_to_complete.setter - def jobs_to_complete(self, val): - self._jobs_to_complete = val - - def decrement_jobs_to_complete(self): - with self._job_lock: - self._jobs_to_complete -= 1 - return self._jobs_to_complete - - -class TransferMonitorManager(BaseManager): - pass - - -TransferMonitorManager.register('TransferMonitor', TransferMonitor) - - -class BaseS3TransferProcess(multiprocessing.Process): - def __init__(self, client_factory): - super().__init__() - self._client_factory = client_factory - self._client = None - - def run(self): - # Clients are not pickleable so their instantiation cannot happen - # in the __init__ for processes that are created under the - # spawn method. - self._client = self._client_factory.create_client() - with ignore_ctrl_c(): - # By default these processes are ran as child processes to the - # main process. Any Ctrl-c encountered in the main process is - # propagated to the child process and interrupt it at any time. - # To avoid any potentially bad states caused from an interrupt - # (i.e. a transfer failing to notify its done or making the - # communication protocol become out of sync with the - # TransferMonitor), we ignore all Ctrl-C's and allow the main - # process to notify these child processes when to stop processing - # jobs. - self._do_run() - - def _do_run(self): - raise NotImplementedError('_do_run()') - - -class GetObjectSubmitter(BaseS3TransferProcess): - def __init__( - self, - transfer_config, - client_factory, - transfer_monitor, - osutil, - download_request_queue, - worker_queue, - ): - """Submit GetObjectJobs to fulfill a download file request - - :param transfer_config: Configuration for transfers. - :param client_factory: ClientFactory for creating S3 clients. - :param transfer_monitor: Monitor for notifying and retrieving state - of transfer. - :param osutil: OSUtils object to use for os-related behavior when - performing the transfer. - :param download_request_queue: Queue to retrieve download file - requests. - :param worker_queue: Queue to submit GetObjectJobs for workers - to perform. - """ - super().__init__(client_factory) - self._transfer_config = transfer_config - self._transfer_monitor = transfer_monitor - self._osutil = osutil - self._download_request_queue = download_request_queue - self._worker_queue = worker_queue - - def _do_run(self): - while True: - download_file_request = self._download_request_queue.get() - if download_file_request == SHUTDOWN_SIGNAL: - logger.debug('Submitter shutdown signal received.') - return - try: - self._submit_get_object_jobs(download_file_request) - except Exception as e: - logger.debug( - 'Exception caught when submitting jobs for ' - 'download file request %s: %s', - download_file_request, - e, - exc_info=True, - ) - self._transfer_monitor.notify_exception( - download_file_request.transfer_id, e - ) - self._transfer_monitor.notify_done( - download_file_request.transfer_id - ) - - def _submit_get_object_jobs(self, download_file_request): - size = self._get_size(download_file_request) - temp_filename = self._allocate_temp_file(download_file_request, size) - if size < self._transfer_config.multipart_threshold: - self._submit_single_get_object_job( - download_file_request, temp_filename - ) - else: - self._submit_ranged_get_object_jobs( - download_file_request, temp_filename, size - ) - - def _get_size(self, download_file_request): - expected_size = download_file_request.expected_size - if expected_size is None: - expected_size = self._client.head_object( - Bucket=download_file_request.bucket, - Key=download_file_request.key, - **download_file_request.extra_args, - )['ContentLength'] - return expected_size - - def _allocate_temp_file(self, download_file_request, size): - temp_filename = self._osutil.get_temp_filename( - download_file_request.filename - ) - self._osutil.allocate(temp_filename, size) - return temp_filename - - def _submit_single_get_object_job( - self, download_file_request, temp_filename - ): - self._notify_jobs_to_complete(download_file_request.transfer_id, 1) - self._submit_get_object_job( - transfer_id=download_file_request.transfer_id, - bucket=download_file_request.bucket, - key=download_file_request.key, - temp_filename=temp_filename, - offset=0, - extra_args=download_file_request.extra_args, - filename=download_file_request.filename, - ) - - def _submit_ranged_get_object_jobs( - self, download_file_request, temp_filename, size - ): - part_size = self._transfer_config.multipart_chunksize - num_parts = calculate_num_parts(size, part_size) - self._notify_jobs_to_complete( - download_file_request.transfer_id, num_parts - ) - for i in range(num_parts): - offset = i * part_size - range_parameter = calculate_range_parameter( - part_size, i, num_parts - ) - get_object_kwargs = {'Range': range_parameter} - get_object_kwargs.update(download_file_request.extra_args) - self._submit_get_object_job( - transfer_id=download_file_request.transfer_id, - bucket=download_file_request.bucket, - key=download_file_request.key, - temp_filename=temp_filename, - offset=offset, - extra_args=get_object_kwargs, - filename=download_file_request.filename, - ) - - def _submit_get_object_job(self, **get_object_job_kwargs): - self._worker_queue.put(GetObjectJob(**get_object_job_kwargs)) - - def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete): - logger.debug( - 'Notifying %s job(s) to complete for transfer_id %s.', - jobs_to_complete, - transfer_id, - ) - self._transfer_monitor.notify_expected_jobs_to_complete( - transfer_id, jobs_to_complete - ) - - -class GetObjectWorker(BaseS3TransferProcess): - # TODO: It may make sense to expose these class variables as configuration - # options if users want to tweak them. - _MAX_ATTEMPTS = 5 - _IO_CHUNKSIZE = 2 * MB - - def __init__(self, queue, client_factory, transfer_monitor, osutil): - """Fulfills GetObjectJobs - - Downloads the S3 object, writes it to the specified file, and - renames the file to its final location if it completes the final - job for a particular transfer. - - :param queue: Queue for retrieving GetObjectJob's - :param client_factory: ClientFactory for creating S3 clients - :param transfer_monitor: Monitor for notifying - :param osutil: OSUtils object to use for os-related behavior when - performing the transfer. - """ - super().__init__(client_factory) - self._queue = queue - self._client_factory = client_factory - self._transfer_monitor = transfer_monitor - self._osutil = osutil - - def _do_run(self): - while True: - job = self._queue.get() - if job == SHUTDOWN_SIGNAL: - logger.debug('Worker shutdown signal received.') - return - if not self._transfer_monitor.get_exception(job.transfer_id): - self._run_get_object_job(job) - else: - logger.debug( - 'Skipping get object job %s because there was a previous ' - 'exception.', - job, - ) - remaining = self._transfer_monitor.notify_job_complete( - job.transfer_id - ) - logger.debug( - '%s jobs remaining for transfer_id %s.', - remaining, - job.transfer_id, - ) - if not remaining: - self._finalize_download( - job.transfer_id, job.temp_filename, job.filename - ) - - def _run_get_object_job(self, job): - try: - self._do_get_object( - bucket=job.bucket, - key=job.key, - temp_filename=job.temp_filename, - extra_args=job.extra_args, - offset=job.offset, - ) - except Exception as e: - logger.debug( - 'Exception caught when downloading object for ' - 'get object job %s: %s', - job, - e, - exc_info=True, - ) - self._transfer_monitor.notify_exception(job.transfer_id, e) - - def _do_get_object(self, bucket, key, extra_args, temp_filename, offset): - last_exception = None - for i in range(self._MAX_ATTEMPTS): - try: - response = self._client.get_object( - Bucket=bucket, Key=key, **extra_args - ) - self._write_to_file(temp_filename, offset, response['Body']) - return - except S3_RETRYABLE_DOWNLOAD_ERRORS as e: - logger.debug( - 'Retrying exception caught (%s), ' - 'retrying request, (attempt %s / %s)', - e, - i + 1, - self._MAX_ATTEMPTS, - exc_info=True, - ) - last_exception = e - raise RetriesExceededError(last_exception) - - def _write_to_file(self, filename, offset, body): - with open(filename, 'rb+') as f: - f.seek(offset) - chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'') - for chunk in chunks: - f.write(chunk) - - def _finalize_download(self, transfer_id, temp_filename, filename): - if self._transfer_monitor.get_exception(transfer_id): - self._osutil.remove_file(temp_filename) - else: - self._do_file_rename(transfer_id, temp_filename, filename) - self._transfer_monitor.notify_done(transfer_id) - - def _do_file_rename(self, transfer_id, temp_filename, filename): - try: - self._osutil.rename_file(temp_filename, filename) - except Exception as e: - self._transfer_monitor.notify_exception(transfer_id, e) - self._osutil.remove_file(temp_filename) +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Speeds up S3 throughput by using processes + +Getting Started +=============== + +The :class:`ProcessPoolDownloader` can be used to download a single file by +calling :meth:`ProcessPoolDownloader.download_file`: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + downloader.download_file('mybucket', 'mykey', 'myfile') + + +This snippet downloads the S3 object located in the bucket ``mybucket`` at the +key ``mykey`` to the local file ``myfile``. Any errors encountered during the +transfer are not propagated. To determine if a transfer succeeded or +failed, use the `Futures`_ interface. + + +The :class:`ProcessPoolDownloader` can be used to download multiple files as +well: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + downloader.download_file('mybucket', 'mykey', 'myfile') + downloader.download_file('mybucket', 'myotherkey', 'myotherfile') + + +When running this snippet, the downloading of ``mykey`` and ``myotherkey`` +happen in parallel. The first ``download_file`` call does not block the +second ``download_file`` call. The snippet blocks when exiting +the context manager and blocks until both downloads are complete. + +Alternatively, the ``ProcessPoolDownloader`` can be instantiated +and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + downloader = ProcessPoolDownloader() + downloader.download_file('mybucket', 'mykey', 'myfile') + downloader.download_file('mybucket', 'myotherkey', 'myotherfile') + downloader.shutdown() + + +For this code snippet, the call to ``shutdown`` blocks until both +downloads are complete. + + +Additional Parameters +===================== + +Additional parameters can be provided to the ``download_file`` method: + +* ``extra_args``: A dictionary containing any additional client arguments + to include in the + `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_ + API request. For example: + + .. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + downloader.download_file( + 'mybucket', 'mykey', 'myfile', + extra_args={'VersionId': 'myversion'}) + + +* ``expected_size``: By default, the downloader will make a HeadObject + call to determine the size of the object. To opt-out of this additional + API call, you can provide the size of the object in bytes: + + .. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + MB = 1024 * 1024 + with ProcessPoolDownloader() as downloader: + downloader.download_file( + 'mybucket', 'mykey', 'myfile', expected_size=2 * MB) + + +Futures +======= + +When ``download_file`` is called, it immediately returns a +:class:`ProcessPoolTransferFuture`. The future can be used to poll the state +of a particular transfer. To get the result of the download, +call :meth:`ProcessPoolTransferFuture.result`. The method blocks +until the transfer completes, whether it succeeds or fails. For example: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + + with ProcessPoolDownloader() as downloader: + future = downloader.download_file('mybucket', 'mykey', 'myfile') + print(future.result()) + + +If the download succeeds, the future returns ``None``: + +.. code:: python + + None + + +If the download fails, the exception causing the failure is raised. For +example, if ``mykey`` did not exist, the following error would be raised + + +.. code:: python + + botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found + + +.. note:: + + :meth:`ProcessPoolTransferFuture.result` can only be called while the + ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or + inside the context manager). + + +Process Pool Configuration +========================== + +By default, the downloader has the following configuration options: + +* ``multipart_threshold``: The threshold size for performing ranged downloads + in bytes. By default, ranged downloads happen for S3 objects that are + greater than or equal to 8 MB in size. + +* ``multipart_chunksize``: The size of each ranged download in bytes. By + default, the size of each ranged download is 8 MB. + +* ``max_request_processes``: The maximum number of processes used to download + S3 objects. By default, the maximum is 10 processes. + + +To change the default configuration, use the :class:`ProcessTransferConfig`: + +.. code:: python + + from s3transfer.processpool import ProcessPoolDownloader + from s3transfer.processpool import ProcessTransferConfig + + config = ProcessTransferConfig( + multipart_threshold=64 * 1024 * 1024, # 64 MB + max_request_processes=50 + ) + downloader = ProcessPoolDownloader(config=config) + + +Client Configuration +==================== + +The process pool downloader creates ``botocore`` clients on your behalf. In +order to affect how the client is created, pass the keyword arguments +that would have been used in the :meth:`botocore.Session.create_client` call: + +.. code:: python + + + from s3transfer.processpool import ProcessPoolDownloader + from s3transfer.processpool import ProcessTransferConfig + + downloader = ProcessPoolDownloader( + client_kwargs={'region_name': 'us-west-2'}) + + +This snippet ensures that all clients created by the ``ProcessPoolDownloader`` +are using ``us-west-2`` as their region. + +""" +import collections +import contextlib +import logging +import multiprocessing +import signal +import threading +from copy import deepcopy + +import botocore.session +from botocore.config import Config + +from s3transfer.compat import MAXINT, BaseManager +from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT +from s3transfer.exceptions import CancelledError, RetriesExceededError +from s3transfer.futures import BaseTransferFuture, BaseTransferMeta +from s3transfer.utils import ( + S3_RETRYABLE_DOWNLOAD_ERRORS, + CallArgs, + OSUtils, + calculate_num_parts, + calculate_range_parameter, +) + +logger = logging.getLogger(__name__) + +SHUTDOWN_SIGNAL = 'SHUTDOWN' + +# The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader +# to the GetObjectSubmitter in order for the submitter to begin submitting +# GetObjectJobs to the GetObjectWorkers. +DownloadFileRequest = collections.namedtuple( + 'DownloadFileRequest', + [ + 'transfer_id', # The unique id for the transfer + 'bucket', # The bucket to download the object from + 'key', # The key to download the object from + 'filename', # The user-requested download location + 'extra_args', # Extra arguments to provide to client calls + 'expected_size', # The user-provided expected size of the download + ], +) + +# The GetObjectJob tuple is submitted from the GetObjectSubmitter +# to the GetObjectWorkers to download the file or parts of the file. +GetObjectJob = collections.namedtuple( + 'GetObjectJob', + [ + 'transfer_id', # The unique id for the transfer + 'bucket', # The bucket to download the object from + 'key', # The key to download the object from + 'temp_filename', # The temporary file to write the content to via + # completed GetObject calls. + 'extra_args', # Extra arguments to provide to the GetObject call + 'offset', # The offset to write the content for the temp file. + 'filename', # The user-requested download location. The worker + # of final GetObjectJob will move the file located at + # temp_filename to the location of filename. + ], +) + + +@contextlib.contextmanager +def ignore_ctrl_c(): + original_handler = _add_ignore_handler_for_interrupts() + yield + signal.signal(signal.SIGINT, original_handler) + + +def _add_ignore_handler_for_interrupts(): + # Windows is unable to pickle signal.signal directly so it needs to + # be wrapped in a function defined at the module level + return signal.signal(signal.SIGINT, signal.SIG_IGN) + + +class ProcessTransferConfig: + def __init__( + self, + multipart_threshold=8 * MB, + multipart_chunksize=8 * MB, + max_request_processes=10, + ): + """Configuration for the ProcessPoolDownloader + + :param multipart_threshold: The threshold for which ranged downloads + occur. + + :param multipart_chunksize: The chunk size of each ranged download. + + :param max_request_processes: The maximum number of processes that + will be making S3 API transfer-related requests at a time. + """ + self.multipart_threshold = multipart_threshold + self.multipart_chunksize = multipart_chunksize + self.max_request_processes = max_request_processes + + +class ProcessPoolDownloader: + def __init__(self, client_kwargs=None, config=None): + """Downloads S3 objects using process pools + + :type client_kwargs: dict + :param client_kwargs: The keyword arguments to provide when + instantiating S3 clients. The arguments must match the keyword + arguments provided to the + `botocore.session.Session.create_client()` method. + + :type config: ProcessTransferConfig + :param config: Configuration for the downloader + """ + if client_kwargs is None: + client_kwargs = {} + self._client_factory = ClientFactory(client_kwargs) + + self._transfer_config = config + if config is None: + self._transfer_config = ProcessTransferConfig() + + self._download_request_queue = multiprocessing.Queue(1000) + self._worker_queue = multiprocessing.Queue(1000) + self._osutil = OSUtils() + + self._started = False + self._start_lock = threading.Lock() + + # These below are initialized in the start() method + self._manager = None + self._transfer_monitor = None + self._submitter = None + self._workers = [] + + def download_file( + self, bucket, key, filename, extra_args=None, expected_size=None + ): + """Downloads the object's contents to a file + + :type bucket: str + :param bucket: The name of the bucket to download from + + :type key: str + :param key: The name of the key to download from + + :type filename: str + :param filename: The name of a file to download to. + + :type extra_args: dict + :param extra_args: Extra arguments that may be passed to the + client operation + + :type expected_size: int + :param expected_size: The expected size in bytes of the download. If + provided, the downloader will not call HeadObject to determine the + object's size and use the provided value instead. The size is + needed to determine whether to do a multipart download. + + :rtype: s3transfer.futures.TransferFuture + :returns: Transfer future representing the download + """ + self._start_if_needed() + if extra_args is None: + extra_args = {} + self._validate_all_known_args(extra_args) + transfer_id = self._transfer_monitor.notify_new_transfer() + download_file_request = DownloadFileRequest( + transfer_id=transfer_id, + bucket=bucket, + key=key, + filename=filename, + extra_args=extra_args, + expected_size=expected_size, + ) + logger.debug( + 'Submitting download file request: %s.', download_file_request + ) + self._download_request_queue.put(download_file_request) + call_args = CallArgs( + bucket=bucket, + key=key, + filename=filename, + extra_args=extra_args, + expected_size=expected_size, + ) + future = self._get_transfer_future(transfer_id, call_args) + return future + + def shutdown(self): + """Shutdown the downloader + + It will wait till all downloads are complete before returning. + """ + self._shutdown_if_needed() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, *args): + if isinstance(exc_value, KeyboardInterrupt): + if self._transfer_monitor is not None: + self._transfer_monitor.notify_cancel_all_in_progress() + self.shutdown() + + def _start_if_needed(self): + with self._start_lock: + if not self._started: + self._start() + + def _start(self): + self._start_transfer_monitor_manager() + self._start_submitter() + self._start_get_object_workers() + self._started = True + + def _validate_all_known_args(self, provided): + for kwarg in provided: + if kwarg not in ALLOWED_DOWNLOAD_ARGS: + download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS) + raise ValueError( + f"Invalid extra_args key '{kwarg}', " + f"must be one of: {download_args}" + ) + + def _get_transfer_future(self, transfer_id, call_args): + meta = ProcessPoolTransferMeta( + call_args=call_args, transfer_id=transfer_id + ) + future = ProcessPoolTransferFuture( + monitor=self._transfer_monitor, meta=meta + ) + return future + + def _start_transfer_monitor_manager(self): + logger.debug('Starting the TransferMonitorManager.') + self._manager = TransferMonitorManager() + # We do not want Ctrl-C's to cause the manager to shutdown immediately + # as worker processes will still need to communicate with it when they + # are shutting down. So instead we ignore Ctrl-C and let the manager + # be explicitly shutdown when shutting down the downloader. + self._manager.start(_add_ignore_handler_for_interrupts) + self._transfer_monitor = self._manager.TransferMonitor() + + def _start_submitter(self): + logger.debug('Starting the GetObjectSubmitter.') + self._submitter = GetObjectSubmitter( + transfer_config=self._transfer_config, + client_factory=self._client_factory, + transfer_monitor=self._transfer_monitor, + osutil=self._osutil, + download_request_queue=self._download_request_queue, + worker_queue=self._worker_queue, + ) + self._submitter.start() + + def _start_get_object_workers(self): + logger.debug( + 'Starting %s GetObjectWorkers.', + self._transfer_config.max_request_processes, + ) + for _ in range(self._transfer_config.max_request_processes): + worker = GetObjectWorker( + queue=self._worker_queue, + client_factory=self._client_factory, + transfer_monitor=self._transfer_monitor, + osutil=self._osutil, + ) + worker.start() + self._workers.append(worker) + + def _shutdown_if_needed(self): + with self._start_lock: + if self._started: + self._shutdown() + + def _shutdown(self): + self._shutdown_submitter() + self._shutdown_get_object_workers() + self._shutdown_transfer_monitor_manager() + self._started = False + + def _shutdown_transfer_monitor_manager(self): + logger.debug('Shutting down the TransferMonitorManager.') + self._manager.shutdown() + + def _shutdown_submitter(self): + logger.debug('Shutting down the GetObjectSubmitter.') + self._download_request_queue.put(SHUTDOWN_SIGNAL) + self._submitter.join() + + def _shutdown_get_object_workers(self): + logger.debug('Shutting down the GetObjectWorkers.') + for _ in self._workers: + self._worker_queue.put(SHUTDOWN_SIGNAL) + for worker in self._workers: + worker.join() + + +class ProcessPoolTransferFuture(BaseTransferFuture): + def __init__(self, monitor, meta): + """The future associated to a submitted process pool transfer request + + :type monitor: TransferMonitor + :param monitor: The monitor associated to the process pool downloader + + :type meta: ProcessPoolTransferMeta + :param meta: The metadata associated to the request. This object + is visible to the requester. + """ + self._monitor = monitor + self._meta = meta + + @property + def meta(self): + return self._meta + + def done(self): + return self._monitor.is_done(self._meta.transfer_id) + + def result(self): + try: + return self._monitor.poll_for_result(self._meta.transfer_id) + except KeyboardInterrupt: + # For the multiprocessing Manager, a thread is given a single + # connection to reuse in communicating between the thread in the + # main process and the Manager's process. If a Ctrl-C happens when + # polling for the result, it will make the main thread stop trying + # to receive from the connection, but the Manager process will not + # know that the main process has stopped trying to receive and + # will not close the connection. As a result if another message is + # sent to the Manager process, the listener in the Manager + # processes will not process the new message as it is still trying + # trying to process the previous message (that was Ctrl-C'd) and + # thus cause the thread in the main process to hang on its send. + # The only way around this is to create a new connection and send + # messages from that new connection instead. + self._monitor._connect() + self.cancel() + raise + + def cancel(self): + self._monitor.notify_exception( + self._meta.transfer_id, CancelledError() + ) + + +class ProcessPoolTransferMeta(BaseTransferMeta): + """Holds metadata about the ProcessPoolTransferFuture""" + + def __init__(self, transfer_id, call_args): + self._transfer_id = transfer_id + self._call_args = call_args + self._user_context = {} + + @property + def call_args(self): + return self._call_args + + @property + def transfer_id(self): + return self._transfer_id + + @property + def user_context(self): + return self._user_context + + +class ClientFactory: + def __init__(self, client_kwargs=None): + """Creates S3 clients for processes + + Botocore sessions and clients are not pickleable so they cannot be + inherited across Process boundaries. Instead, they must be instantiated + once a process is running. + """ + self._client_kwargs = client_kwargs + if self._client_kwargs is None: + self._client_kwargs = {} + + client_config = deepcopy(self._client_kwargs.get('config', Config())) + if not client_config.user_agent_extra: + client_config.user_agent_extra = PROCESS_USER_AGENT + else: + client_config.user_agent_extra += " " + PROCESS_USER_AGENT + self._client_kwargs['config'] = client_config + + def create_client(self): + """Create a botocore S3 client""" + return botocore.session.Session().create_client( + 's3', **self._client_kwargs + ) + + +class TransferMonitor: + def __init__(self): + """Monitors transfers for cross-process communication + + Notifications can be sent to the monitor and information can be + retrieved from the monitor for a particular transfer. This abstraction + is ran in a ``multiprocessing.managers.BaseManager`` in order to be + shared across processes. + """ + # TODO: Add logic that removes the TransferState if the transfer is + # marked as done and the reference to the future is no longer being + # held onto. Without this logic, this dictionary will continue to + # grow in size with no limit. + self._transfer_states = {} + self._id_count = 0 + self._init_lock = threading.Lock() + + def notify_new_transfer(self): + with self._init_lock: + transfer_id = self._id_count + self._transfer_states[transfer_id] = TransferState() + self._id_count += 1 + return transfer_id + + def is_done(self, transfer_id): + """Determine a particular transfer is complete + + :param transfer_id: Unique identifier for the transfer + :return: True, if done. False, otherwise. + """ + return self._transfer_states[transfer_id].done + + def notify_done(self, transfer_id): + """Notify a particular transfer is complete + + :param transfer_id: Unique identifier for the transfer + """ + self._transfer_states[transfer_id].set_done() + + def poll_for_result(self, transfer_id): + """Poll for the result of a transfer + + :param transfer_id: Unique identifier for the transfer + :return: If the transfer succeeded, it will return the result. If the + transfer failed, it will raise the exception associated to the + failure. + """ + self._transfer_states[transfer_id].wait_till_done() + exception = self._transfer_states[transfer_id].exception + if exception: + raise exception + return None + + def notify_exception(self, transfer_id, exception): + """Notify an exception was encountered for a transfer + + :param transfer_id: Unique identifier for the transfer + :param exception: The exception encountered for that transfer + """ + # TODO: Not all exceptions are pickleable so if we are running + # this in a multiprocessing.BaseManager we will want to + # make sure to update this signature to ensure pickleability of the + # arguments or have the ProxyObject do the serialization. + self._transfer_states[transfer_id].exception = exception + + def notify_cancel_all_in_progress(self): + for transfer_state in self._transfer_states.values(): + if not transfer_state.done: + transfer_state.exception = CancelledError() + + def get_exception(self, transfer_id): + """Retrieve the exception encountered for the transfer + + :param transfer_id: Unique identifier for the transfer + :return: The exception encountered for that transfer. Otherwise + if there were no exceptions, returns None. + """ + return self._transfer_states[transfer_id].exception + + def notify_expected_jobs_to_complete(self, transfer_id, num_jobs): + """Notify the amount of jobs expected for a transfer + + :param transfer_id: Unique identifier for the transfer + :param num_jobs: The number of jobs to complete the transfer + """ + self._transfer_states[transfer_id].jobs_to_complete = num_jobs + + def notify_job_complete(self, transfer_id): + """Notify that a single job is completed for a transfer + + :param transfer_id: Unique identifier for the transfer + :return: The number of jobs remaining to complete the transfer + """ + return self._transfer_states[transfer_id].decrement_jobs_to_complete() + + +class TransferState: + """Represents the current state of an individual transfer""" + + # NOTE: Ideally the TransferState object would be used directly by the + # various different abstractions in the ProcessPoolDownloader and remove + # the need for the TransferMonitor. However, it would then impose the + # constraint that two hops are required to make or get any changes in the + # state of a transfer across processes: one hop to get a proxy object for + # the TransferState and then a second hop to communicate calling the + # specific TransferState method. + def __init__(self): + self._exception = None + self._done_event = threading.Event() + self._job_lock = threading.Lock() + self._jobs_to_complete = 0 + + @property + def done(self): + return self._done_event.is_set() + + def set_done(self): + self._done_event.set() + + def wait_till_done(self): + self._done_event.wait(MAXINT) + + @property + def exception(self): + return self._exception + + @exception.setter + def exception(self, val): + self._exception = val + + @property + def jobs_to_complete(self): + return self._jobs_to_complete + + @jobs_to_complete.setter + def jobs_to_complete(self, val): + self._jobs_to_complete = val + + def decrement_jobs_to_complete(self): + with self._job_lock: + self._jobs_to_complete -= 1 + return self._jobs_to_complete + + +class TransferMonitorManager(BaseManager): + pass + + +TransferMonitorManager.register('TransferMonitor', TransferMonitor) + + +class BaseS3TransferProcess(multiprocessing.Process): + def __init__(self, client_factory): + super().__init__() + self._client_factory = client_factory + self._client = None + + def run(self): + # Clients are not pickleable so their instantiation cannot happen + # in the __init__ for processes that are created under the + # spawn method. + self._client = self._client_factory.create_client() + with ignore_ctrl_c(): + # By default these processes are ran as child processes to the + # main process. Any Ctrl-c encountered in the main process is + # propagated to the child process and interrupt it at any time. + # To avoid any potentially bad states caused from an interrupt + # (i.e. a transfer failing to notify its done or making the + # communication protocol become out of sync with the + # TransferMonitor), we ignore all Ctrl-C's and allow the main + # process to notify these child processes when to stop processing + # jobs. + self._do_run() + + def _do_run(self): + raise NotImplementedError('_do_run()') + + +class GetObjectSubmitter(BaseS3TransferProcess): + def __init__( + self, + transfer_config, + client_factory, + transfer_monitor, + osutil, + download_request_queue, + worker_queue, + ): + """Submit GetObjectJobs to fulfill a download file request + + :param transfer_config: Configuration for transfers. + :param client_factory: ClientFactory for creating S3 clients. + :param transfer_monitor: Monitor for notifying and retrieving state + of transfer. + :param osutil: OSUtils object to use for os-related behavior when + performing the transfer. + :param download_request_queue: Queue to retrieve download file + requests. + :param worker_queue: Queue to submit GetObjectJobs for workers + to perform. + """ + super().__init__(client_factory) + self._transfer_config = transfer_config + self._transfer_monitor = transfer_monitor + self._osutil = osutil + self._download_request_queue = download_request_queue + self._worker_queue = worker_queue + + def _do_run(self): + while True: + download_file_request = self._download_request_queue.get() + if download_file_request == SHUTDOWN_SIGNAL: + logger.debug('Submitter shutdown signal received.') + return + try: + self._submit_get_object_jobs(download_file_request) + except Exception as e: + logger.debug( + 'Exception caught when submitting jobs for ' + 'download file request %s: %s', + download_file_request, + e, + exc_info=True, + ) + self._transfer_monitor.notify_exception( + download_file_request.transfer_id, e + ) + self._transfer_monitor.notify_done( + download_file_request.transfer_id + ) + + def _submit_get_object_jobs(self, download_file_request): + size = self._get_size(download_file_request) + temp_filename = self._allocate_temp_file(download_file_request, size) + if size < self._transfer_config.multipart_threshold: + self._submit_single_get_object_job( + download_file_request, temp_filename + ) + else: + self._submit_ranged_get_object_jobs( + download_file_request, temp_filename, size + ) + + def _get_size(self, download_file_request): + expected_size = download_file_request.expected_size + if expected_size is None: + expected_size = self._client.head_object( + Bucket=download_file_request.bucket, + Key=download_file_request.key, + **download_file_request.extra_args, + )['ContentLength'] + return expected_size + + def _allocate_temp_file(self, download_file_request, size): + temp_filename = self._osutil.get_temp_filename( + download_file_request.filename + ) + self._osutil.allocate(temp_filename, size) + return temp_filename + + def _submit_single_get_object_job( + self, download_file_request, temp_filename + ): + self._notify_jobs_to_complete(download_file_request.transfer_id, 1) + self._submit_get_object_job( + transfer_id=download_file_request.transfer_id, + bucket=download_file_request.bucket, + key=download_file_request.key, + temp_filename=temp_filename, + offset=0, + extra_args=download_file_request.extra_args, + filename=download_file_request.filename, + ) + + def _submit_ranged_get_object_jobs( + self, download_file_request, temp_filename, size + ): + part_size = self._transfer_config.multipart_chunksize + num_parts = calculate_num_parts(size, part_size) + self._notify_jobs_to_complete( + download_file_request.transfer_id, num_parts + ) + for i in range(num_parts): + offset = i * part_size + range_parameter = calculate_range_parameter( + part_size, i, num_parts + ) + get_object_kwargs = {'Range': range_parameter} + get_object_kwargs.update(download_file_request.extra_args) + self._submit_get_object_job( + transfer_id=download_file_request.transfer_id, + bucket=download_file_request.bucket, + key=download_file_request.key, + temp_filename=temp_filename, + offset=offset, + extra_args=get_object_kwargs, + filename=download_file_request.filename, + ) + + def _submit_get_object_job(self, **get_object_job_kwargs): + self._worker_queue.put(GetObjectJob(**get_object_job_kwargs)) + + def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete): + logger.debug( + 'Notifying %s job(s) to complete for transfer_id %s.', + jobs_to_complete, + transfer_id, + ) + self._transfer_monitor.notify_expected_jobs_to_complete( + transfer_id, jobs_to_complete + ) + + +class GetObjectWorker(BaseS3TransferProcess): + # TODO: It may make sense to expose these class variables as configuration + # options if users want to tweak them. + _MAX_ATTEMPTS = 5 + _IO_CHUNKSIZE = 2 * MB + + def __init__(self, queue, client_factory, transfer_monitor, osutil): + """Fulfills GetObjectJobs + + Downloads the S3 object, writes it to the specified file, and + renames the file to its final location if it completes the final + job for a particular transfer. + + :param queue: Queue for retrieving GetObjectJob's + :param client_factory: ClientFactory for creating S3 clients + :param transfer_monitor: Monitor for notifying + :param osutil: OSUtils object to use for os-related behavior when + performing the transfer. + """ + super().__init__(client_factory) + self._queue = queue + self._client_factory = client_factory + self._transfer_monitor = transfer_monitor + self._osutil = osutil + + def _do_run(self): + while True: + job = self._queue.get() + if job == SHUTDOWN_SIGNAL: + logger.debug('Worker shutdown signal received.') + return + if not self._transfer_monitor.get_exception(job.transfer_id): + self._run_get_object_job(job) + else: + logger.debug( + 'Skipping get object job %s because there was a previous ' + 'exception.', + job, + ) + remaining = self._transfer_monitor.notify_job_complete( + job.transfer_id + ) + logger.debug( + '%s jobs remaining for transfer_id %s.', + remaining, + job.transfer_id, + ) + if not remaining: + self._finalize_download( + job.transfer_id, job.temp_filename, job.filename + ) + + def _run_get_object_job(self, job): + try: + self._do_get_object( + bucket=job.bucket, + key=job.key, + temp_filename=job.temp_filename, + extra_args=job.extra_args, + offset=job.offset, + ) + except Exception as e: + logger.debug( + 'Exception caught when downloading object for ' + 'get object job %s: %s', + job, + e, + exc_info=True, + ) + self._transfer_monitor.notify_exception(job.transfer_id, e) + + def _do_get_object(self, bucket, key, extra_args, temp_filename, offset): + last_exception = None + for i in range(self._MAX_ATTEMPTS): + try: + response = self._client.get_object( + Bucket=bucket, Key=key, **extra_args + ) + self._write_to_file(temp_filename, offset, response['Body']) + return + except S3_RETRYABLE_DOWNLOAD_ERRORS as e: + logger.debug( + 'Retrying exception caught (%s), ' + 'retrying request, (attempt %s / %s)', + e, + i + 1, + self._MAX_ATTEMPTS, + exc_info=True, + ) + last_exception = e + raise RetriesExceededError(last_exception) + + def _write_to_file(self, filename, offset, body): + with open(filename, 'rb+') as f: + f.seek(offset) + chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'') + for chunk in chunks: + f.write(chunk) + + def _finalize_download(self, transfer_id, temp_filename, filename): + if self._transfer_monitor.get_exception(transfer_id): + self._osutil.remove_file(temp_filename) + else: + self._do_file_rename(transfer_id, temp_filename, filename) + self._transfer_monitor.notify_done(transfer_id) + + def _do_file_rename(self, transfer_id, temp_filename, filename): + try: + self._osutil.rename_file(temp_filename, filename) + except Exception as e: + self._transfer_monitor.notify_exception(transfer_id, e) + self._osutil.remove_file(temp_filename) diff --git a/contrib/python/s3transfer/py3/s3transfer/subscribers.py b/contrib/python/s3transfer/py3/s3transfer/subscribers.py index 27e8fe0c4b..cf0dbaa0d7 100644 --- a/contrib/python/s3transfer/py3/s3transfer/subscribers.py +++ b/contrib/python/s3transfer/py3/s3transfer/subscribers.py @@ -14,34 +14,34 @@ from s3transfer.compat import accepts_kwargs from s3transfer.exceptions import InvalidSubscriberMethodError -class BaseSubscriber: +class BaseSubscriber: """The base subscriber class It is recommended that all subscriber implementations subclass and then override the subscription methods (i.e. on_{subsribe_type}() methods). """ - VALID_SUBSCRIBER_TYPES = ['queued', 'progress', 'done'] - + VALID_SUBSCRIBER_TYPES = ['queued', 'progress', 'done'] + def __new__(cls, *args, **kwargs): cls._validate_subscriber_methods() - return super().__new__(cls) + return super().__new__(cls) @classmethod def _validate_subscriber_methods(cls): for subscriber_type in cls.VALID_SUBSCRIBER_TYPES: subscriber_method = getattr(cls, 'on_' + subscriber_type) - if not callable(subscriber_method): + if not callable(subscriber_method): raise InvalidSubscriberMethodError( - 'Subscriber method %s must be callable.' - % subscriber_method - ) + 'Subscriber method %s must be callable.' + % subscriber_method + ) if not accepts_kwargs(subscriber_method): raise InvalidSubscriberMethodError( 'Subscriber method %s must accept keyword ' - 'arguments (**kwargs)' % subscriber_method - ) + 'arguments (**kwargs)' % subscriber_method + ) def on_queued(self, future, **kwargs): """Callback to be invoked when transfer request gets queued diff --git a/contrib/python/s3transfer/py3/s3transfer/tasks.py b/contrib/python/s3transfer/py3/s3transfer/tasks.py index 7153ac547c..1bad981264 100644 --- a/contrib/python/s3transfer/py3/s3transfer/tasks.py +++ b/contrib/python/s3transfer/py3/s3transfer/tasks.py @@ -18,21 +18,21 @@ from s3transfer.utils import get_callbacks logger = logging.getLogger(__name__) -class Task: +class Task: """A task associated to a TransferFuture request This is a base class for other classes to subclass from. All subclassed classes must implement the main() method. """ - - def __init__( - self, - transfer_coordinator, - main_kwargs=None, - pending_main_kwargs=None, - done_callbacks=None, - is_final=False, - ): + + def __init__( + self, + transfer_coordinator, + main_kwargs=None, + pending_main_kwargs=None, + done_callbacks=None, + is_final=False, + ): """ :type transfer_coordinator: s3transfer.futures.TransferCoordinator :param transfer_coordinator: The context associated to the @@ -85,22 +85,22 @@ class Task: # These are the general main_kwarg parameters that we want to # display in the repr. params_to_display = [ - 'bucket', - 'key', - 'part_number', - 'final_filename', - 'transfer_future', - 'offset', - 'extra_args', + 'bucket', + 'key', + 'part_number', + 'final_filename', + 'transfer_future', + 'offset', + 'extra_args', ] main_kwargs_to_display = self._get_kwargs_with_params_to_include( - self._main_kwargs, params_to_display - ) - return '{}(transfer_id={}, {})'.format( - self.__class__.__name__, - self._transfer_coordinator.transfer_id, - main_kwargs_to_display, - ) + self._main_kwargs, params_to_display + ) + return '{}(transfer_id={}, {})'.format( + self.__class__.__name__, + self._transfer_coordinator.transfer_id, + main_kwargs_to_display, + ) @property def transfer_id(self): @@ -130,7 +130,7 @@ class Task: # Gather up all of the main keyword arguments for main(). # This includes the immediately provided main_kwargs and # the values for pending_main_kwargs that source from the return - # values from the task's dependent futures. + # values from the task's dependent futures. kwargs = self._get_all_main_kwargs() # If the task is not done (really only if some other related # task to the TransferFuture had failed) then execute the task's @@ -154,10 +154,10 @@ class Task: # if they are going to make the logs hard to follow. params_to_exclude = ['data'] kwargs_to_display = self._get_kwargs_with_params_to_exclude( - kwargs, params_to_exclude - ) + kwargs, params_to_exclude + ) # Log what is about to be executed. - logger.debug(f"Executing task {self} with kwargs {kwargs_to_display}") + logger.debug(f"Executing task {self} with kwargs {kwargs_to_display}") return_value = self._main(**kwargs) # If the task is the final task, then set the TransferFuture's @@ -187,7 +187,7 @@ class Task: # If the pending main keyword arg is a list then extend the list. if isinstance(future, list): futures_to_wait_on.extend(future) - # If the pending main keyword arg is a future append it to the list. + # If the pending main keyword arg is a future append it to the list. else: futures_to_wait_on.append(future) # Now wait for all of the futures to complete. @@ -198,20 +198,20 @@ class Task: # # concurrent.futures.wait() is not used instead because of this # reported issue: https://bugs.python.org/issue20319. - # The issue would occasionally cause multipart uploads to hang + # The issue would occasionally cause multipart uploads to hang # when wait() was called. With this approach, it avoids the # concurrency bug by removing any association with concurrent.futures # implementation of waiters. logger.debug( - '%s about to wait for the following futures %s', self, futures - ) + '%s about to wait for the following futures %s', self, futures + ) for future in futures: try: logger.debug('%s about to wait for %s', self, future) future.result() except Exception: # result() can also produce exceptions. We want to ignore - # these to be deferred to error handling down the road. + # these to be deferred to error handling down the road. pass logger.debug('%s done waiting for dependent futures', self) @@ -243,7 +243,7 @@ class SubmissionTask(Task): Submission tasks are the top-level task used to submit a series of tasks to execute a particular transfer. """ - + def _main(self, transfer_future, **kwargs): """ :type transfer_future: s3transfer.futures.TransferFuture @@ -274,7 +274,7 @@ class SubmissionTask(Task): # the first place so we need to account accordingly. # # Note that BaseException is caught, instead of Exception, because - # for some implementations of executors, specifically the serial + # for some implementations of executors, specifically the serial # implementation, the SubmissionTask is directly exposed to # KeyboardInterupts and so needs to cleanup and signal done # for those as well. @@ -283,7 +283,7 @@ class SubmissionTask(Task): self._log_and_set_exception(e) # Wait for all possibly associated futures that may have spawned - # from this submission task have finished before we announce the + # from this submission task have finished before we announce the # transfer done. self._wait_for_all_submitted_futures_to_complete() @@ -292,7 +292,7 @@ class SubmissionTask(Task): self._transfer_coordinator.announce_done() def _submit(self, transfer_future, **kwargs): - """The submission method to be implemented + """The submission method to be implemented :type transfer_future: s3transfer.futures.TransferFuture :param transfer_future: The transfer future associated with the @@ -317,9 +317,9 @@ class SubmissionTask(Task): self._wait_until_all_complete(submitted_futures) # However, more futures may have been submitted as we waited so # we need to check again for any more associated futures. - possibly_more_submitted_futures = ( + possibly_more_submitted_futures = ( self._transfer_coordinator.associated_futures - ) + ) # If the current list of submitted futures is equal to the # the list of associated futures for when after the wait completes, # we can ensure no more futures were submitted in waiting on @@ -333,36 +333,36 @@ class SubmissionTask(Task): class CreateMultipartUploadTask(Task): """Task to initiate a multipart upload""" - + def _main(self, client, bucket, key, extra_args): """ :param client: The client to use when calling CreateMultipartUpload :param bucket: The name of the bucket to upload to :param key: The name of the key to upload to :param extra_args: A dictionary of any extra arguments that may be - used in the initialization. + used in the initialization. :returns: The upload id of the multipart upload """ # Create the multipart upload. response = client.create_multipart_upload( - Bucket=bucket, Key=key, **extra_args - ) + Bucket=bucket, Key=key, **extra_args + ) upload_id = response['UploadId'] # Add a cleanup if the multipart upload fails at any point. self._transfer_coordinator.add_failure_cleanup( - client.abort_multipart_upload, - Bucket=bucket, - Key=key, - UploadId=upload_id, + client.abort_multipart_upload, + Bucket=bucket, + Key=key, + UploadId=upload_id, ) return upload_id class CompleteMultipartUploadTask(Task): """Task to complete a multipart upload""" - + def _main(self, client, bucket, key, upload_id, parts, extra_args): """ :param client: The client to use when calling CompleteMultipartUpload @@ -379,9 +379,9 @@ class CompleteMultipartUploadTask(Task): used in completing the multipart transfer. """ client.complete_multipart_upload( - Bucket=bucket, - Key=key, - UploadId=upload_id, + Bucket=bucket, + Key=key, + UploadId=upload_id, MultipartUpload={'Parts': parts}, - **extra_args, - ) + **extra_args, + ) diff --git a/contrib/python/s3transfer/py3/s3transfer/upload.py b/contrib/python/s3transfer/py3/s3transfer/upload.py index 1407d14361..31ade051d7 100644 --- a/contrib/python/s3transfer/py3/s3transfer/upload.py +++ b/contrib/python/s3transfer/py3/s3transfer/upload.py @@ -11,25 +11,25 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import math -from io import BytesIO +from io import BytesIO -from s3transfer.compat import readable, seekable +from s3transfer.compat import readable, seekable from s3transfer.futures import IN_MEMORY_UPLOAD_TAG -from s3transfer.tasks import ( - CompleteMultipartUploadTask, - CreateMultipartUploadTask, - SubmissionTask, - Task, -) -from s3transfer.utils import ( - ChunksizeAdjuster, - DeferredOpenFile, - get_callbacks, - get_filtered_dict, -) - - -class AggregatedProgressCallback: +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import ( + ChunksizeAdjuster, + DeferredOpenFile, + get_callbacks, + get_filtered_dict, +) + + +class AggregatedProgressCallback: def __init__(self, callbacks, threshold=1024 * 256): """Aggregates progress updates for every provided progress callback @@ -62,10 +62,10 @@ class AggregatedProgressCallback: self._bytes_seen = 0 -class InterruptReader: +class InterruptReader: """Wrapper that can interrupt reading using an error - It uses a transfer coordinator to propagate an error if it notices + It uses a transfer coordinator to propagate an error if it notices that a read is being made while the file is being read from. :type fileobj: file-like obj @@ -75,7 +75,7 @@ class InterruptReader: :param transfer_coordinator: The transfer coordinator to use if the reader needs to be interrupted. """ - + def __init__(self, fileobj, transfer_coordinator): self._fileobj = fileobj self._transfer_coordinator = transfer_coordinator @@ -90,8 +90,8 @@ class InterruptReader: raise self._transfer_coordinator.exception return self._fileobj.read(amount) - def seek(self, where, whence=0): - self._fileobj.seek(where, whence) + def seek(self, where, whence=0): + self._fileobj.seek(where, whence) def tell(self): return self._fileobj.tell() @@ -106,7 +106,7 @@ class InterruptReader: self.close() -class UploadInputManager: +class UploadInputManager: """Base manager class for handling various types of files for uploads This class is typically used for the UploadSubmissionTask class to help @@ -121,7 +121,7 @@ class UploadInputManager: that may be accepted. All implementations must subclass and override public methods from this class. """ - + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): self._osutil = osutil self._transfer_coordinator = transfer_coordinator @@ -152,7 +152,7 @@ class UploadInputManager: memory. False if the manager will not directly store the body in memory. """ - raise NotImplementedError('must implement store_body_in_memory()') + raise NotImplementedError('must implement store_body_in_memory()') def provide_transfer_size(self, transfer_future): """Provides the transfer size of an upload @@ -173,7 +173,7 @@ class UploadInputManager: :rtype: boolean :returns: True, if the upload should be multipart based on - configuration and size. False, otherwise. + configuration and size. False, otherwise. """ raise NotImplementedError('must implement requires_multipart_upload()') @@ -212,8 +212,8 @@ class UploadInputManager: fileobj = InterruptReader(fileobj, self._transfer_coordinator) if self._bandwidth_limiter: fileobj = self._bandwidth_limiter.get_bandwith_limited_stream( - fileobj, self._transfer_coordinator, enabled=False - ) + fileobj, self._transfer_coordinator, enabled=False + ) return fileobj def _get_progress_callbacks(self, transfer_future): @@ -231,18 +231,18 @@ class UploadInputManager: class UploadFilenameInputManager(UploadInputManager): """Upload utility for filenames""" - + @classmethod def is_compatible(cls, upload_source): - return isinstance(upload_source, str) + return isinstance(upload_source, str) def stores_body_in_memory(self, operation_name): return False def provide_transfer_size(self, transfer_future): transfer_future.meta.provide_transfer_size( - self._osutil.get_file_size(transfer_future.meta.call_args.fileobj) - ) + self._osutil.get_file_size(transfer_future.meta.call_args.fileobj) + ) def requires_multipart_upload(self, transfer_future, config): return transfer_future.meta.size >= config.multipart_threshold @@ -250,8 +250,8 @@ class UploadFilenameInputManager(UploadInputManager): def get_put_object_body(self, transfer_future): # Get a file-like object for the given input fileobj, full_size = self._get_put_object_fileobj_with_full_size( - transfer_future - ) + transfer_future + ) # Wrap fileobj with interrupt reader that will quickly cancel # uploads if needed instead of having to wait for the socket @@ -264,12 +264,12 @@ class UploadFilenameInputManager(UploadInputManager): # Return the file-like object wrapped into a ReadFileChunk to get # progress. return self._osutil.open_file_chunk_reader_from_fileobj( - fileobj=fileobj, - chunk_size=size, - full_file_size=full_size, - callbacks=callbacks, - close_callbacks=close_callbacks, - ) + fileobj=fileobj, + chunk_size=size, + full_file_size=full_size, + callbacks=callbacks, + close_callbacks=close_callbacks, + ) def yield_upload_part_bodies(self, transfer_future, chunksize): full_file_size = transfer_future.meta.size @@ -281,11 +281,11 @@ class UploadFilenameInputManager(UploadInputManager): # Get a file-like object for that part and the size of the full # file size for the associated file-like object for that part. fileobj, full_size = self._get_upload_part_fileobj_with_full_size( - transfer_future.meta.call_args.fileobj, - start_byte=start_byte, - part_size=chunksize, - full_file_size=full_file_size, - ) + transfer_future.meta.call_args.fileobj, + start_byte=start_byte, + part_size=chunksize, + full_file_size=full_file_size, + ) # Wrap fileobj with interrupt reader that will quickly cancel # uploads if needed instead of having to wait for the socket @@ -294,18 +294,18 @@ class UploadFilenameInputManager(UploadInputManager): # Wrap the file-like object into a ReadFileChunk to get progress. read_file_chunk = self._osutil.open_file_chunk_reader_from_fileobj( - fileobj=fileobj, - chunk_size=chunksize, - full_file_size=full_size, - callbacks=callbacks, - close_callbacks=close_callbacks, - ) + fileobj=fileobj, + chunk_size=chunksize, + full_file_size=full_size, + callbacks=callbacks, + close_callbacks=close_callbacks, + ) yield part_number, read_file_chunk def _get_deferred_open_file(self, fileobj, start_byte): fileobj = DeferredOpenFile( - fileobj, start_byte, open_function=self._osutil.open - ) + fileobj, start_byte, open_function=self._osutil.open + ) return fileobj def _get_put_object_fileobj_with_full_size(self, transfer_future): @@ -319,12 +319,12 @@ class UploadFilenameInputManager(UploadInputManager): return self._get_deferred_open_file(fileobj, start_byte), full_size def _get_num_parts(self, transfer_future, part_size): - return int(math.ceil(transfer_future.meta.size / float(part_size))) + return int(math.ceil(transfer_future.meta.size / float(part_size))) class UploadSeekableInputManager(UploadFilenameInputManager): """Upload utility for an open file object""" - + @classmethod def is_compatible(cls, upload_source): return readable(upload_source) and seekable(upload_source) @@ -345,8 +345,8 @@ class UploadSeekableInputManager(UploadFilenameInputManager): end_position = fileobj.tell() fileobj.seek(start_position) transfer_future.meta.provide_transfer_size( - end_position - start_position - ) + end_position - start_position + ) def _get_upload_part_fileobj_with_full_size(self, fileobj, **kwargs): # Note: It is unfortunate that in order to do a multithreaded @@ -354,14 +354,14 @@ class UploadSeekableInputManager(UploadFilenameInputManager): # since there is not really a mechanism in python (i.e. os.dup # points to the same OS filehandle which causes concurrency # issues). So instead we need to read from the fileobj and - # chunk the data out to separate file-like objects in memory. + # chunk the data out to separate file-like objects in memory. data = fileobj.read(kwargs['part_size']) # We return the length of the data instead of the full_file_size - # because we partitioned the data into separate BytesIO objects + # because we partitioned the data into separate BytesIO objects # meaning the BytesIO object has no knowledge of its start position # relative the input source nor access to the rest of the input # source. So we must treat it as its own standalone file. - return BytesIO(data), len(data) + return BytesIO(data), len(data) def _get_put_object_fileobj_with_full_size(self, transfer_future): fileobj = transfer_future.meta.call_args.fileobj @@ -373,9 +373,9 @@ class UploadSeekableInputManager(UploadFilenameInputManager): class UploadNonSeekableInputManager(UploadInputManager): """Upload utility for a file-like object that cannot seek.""" - + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): - super().__init__(osutil, transfer_coordinator, bandwidth_limiter) + super().__init__(osutil, transfer_coordinator, bandwidth_limiter) self._initial_data = b'' @classmethod @@ -413,8 +413,8 @@ class UploadNonSeekableInputManager(UploadInputManager): fileobj = transfer_future.meta.call_args.fileobj body = self._wrap_data( - self._initial_data + fileobj.read(), callbacks, close_callbacks - ) + self._initial_data + fileobj.read(), callbacks, close_callbacks + ) # Zero out the stored data so we don't have additional copies # hanging around in memory. @@ -434,8 +434,8 @@ class UploadNonSeekableInputManager(UploadInputManager): if not part_content: break part_object = self._wrap_data( - part_content, callbacks, close_callbacks - ) + part_content, callbacks, close_callbacks + ) # Zero out part_content to avoid hanging on to additional data. part_content = None @@ -462,7 +462,7 @@ class UploadNonSeekableInputManager(UploadInputManager): if len(self._initial_data) == 0: return fileobj.read(amount) - # If the requested number of bytes is less than the amount of + # If the requested number of bytes is less than the amount of # initial data, pull entirely from initial data. if amount <= len(self._initial_data): data = self._initial_data[:amount] @@ -499,14 +499,14 @@ class UploadNonSeekableInputManager(UploadInputManager): :return: Fully wrapped data. """ - fileobj = self._wrap_fileobj(BytesIO(data)) + fileobj = self._wrap_fileobj(BytesIO(data)) return self._osutil.open_file_chunk_reader_from_fileobj( - fileobj=fileobj, - chunk_size=len(data), - full_file_size=len(data), - callbacks=callbacks, - close_callbacks=close_callbacks, - ) + fileobj=fileobj, + chunk_size=len(data), + full_file_size=len(data), + callbacks=callbacks, + close_callbacks=close_callbacks, + ) class UploadSubmissionTask(SubmissionTask): @@ -517,13 +517,13 @@ class UploadSubmissionTask(SubmissionTask): 'SSECustomerAlgorithm', 'SSECustomerKeyMD5', 'RequestPayer', - 'ExpectedBucketOwner', + 'ExpectedBucketOwner', ] - COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner'] + COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner'] def _get_upload_input_manager_cls(self, transfer_future): - """Retrieves a class for managing input for an upload based on file type + """Retrieves a class for managing input for an upload based on file type :type transfer_future: s3transfer.futures.TransferFuture :param transfer_future: The transfer future for the request @@ -535,7 +535,7 @@ class UploadSubmissionTask(SubmissionTask): upload_manager_resolver_chain = [ UploadFilenameInputManager, UploadSeekableInputManager, - UploadNonSeekableInputManager, + UploadNonSeekableInputManager, ] fileobj = transfer_future.meta.call_args.fileobj @@ -543,20 +543,20 @@ class UploadSubmissionTask(SubmissionTask): if upload_manager_cls.is_compatible(fileobj): return upload_manager_cls raise RuntimeError( - 'Input {} of type: {} is not supported.'.format( - fileobj, type(fileobj) - ) - ) - - def _submit( - self, - client, - config, - osutil, - request_executor, - transfer_future, - bandwidth_limiter=None, - ): + 'Input {} of type: {} is not supported.'.format( + fileobj, type(fileobj) + ) + ) + + def _submit( + self, + client, + config, + osutil, + request_executor, + transfer_future, + bandwidth_limiter=None, + ): """ :param client: The client associated with the transfer manager @@ -576,8 +576,8 @@ class UploadSubmissionTask(SubmissionTask): transfer request that tasks are being submitted for """ upload_input_manager = self._get_upload_input_manager_cls( - transfer_future - )(osutil, self._transfer_coordinator, bandwidth_limiter) + transfer_future + )(osutil, self._transfer_coordinator, bandwidth_limiter) # Determine the size if it was not provided if transfer_future.meta.size is None: @@ -585,41 +585,41 @@ class UploadSubmissionTask(SubmissionTask): # Do a multipart upload if needed, otherwise do a regular put object. if not upload_input_manager.requires_multipart_upload( - transfer_future, config - ): + transfer_future, config + ): self._submit_upload_request( - client, - config, - osutil, - request_executor, - transfer_future, - upload_input_manager, - ) + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ) else: self._submit_multipart_request( - client, - config, - osutil, - request_executor, - transfer_future, - upload_input_manager, - ) - - def _submit_upload_request( - self, - client, - config, - osutil, - request_executor, - transfer_future, - upload_input_manager, - ): + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ) + + def _submit_upload_request( + self, + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ): call_args = transfer_future.meta.call_args # Get any tags that need to be associated to the put object task put_object_tag = self._get_upload_task_tag( - upload_input_manager, 'put_object' - ) + upload_input_manager, 'put_object' + ) # Submit the request of a single upload. self._transfer_coordinator.submit( @@ -629,26 +629,26 @@ class UploadSubmissionTask(SubmissionTask): main_kwargs={ 'client': client, 'fileobj': upload_input_manager.get_put_object_body( - transfer_future - ), + transfer_future + ), 'bucket': call_args.bucket, 'key': call_args.key, - 'extra_args': call_args.extra_args, + 'extra_args': call_args.extra_args, }, - is_final=True, + is_final=True, ), - tag=put_object_tag, + tag=put_object_tag, ) - def _submit_multipart_request( - self, - client, - config, - osutil, - request_executor, - transfer_future, - upload_input_manager, - ): + def _submit_multipart_request( + self, + client, + config, + osutil, + request_executor, + transfer_future, + upload_input_manager, + ): call_args = transfer_future.meta.call_args # Submit the request to create a multipart upload. @@ -661,8 +661,8 @@ class UploadSubmissionTask(SubmissionTask): 'bucket': call_args.bucket, 'key': call_args.key, 'extra_args': call_args.extra_args, - }, - ), + }, + ), ) # Submit requests to upload the parts of the file. @@ -672,15 +672,15 @@ class UploadSubmissionTask(SubmissionTask): # Get any tags that need to be associated to the submitted task # for upload the data upload_part_tag = self._get_upload_task_tag( - upload_input_manager, 'upload_part' - ) + upload_input_manager, 'upload_part' + ) size = transfer_future.meta.size adjuster = ChunksizeAdjuster() chunksize = adjuster.adjust_chunksize(config.multipart_chunksize, size) part_iterator = upload_input_manager.yield_upload_part_bodies( - transfer_future, chunksize - ) + transfer_future, chunksize + ) for part_number, fileobj in part_iterator: part_futures.append( @@ -694,19 +694,19 @@ class UploadSubmissionTask(SubmissionTask): 'bucket': call_args.bucket, 'key': call_args.key, 'part_number': part_number, - 'extra_args': extra_part_args, + 'extra_args': extra_part_args, }, pending_main_kwargs={ 'upload_id': create_multipart_future - }, + }, ), - tag=upload_part_tag, + tag=upload_part_tag, ) ) complete_multipart_extra_args = self._extra_complete_multipart_args( - call_args.extra_args - ) + call_args.extra_args + ) # Submit the request to complete the multipart upload. self._transfer_coordinator.submit( request_executor, @@ -720,10 +720,10 @@ class UploadSubmissionTask(SubmissionTask): }, pending_main_kwargs={ 'upload_id': create_multipart_future, - 'parts': part_futures, + 'parts': part_futures, }, - is_final=True, - ), + is_final=True, + ), ) def _extra_upload_part_args(self, extra_args): @@ -743,7 +743,7 @@ class UploadSubmissionTask(SubmissionTask): class PutObjectTask(Task): """Task to do a nonmultipart upload""" - + def _main(self, client, fileobj, bucket, key, extra_args): """ :param client: The client to use when calling PutObject @@ -759,10 +759,10 @@ class PutObjectTask(Task): class UploadPartTask(Task): """Task to upload a part in a multipart upload""" - - def _main( - self, client, fileobj, bucket, key, upload_id, part_number, extra_args - ): + + def _main( + self, client, fileobj, bucket, key, upload_id, part_number, extra_args + ): """ :param client: The client to use when calling PutObject :param fileobj: The file to upload. @@ -784,12 +784,12 @@ class UploadPartTask(Task): """ with fileobj as body: response = client.upload_part( - Bucket=bucket, - Key=key, - UploadId=upload_id, - PartNumber=part_number, - Body=body, - **extra_args - ) + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + Body=body, + **extra_args + ) etag = response['ETag'] return {'ETag': etag, 'PartNumber': part_number} diff --git a/contrib/python/s3transfer/py3/s3transfer/utils.py b/contrib/python/s3transfer/py3/s3transfer/utils.py index 31ada34c65..ba881c67dd 100644 --- a/contrib/python/s3transfer/py3/s3transfer/utils.py +++ b/contrib/python/s3transfer/py3/s3transfer/utils.py @@ -11,19 +11,19 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import functools -import logging +import logging import math import os -import random -import socket +import random +import socket import stat import string import threading from collections import defaultdict -from botocore.exceptions import IncompleteReadError, ReadTimeoutError - -from s3transfer.compat import SOCKET_ERROR, fallocate, rename_file +from botocore.exceptions import IncompleteReadError, ReadTimeoutError + +from s3transfer.compat import SOCKET_ERROR, fallocate, rename_file MAX_PARTS = 10000 # The maximum file size you can upload via S3 per request. @@ -34,39 +34,39 @@ MIN_UPLOAD_CHUNKSIZE = 5 * (1024 ** 2) logger = logging.getLogger(__name__) -S3_RETRYABLE_DOWNLOAD_ERRORS = ( - socket.timeout, - SOCKET_ERROR, - ReadTimeoutError, - IncompleteReadError, -) - - +S3_RETRYABLE_DOWNLOAD_ERRORS = ( + socket.timeout, + SOCKET_ERROR, + ReadTimeoutError, + IncompleteReadError, +) + + def random_file_extension(num_digits=8): return ''.join(random.choice(string.hexdigits) for _ in range(num_digits)) def signal_not_transferring(request, operation_name, **kwargs): - if operation_name in ['PutObject', 'UploadPart'] and hasattr( - request.body, 'signal_not_transferring' - ): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'signal_not_transferring' + ): request.body.signal_not_transferring() def signal_transferring(request, operation_name, **kwargs): - if operation_name in ['PutObject', 'UploadPart'] and hasattr( - request.body, 'signal_transferring' - ): + if operation_name in ['PutObject', 'UploadPart'] and hasattr( + request.body, 'signal_transferring' + ): request.body.signal_transferring() -def calculate_num_parts(size, part_size): - return int(math.ceil(size / float(part_size))) - - -def calculate_range_parameter( - part_size, part_index, num_parts, total_size=None -): +def calculate_num_parts(size, part_size): + return int(math.ceil(size / float(part_size))) + + +def calculate_range_parameter( + part_size, part_index, num_parts, total_size=None +): """Calculate the range parameter for multipart downloads/copies :type part_size: int @@ -90,7 +90,7 @@ def calculate_range_parameter( end_range = str(total_size - 1) else: end_range = start_range + part_size - 1 - range_param = f'bytes={start_range}-{end_range}' + range_param = f'bytes={start_range}-{end_range}' return range_param @@ -117,7 +117,7 @@ def get_callbacks(transfer_future, callback_type): if hasattr(subscriber, callback_name): callbacks.append( functools.partial( - getattr(subscriber, callback_name), future=transfer_future + getattr(subscriber, callback_name), future=transfer_future ) ) return callbacks @@ -157,7 +157,7 @@ def get_filtered_dict(original_dict, whitelisted_keys): return filtered_dict -class CallArgs: +class CallArgs: def __init__(self, **kwargs): """A class that records call arguments @@ -169,33 +169,33 @@ class CallArgs: setattr(self, arg, value) -class FunctionContainer: +class FunctionContainer: """An object that contains a function and any args or kwargs to call it When called the provided function will be called with provided args and kwargs. """ - + def __init__(self, func, *args, **kwargs): self._func = func self._args = args self._kwargs = kwargs def __repr__(self): - return 'Function: {} with args {} and kwargs {}'.format( - self._func, self._args, self._kwargs - ) + return 'Function: {} with args {} and kwargs {}'.format( + self._func, self._args, self._kwargs + ) def __call__(self): return self._func(*self._args, **self._kwargs) -class CountCallbackInvoker: +class CountCallbackInvoker: """An abstraction to invoke a callback when a shared count reaches zero :param callback: Callback invoke when finalized count reaches zero """ - + def __init__(self, callback): self._lock = threading.Lock() self._callback = callback @@ -222,8 +222,8 @@ class CountCallbackInvoker: with self._lock: if self._count == 0: raise RuntimeError( - 'Counter is at zero. It cannot dip below zero' - ) + 'Counter is at zero. It cannot dip below zero' + ) self._count -= 1 if self._is_finalized and self._count == 0: self._callback() @@ -240,33 +240,33 @@ class CountCallbackInvoker: self._callback() -class OSUtils: - _MAX_FILENAME_LEN = 255 - +class OSUtils: + _MAX_FILENAME_LEN = 255 + def get_file_size(self, filename): return os.path.getsize(filename) def open_file_chunk_reader(self, filename, start_byte, size, callbacks): - return ReadFileChunk.from_filename( - filename, start_byte, size, callbacks, enable_callbacks=False - ) - - def open_file_chunk_reader_from_fileobj( - self, - fileobj, - chunk_size, - full_file_size, - callbacks, - close_callbacks=None, - ): + return ReadFileChunk.from_filename( + filename, start_byte, size, callbacks, enable_callbacks=False + ) + + def open_file_chunk_reader_from_fileobj( + self, + fileobj, + chunk_size, + full_file_size, + callbacks, + close_callbacks=None, + ): return ReadFileChunk( - fileobj, - chunk_size, - full_file_size, - callbacks=callbacks, - enable_callbacks=False, - close_callbacks=close_callbacks, - ) + fileobj, + chunk_size, + full_file_size, + callbacks=callbacks, + enable_callbacks=False, + close_callbacks=close_callbacks, + ) def open(self, filename, mode): return open(filename, mode) @@ -312,23 +312,23 @@ class OSUtils: return True return False - def get_temp_filename(self, filename): - suffix = os.extsep + random_file_extension() - path = os.path.dirname(filename) - name = os.path.basename(filename) - temp_filename = name[: self._MAX_FILENAME_LEN - len(suffix)] + suffix - return os.path.join(path, temp_filename) - - def allocate(self, filename, size): - try: - with self.open(filename, 'wb') as f: - fallocate(f, size) - except OSError: - self.remove_file(filename) - raise - - -class DeferredOpenFile: + def get_temp_filename(self, filename): + suffix = os.extsep + random_file_extension() + path = os.path.dirname(filename) + name = os.path.basename(filename) + temp_filename = name[: self._MAX_FILENAME_LEN - len(suffix)] + suffix + return os.path.join(path, temp_filename) + + def allocate(self, filename, size): + try: + with self.open(filename, 'wb') as f: + fallocate(f, size) + except OSError: + self.remove_file(filename) + raise + + +class DeferredOpenFile: def __init__(self, filename, start_byte=0, mode='rb', open_function=open): """A class that defers the opening of a file till needed @@ -374,9 +374,9 @@ class DeferredOpenFile: self._open_if_needed() self._fileobj.write(data) - def seek(self, where, whence=0): + def seek(self, where, whence=0): self._open_if_needed() - self._fileobj.seek(where, whence) + self._fileobj.seek(where, whence) def tell(self): if self._fileobj is None: @@ -395,16 +395,16 @@ class DeferredOpenFile: self.close() -class ReadFileChunk: - def __init__( - self, - fileobj, - chunk_size, - full_file_size, - callbacks=None, - enable_callbacks=True, - close_callbacks=None, - ): +class ReadFileChunk: + def __init__( + self, + fileobj, + chunk_size, + full_file_size, + callbacks=None, + enable_callbacks=True, + close_callbacks=None, + ): """ Given a file object shown below:: @@ -441,13 +441,13 @@ class ReadFileChunk: self._fileobj = fileobj self._start_byte = self._fileobj.tell() self._size = self._calculate_file_size( - self._fileobj, - requested_size=chunk_size, - start_byte=self._start_byte, - actual_file_size=full_file_size, - ) - # _amount_read represents the position in the chunk and may exceed - # the chunk size, but won't allow reads out of bounds. + self._fileobj, + requested_size=chunk_size, + start_byte=self._start_byte, + actual_file_size=full_file_size, + ) + # _amount_read represents the position in the chunk and may exceed + # the chunk size, but won't allow reads out of bounds. self._amount_read = 0 self._callbacks = callbacks if callbacks is None: @@ -458,14 +458,14 @@ class ReadFileChunk: self._close_callbacks = close_callbacks @classmethod - def from_filename( - cls, - filename, - start_byte, - chunk_size, - callbacks=None, - enable_callbacks=True, - ): + def from_filename( + cls, + filename, + start_byte, + chunk_size, + callbacks=None, + enable_callbacks=True, + ): """Convenience factory function to create from a filename. :type start_byte: int @@ -496,18 +496,18 @@ class ReadFileChunk: file_size = os.fstat(f.fileno()).st_size return cls(f, chunk_size, file_size, callbacks, enable_callbacks) - def _calculate_file_size( - self, fileobj, requested_size, start_byte, actual_file_size - ): + def _calculate_file_size( + self, fileobj, requested_size, start_byte, actual_file_size + ): max_chunk_size = actual_file_size - start_byte return min(max_chunk_size, requested_size) def read(self, amount=None): - amount_left = max(self._size - self._amount_read, 0) + amount_left = max(self._size - self._amount_read, 0) if amount is None: - amount_to_read = amount_left + amount_to_read = amount_left else: - amount_to_read = min(amount_left, amount) + amount_to_read = min(amount_left, amount) data = self._fileobj.read(amount_to_read) self._amount_read += len(data) if self._callbacks is not None and self._callbacks_enabled: @@ -530,29 +530,29 @@ class ReadFileChunk: def disable_callback(self): self._callbacks_enabled = False - def seek(self, where, whence=0): - if whence not in (0, 1, 2): - # Mimic io's error for invalid whence values - raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)") - - # Recalculate where based on chunk attributes so seek from file - # start (whence=0) is always used - where += self._start_byte - if whence == 1: - where += self._amount_read - elif whence == 2: - where += self._size - - self._fileobj.seek(max(where, self._start_byte)) + def seek(self, where, whence=0): + if whence not in (0, 1, 2): + # Mimic io's error for invalid whence values + raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)") + + # Recalculate where based on chunk attributes so seek from file + # start (whence=0) is always used + where += self._start_byte + if whence == 1: + where += self._amount_read + elif whence == 2: + where += self._size + + self._fileobj.seek(max(where, self._start_byte)) if self._callbacks is not None and self._callbacks_enabled: # To also rewind the callback() for an accurate progress report - bounded_where = max(min(where - self._start_byte, self._size), 0) - bounded_amount_read = min(self._amount_read, self._size) - amount = bounded_where - bounded_amount_read + bounded_where = max(min(where - self._start_byte, self._size), 0) + bounded_amount_read = min(self._amount_read, self._size) + amount = bounded_where - bounded_amount_read invoke_progress_callbacks( - self._callbacks, bytes_transferred=amount - ) - self._amount_read = max(where - self._start_byte, 0) + self._callbacks, bytes_transferred=amount + ) + self._amount_read = max(where - self._start_byte, 0) def close(self): if self._close_callbacks is not None and self._callbacks_enabled: @@ -586,9 +586,9 @@ class ReadFileChunk: return iter([]) -class StreamReaderProgress: +class StreamReaderProgress: """Wrapper for a read only stream that adds progress callbacks.""" - + def __init__(self, stream, callbacks=None): self._stream = stream self._callbacks = callbacks @@ -605,7 +605,7 @@ class NoResourcesAvailable(Exception): pass -class TaskSemaphore: +class TaskSemaphore: def __init__(self, count): """A semaphore for the purpose of limiting the number of tasks @@ -621,7 +621,7 @@ class TaskSemaphore: needed for API compatibility with the SlidingWindowSemaphore implementation. :param block: If True, block until it can be acquired. If False, - do not block and raise an exception if cannot be acquired. + do not block and raise an exception if cannot be acquired. :returns: A token (can be None) to use when releasing the semaphore """ @@ -638,7 +638,7 @@ class TaskSemaphore: class but is needed for API compatibility with the SlidingWindowSemaphore implementation. """ - logger.debug(f"Releasing acquire {tag}/{acquire_token}") + logger.debug(f"Releasing acquire {tag}/{acquire_token}") self._semaphore.release() @@ -664,7 +664,7 @@ class SlidingWindowSemaphore(TaskSemaphore): when the minimum sequence number for a tag is released. """ - + def __init__(self, count): self._count = count # Dict[tag, next_sequence_number]. @@ -727,26 +727,26 @@ class SlidingWindowSemaphore(TaskSemaphore): # We can't do anything right now because we're still waiting # for the min sequence for the tag to be released. We have # to queue this for pending release. - self._pending_release.setdefault(tag, []).append( - sequence_number - ) + self._pending_release.setdefault(tag, []).append( + sequence_number + ) self._pending_release[tag].sort(reverse=True) else: raise ValueError( "Attempted to release unknown sequence number " - "%s for tag: %s" % (sequence_number, tag) - ) + "%s for tag: %s" % (sequence_number, tag) + ) finally: self._condition.release() -class ChunksizeAdjuster: - def __init__( - self, - max_size=MAX_SINGLE_UPLOAD_SIZE, - min_size=MIN_UPLOAD_CHUNKSIZE, - max_parts=MAX_PARTS, - ): +class ChunksizeAdjuster: + def __init__( + self, + max_size=MAX_SINGLE_UPLOAD_SIZE, + min_size=MIN_UPLOAD_CHUNKSIZE, + max_parts=MAX_PARTS, + ): self.max_size = max_size self.min_size = min_size self.max_parts = max_parts @@ -772,14 +772,14 @@ class ChunksizeAdjuster: if current_chunksize > self.max_size: logger.debug( "Chunksize greater than maximum chunksize. " - "Setting to %s from %s." % (self.max_size, current_chunksize) - ) + "Setting to %s from %s." % (self.max_size, current_chunksize) + ) return self.max_size elif current_chunksize < self.min_size: logger.debug( "Chunksize less than minimum chunksize. " - "Setting to %s from %s." % (self.min_size, current_chunksize) - ) + "Setting to %s from %s." % (self.min_size, current_chunksize) + ) return self.min_size else: return current_chunksize @@ -795,8 +795,8 @@ class ChunksizeAdjuster: if chunksize != current_chunksize: logger.debug( "Chunksize would result in the number of parts exceeding the " - "maximum. Setting to %s from %s." - % (chunksize, current_chunksize) - ) + "maximum. Setting to %s from %s." + % (chunksize, current_chunksize) + ) return chunksize diff --git a/contrib/python/s3transfer/py3/tests/__init__.py b/contrib/python/s3transfer/py3/tests/__init__.py index 229371009a..e36c4936bf 100644 --- a/contrib/python/s3transfer/py3/tests/__init__.py +++ b/contrib/python/s3transfer/py3/tests/__init__.py @@ -1,531 +1,531 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import hashlib -import io -import math -import os -import platform -import shutil -import string -import tempfile -import unittest -from unittest import mock # noqa: F401 - -import botocore.session -from botocore.stub import Stubber - -from s3transfer.futures import ( - IN_MEMORY_DOWNLOAD_TAG, - IN_MEMORY_UPLOAD_TAG, - BoundedExecutor, - NonThreadedExecutor, - TransferCoordinator, - TransferFuture, - TransferMeta, -) -from s3transfer.manager import TransferConfig -from s3transfer.subscribers import BaseSubscriber -from s3transfer.utils import ( - CallArgs, - OSUtils, - SlidingWindowSemaphore, - TaskSemaphore, -) - -ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS -# Detect if CRT is available for use -try: - import awscrt.s3 # noqa: F401 - - HAS_CRT = True -except ImportError: - HAS_CRT = False - - -def setup_package(): - if is_serial_implementation(): - BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor - - -def teardown_package(): - BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS - - -def is_serial_implementation(): - return os.environ.get('USE_SERIAL_EXECUTOR', False) - - -def assert_files_equal(first, second): - if os.path.getsize(first) != os.path.getsize(second): - raise AssertionError(f"Files are not equal: {first}, {second}") - first_md5 = md5_checksum(first) - second_md5 = md5_checksum(second) - if first_md5 != second_md5: - raise AssertionError( - "Files are not equal: {}(md5={}) != {}(md5={})".format( - first, first_md5, second, second_md5 - ) - ) - - -def md5_checksum(filename): - checksum = hashlib.md5() - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(8192), b''): - checksum.update(chunk) - return checksum.hexdigest() - - -def random_bucket_name(prefix='s3transfer', num_chars=10): - base = string.ascii_lowercase + string.digits - random_bytes = bytearray(os.urandom(num_chars)) - return prefix + ''.join([base[b % len(base)] for b in random_bytes]) - - -def skip_if_windows(reason): - """Decorator to skip tests that should not be run on windows. - - Example usage: - - @skip_if_windows("Not valid") - def test_some_non_windows_stuff(self): - self.assertEqual(...) - - """ - - def decorator(func): - return unittest.skipIf( - platform.system() not in ['Darwin', 'Linux'], reason - )(func) - - return decorator - - -def skip_if_using_serial_implementation(reason): - """Decorator to skip tests when running as the serial implementation""" - - def decorator(func): - return unittest.skipIf(is_serial_implementation(), reason)(func) - - return decorator - - -def requires_crt(cls, reason=None): - if reason is None: - reason = "Test requires awscrt to be installed." - return unittest.skipIf(not HAS_CRT, reason)(cls) - - -class StreamWithError: - """A wrapper to simulate errors while reading from a stream - - :param stream: The underlying stream to read from - :param exception_type: The exception type to throw - :param num_reads: The number of times to allow a read before raising - the exception. A value of zero indicates to raise the error on the - first read. - """ - - def __init__(self, stream, exception_type, num_reads=0): - self._stream = stream - self._exception_type = exception_type - self._num_reads = num_reads - self._count = 0 - - def read(self, n=-1): - if self._count == self._num_reads: - raise self._exception_type - self._count += 1 - return self._stream.read(n) - - -class FileSizeProvider: - def __init__(self, file_size): - self.file_size = file_size - - def on_queued(self, future, **kwargs): - future.meta.provide_transfer_size(self.file_size) - - -class FileCreator: - def __init__(self): - self.rootdir = tempfile.mkdtemp() - - def remove_all(self): - shutil.rmtree(self.rootdir) - - def create_file(self, filename, contents, mode='w'): - """Creates a file in a tmpdir - ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" - It will be translated into a full path in a tmp dir. - ``mode`` is the mode the file should be opened either as ``w`` or - `wb``. - Returns the full path to the file. - """ - full_path = os.path.join(self.rootdir, filename) - if not os.path.isdir(os.path.dirname(full_path)): - os.makedirs(os.path.dirname(full_path)) - with open(full_path, mode) as f: - f.write(contents) - return full_path - - def create_file_with_size(self, filename, filesize): - filename = self.create_file(filename, contents='') - chunksize = 8192 - with open(filename, 'wb') as f: - for i in range(int(math.ceil(filesize / float(chunksize)))): - f.write(b'a' * chunksize) - return filename - - def append_file(self, filename, contents): - """Append contents to a file - ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" - It will be translated into a full path in a tmp dir. - Returns the full path to the file. - """ - full_path = os.path.join(self.rootdir, filename) - if not os.path.isdir(os.path.dirname(full_path)): - os.makedirs(os.path.dirname(full_path)) - with open(full_path, 'a') as f: - f.write(contents) - return full_path - - def full_path(self, filename): - """Translate relative path to full path in temp dir. - f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt - """ - return os.path.join(self.rootdir, filename) - - -class RecordingOSUtils(OSUtils): - """An OSUtil abstraction that records openings and renamings""" - - def __init__(self): - super().__init__() - self.open_records = [] - self.rename_records = [] - - def open(self, filename, mode): - self.open_records.append((filename, mode)) - return super().open(filename, mode) - - def rename_file(self, current_filename, new_filename): - self.rename_records.append((current_filename, new_filename)) - super().rename_file(current_filename, new_filename) - - -class RecordingSubscriber(BaseSubscriber): - def __init__(self): - self.on_queued_calls = [] - self.on_progress_calls = [] - self.on_done_calls = [] - - def on_queued(self, **kwargs): - self.on_queued_calls.append(kwargs) - - def on_progress(self, **kwargs): - self.on_progress_calls.append(kwargs) - - def on_done(self, **kwargs): - self.on_done_calls.append(kwargs) - - def calculate_bytes_seen(self, **kwargs): - amount_seen = 0 - for call in self.on_progress_calls: - amount_seen += call['bytes_transferred'] - return amount_seen - - -class TransferCoordinatorWithInterrupt(TransferCoordinator): - """Used to inject keyboard interrupts""" - - def result(self): - raise KeyboardInterrupt() - - -class RecordingExecutor: - """A wrapper on an executor to record calls made to submit() - - You can access the submissions property to receive a list of dictionaries - that represents all submissions where the dictionary is formatted:: - - { - 'fn': function - 'args': positional args (as tuple) - 'kwargs': keyword args (as dict) - } - """ - - def __init__(self, executor): - self._executor = executor - self.submissions = [] - - def submit(self, task, tag=None, block=True): - future = self._executor.submit(task, tag, block) - self.submissions.append({'task': task, 'tag': tag, 'block': block}) - return future - - def shutdown(self): - self._executor.shutdown() - - -class StubbedClientTest(unittest.TestCase): - def setUp(self): - self.session = botocore.session.get_session() - self.region = 'us-west-2' - self.client = self.session.create_client( - 's3', - self.region, - aws_access_key_id='foo', - aws_secret_access_key='bar', - ) - self.stubber = Stubber(self.client) - self.stubber.activate() - - def tearDown(self): - self.stubber.deactivate() - - def reset_stubber_with_new_client(self, override_client_kwargs): - client_kwargs = { - 'service_name': 's3', - 'region_name': self.region, - 'aws_access_key_id': 'foo', - 'aws_secret_access_key': 'bar', - } - client_kwargs.update(override_client_kwargs) - self.client = self.session.create_client(**client_kwargs) - self.stubber = Stubber(self.client) - self.stubber.activate() - - -class BaseTaskTest(StubbedClientTest): - def setUp(self): - super().setUp() - self.transfer_coordinator = TransferCoordinator() - - def get_task(self, task_cls, **kwargs): - if 'transfer_coordinator' not in kwargs: - kwargs['transfer_coordinator'] = self.transfer_coordinator - return task_cls(**kwargs) - - def get_transfer_future(self, call_args=None): - return TransferFuture( - meta=TransferMeta(call_args), coordinator=self.transfer_coordinator - ) - - -class BaseSubmissionTaskTest(BaseTaskTest): - def setUp(self): - super().setUp() - self.config = TransferConfig() - self.osutil = OSUtils() - self.executor = BoundedExecutor( - 1000, - 1, - { - IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10), - IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10), - }, - ) - - def tearDown(self): - super().tearDown() - self.executor.shutdown() - - -class BaseGeneralInterfaceTest(StubbedClientTest): - """A general test class to ensure consistency across TransferManger methods - - This test should never be called and should be subclassed from to pick up - the various tests that all TransferManager method must pass from a - functionality standpoint. - """ - - __test__ = False - - def manager(self): - """The transfer manager to use""" - raise NotImplementedError('method is not implemented') - - @property - def method(self): - """The transfer manager method to invoke i.e. upload()""" - raise NotImplementedError('method is not implemented') - - def create_call_kwargs(self): - """The kwargs to be passed to the transfer manager method""" - raise NotImplementedError('create_call_kwargs is not implemented') - - def create_invalid_extra_args(self): - """A value for extra_args that will cause validation errors""" - raise NotImplementedError( - 'create_invalid_extra_args is not implemented' - ) - - def create_stubbed_responses(self): - """A list of stubbed responses that will cause the request to succeed - - The elements of this list is a dictionary that will be used as key - word arguments to botocore.Stubber.add_response(). For example:: - - [{'method': 'put_object', 'service_response': {}}] - """ - raise NotImplementedError( - 'create_stubbed_responses is not implemented' - ) - - def create_expected_progress_callback_info(self): - """A list of kwargs expected to be passed to each progress callback - - Note that the future kwargs does not need to be added to each - dictionary provided in the list. This is injected for you. An example - is:: - - [ - {'bytes_transferred': 4}, - {'bytes_transferred': 4}, - {'bytes_transferred': 2} - ] - - This indicates that the progress callback will be called three - times and pass along the specified keyword arguments and corresponding - values. - """ - raise NotImplementedError( - 'create_expected_progress_callback_info is not implemented' - ) - - def _setup_default_stubbed_responses(self): - for stubbed_response in self.create_stubbed_responses(): - self.stubber.add_response(**stubbed_response) - - def test_returns_future_with_meta(self): - self._setup_default_stubbed_responses() - future = self.method(**self.create_call_kwargs()) - # The result is called so we ensure that the entire process executes - # before we try to clean up resources in the tearDown. - future.result() - - # Assert the return value is a future with metadata associated to it. - self.assertIsInstance(future, TransferFuture) - self.assertIsInstance(future.meta, TransferMeta) - - def test_returns_correct_call_args(self): - self._setup_default_stubbed_responses() - call_kwargs = self.create_call_kwargs() - future = self.method(**call_kwargs) - # The result is called so we ensure that the entire process executes - # before we try to clean up resources in the tearDown. - future.result() - - # Assert that there are call args associated to the metadata - self.assertIsInstance(future.meta.call_args, CallArgs) - # Assert that all of the arguments passed to the method exist and - # are of the correct value in call_args. - for param, value in call_kwargs.items(): - self.assertEqual(value, getattr(future.meta.call_args, param)) - - def test_has_transfer_id_associated_to_future(self): - self._setup_default_stubbed_responses() - call_kwargs = self.create_call_kwargs() - future = self.method(**call_kwargs) - # The result is called so we ensure that the entire process executes - # before we try to clean up resources in the tearDown. - future.result() - - # Assert that an transfer id was associated to the future. - # Since there is only one transfer request is made for that transfer - # manager the id will be zero since it will be the first transfer - # request made for that transfer manager. - self.assertEqual(future.meta.transfer_id, 0) - - # If we make a second request, the transfer id should have incremented - # by one for that new TransferFuture. - self._setup_default_stubbed_responses() - future = self.method(**call_kwargs) - future.result() - self.assertEqual(future.meta.transfer_id, 1) - - def test_invalid_extra_args(self): - with self.assertRaisesRegex(ValueError, 'Invalid extra_args'): - self.method( - extra_args=self.create_invalid_extra_args(), - **self.create_call_kwargs(), - ) - - def test_for_callback_kwargs_correctness(self): - # Add the stubbed responses before invoking the method - self._setup_default_stubbed_responses() - - subscriber = RecordingSubscriber() - future = self.method( - subscribers=[subscriber], **self.create_call_kwargs() - ) - # We call shutdown instead of result on future because the future - # could be finished but the done callback could still be going. - # The manager's shutdown method ensures everything completes. - self.manager.shutdown() - - # Assert the various subscribers were called with the - # expected kwargs - expected_progress_calls = self.create_expected_progress_callback_info() - for expected_progress_call in expected_progress_calls: - expected_progress_call['future'] = future - - self.assertEqual(subscriber.on_queued_calls, [{'future': future}]) - self.assertEqual(subscriber.on_progress_calls, expected_progress_calls) - self.assertEqual(subscriber.on_done_calls, [{'future': future}]) - - -class NonSeekableReader(io.RawIOBase): - def __init__(self, b=b''): - super().__init__() - self._data = io.BytesIO(b) - - def seekable(self): - return False - - def writable(self): - return False - - def readable(self): - return True - - def write(self, b): - # This is needed because python will not always return the correct - # kind of error even though writeable returns False. - raise io.UnsupportedOperation("write") - - def read(self, n=-1): - return self._data.read(n) - - -class NonSeekableWriter(io.RawIOBase): - def __init__(self, fileobj): - super().__init__() - self._fileobj = fileobj - - def seekable(self): - return False - - def writable(self): - return True - - def readable(self): - return False - - def write(self, b): - self._fileobj.write(b) - - def read(self, n=-1): - raise io.UnsupportedOperation("read") +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import hashlib +import io +import math +import os +import platform +import shutil +import string +import tempfile +import unittest +from unittest import mock # noqa: F401 + +import botocore.session +from botocore.stub import Stubber + +from s3transfer.futures import ( + IN_MEMORY_DOWNLOAD_TAG, + IN_MEMORY_UPLOAD_TAG, + BoundedExecutor, + NonThreadedExecutor, + TransferCoordinator, + TransferFuture, + TransferMeta, +) +from s3transfer.manager import TransferConfig +from s3transfer.subscribers import BaseSubscriber +from s3transfer.utils import ( + CallArgs, + OSUtils, + SlidingWindowSemaphore, + TaskSemaphore, +) + +ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS +# Detect if CRT is available for use +try: + import awscrt.s3 # noqa: F401 + + HAS_CRT = True +except ImportError: + HAS_CRT = False + + +def setup_package(): + if is_serial_implementation(): + BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor + + +def teardown_package(): + BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS + + +def is_serial_implementation(): + return os.environ.get('USE_SERIAL_EXECUTOR', False) + + +def assert_files_equal(first, second): + if os.path.getsize(first) != os.path.getsize(second): + raise AssertionError(f"Files are not equal: {first}, {second}") + first_md5 = md5_checksum(first) + second_md5 = md5_checksum(second) + if first_md5 != second_md5: + raise AssertionError( + "Files are not equal: {}(md5={}) != {}(md5={})".format( + first, first_md5, second, second_md5 + ) + ) + + +def md5_checksum(filename): + checksum = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(8192), b''): + checksum.update(chunk) + return checksum.hexdigest() + + +def random_bucket_name(prefix='s3transfer', num_chars=10): + base = string.ascii_lowercase + string.digits + random_bytes = bytearray(os.urandom(num_chars)) + return prefix + ''.join([base[b % len(base)] for b in random_bytes]) + + +def skip_if_windows(reason): + """Decorator to skip tests that should not be run on windows. + + Example usage: + + @skip_if_windows("Not valid") + def test_some_non_windows_stuff(self): + self.assertEqual(...) + + """ + + def decorator(func): + return unittest.skipIf( + platform.system() not in ['Darwin', 'Linux'], reason + )(func) + + return decorator + + +def skip_if_using_serial_implementation(reason): + """Decorator to skip tests when running as the serial implementation""" + + def decorator(func): + return unittest.skipIf(is_serial_implementation(), reason)(func) + + return decorator + + +def requires_crt(cls, reason=None): + if reason is None: + reason = "Test requires awscrt to be installed." + return unittest.skipIf(not HAS_CRT, reason)(cls) + + +class StreamWithError: + """A wrapper to simulate errors while reading from a stream + + :param stream: The underlying stream to read from + :param exception_type: The exception type to throw + :param num_reads: The number of times to allow a read before raising + the exception. A value of zero indicates to raise the error on the + first read. + """ + + def __init__(self, stream, exception_type, num_reads=0): + self._stream = stream + self._exception_type = exception_type + self._num_reads = num_reads + self._count = 0 + + def read(self, n=-1): + if self._count == self._num_reads: + raise self._exception_type + self._count += 1 + return self._stream.read(n) + + +class FileSizeProvider: + def __init__(self, file_size): + self.file_size = file_size + + def on_queued(self, future, **kwargs): + future.meta.provide_transfer_size(self.file_size) + + +class FileCreator: + def __init__(self): + self.rootdir = tempfile.mkdtemp() + + def remove_all(self): + shutil.rmtree(self.rootdir) + + def create_file(self, filename, contents, mode='w'): + """Creates a file in a tmpdir + ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" + It will be translated into a full path in a tmp dir. + ``mode`` is the mode the file should be opened either as ``w`` or + `wb``. + Returns the full path to the file. + """ + full_path = os.path.join(self.rootdir, filename) + if not os.path.isdir(os.path.dirname(full_path)): + os.makedirs(os.path.dirname(full_path)) + with open(full_path, mode) as f: + f.write(contents) + return full_path + + def create_file_with_size(self, filename, filesize): + filename = self.create_file(filename, contents='') + chunksize = 8192 + with open(filename, 'wb') as f: + for i in range(int(math.ceil(filesize / float(chunksize)))): + f.write(b'a' * chunksize) + return filename + + def append_file(self, filename, contents): + """Append contents to a file + ``filename`` should be a relative path, e.g. "foo/bar/baz.txt" + It will be translated into a full path in a tmp dir. + Returns the full path to the file. + """ + full_path = os.path.join(self.rootdir, filename) + if not os.path.isdir(os.path.dirname(full_path)): + os.makedirs(os.path.dirname(full_path)) + with open(full_path, 'a') as f: + f.write(contents) + return full_path + + def full_path(self, filename): + """Translate relative path to full path in temp dir. + f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt + """ + return os.path.join(self.rootdir, filename) + + +class RecordingOSUtils(OSUtils): + """An OSUtil abstraction that records openings and renamings""" + + def __init__(self): + super().__init__() + self.open_records = [] + self.rename_records = [] + + def open(self, filename, mode): + self.open_records.append((filename, mode)) + return super().open(filename, mode) + + def rename_file(self, current_filename, new_filename): + self.rename_records.append((current_filename, new_filename)) + super().rename_file(current_filename, new_filename) + + +class RecordingSubscriber(BaseSubscriber): + def __init__(self): + self.on_queued_calls = [] + self.on_progress_calls = [] + self.on_done_calls = [] + + def on_queued(self, **kwargs): + self.on_queued_calls.append(kwargs) + + def on_progress(self, **kwargs): + self.on_progress_calls.append(kwargs) + + def on_done(self, **kwargs): + self.on_done_calls.append(kwargs) + + def calculate_bytes_seen(self, **kwargs): + amount_seen = 0 + for call in self.on_progress_calls: + amount_seen += call['bytes_transferred'] + return amount_seen + + +class TransferCoordinatorWithInterrupt(TransferCoordinator): + """Used to inject keyboard interrupts""" + + def result(self): + raise KeyboardInterrupt() + + +class RecordingExecutor: + """A wrapper on an executor to record calls made to submit() + + You can access the submissions property to receive a list of dictionaries + that represents all submissions where the dictionary is formatted:: + + { + 'fn': function + 'args': positional args (as tuple) + 'kwargs': keyword args (as dict) + } + """ + + def __init__(self, executor): + self._executor = executor + self.submissions = [] + + def submit(self, task, tag=None, block=True): + future = self._executor.submit(task, tag, block) + self.submissions.append({'task': task, 'tag': tag, 'block': block}) + return future + + def shutdown(self): + self._executor.shutdown() + + +class StubbedClientTest(unittest.TestCase): + def setUp(self): + self.session = botocore.session.get_session() + self.region = 'us-west-2' + self.client = self.session.create_client( + 's3', + self.region, + aws_access_key_id='foo', + aws_secret_access_key='bar', + ) + self.stubber = Stubber(self.client) + self.stubber.activate() + + def tearDown(self): + self.stubber.deactivate() + + def reset_stubber_with_new_client(self, override_client_kwargs): + client_kwargs = { + 'service_name': 's3', + 'region_name': self.region, + 'aws_access_key_id': 'foo', + 'aws_secret_access_key': 'bar', + } + client_kwargs.update(override_client_kwargs) + self.client = self.session.create_client(**client_kwargs) + self.stubber = Stubber(self.client) + self.stubber.activate() + + +class BaseTaskTest(StubbedClientTest): + def setUp(self): + super().setUp() + self.transfer_coordinator = TransferCoordinator() + + def get_task(self, task_cls, **kwargs): + if 'transfer_coordinator' not in kwargs: + kwargs['transfer_coordinator'] = self.transfer_coordinator + return task_cls(**kwargs) + + def get_transfer_future(self, call_args=None): + return TransferFuture( + meta=TransferMeta(call_args), coordinator=self.transfer_coordinator + ) + + +class BaseSubmissionTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.config = TransferConfig() + self.osutil = OSUtils() + self.executor = BoundedExecutor( + 1000, + 1, + { + IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10), + IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10), + }, + ) + + def tearDown(self): + super().tearDown() + self.executor.shutdown() + + +class BaseGeneralInterfaceTest(StubbedClientTest): + """A general test class to ensure consistency across TransferManger methods + + This test should never be called and should be subclassed from to pick up + the various tests that all TransferManager method must pass from a + functionality standpoint. + """ + + __test__ = False + + def manager(self): + """The transfer manager to use""" + raise NotImplementedError('method is not implemented') + + @property + def method(self): + """The transfer manager method to invoke i.e. upload()""" + raise NotImplementedError('method is not implemented') + + def create_call_kwargs(self): + """The kwargs to be passed to the transfer manager method""" + raise NotImplementedError('create_call_kwargs is not implemented') + + def create_invalid_extra_args(self): + """A value for extra_args that will cause validation errors""" + raise NotImplementedError( + 'create_invalid_extra_args is not implemented' + ) + + def create_stubbed_responses(self): + """A list of stubbed responses that will cause the request to succeed + + The elements of this list is a dictionary that will be used as key + word arguments to botocore.Stubber.add_response(). For example:: + + [{'method': 'put_object', 'service_response': {}}] + """ + raise NotImplementedError( + 'create_stubbed_responses is not implemented' + ) + + def create_expected_progress_callback_info(self): + """A list of kwargs expected to be passed to each progress callback + + Note that the future kwargs does not need to be added to each + dictionary provided in the list. This is injected for you. An example + is:: + + [ + {'bytes_transferred': 4}, + {'bytes_transferred': 4}, + {'bytes_transferred': 2} + ] + + This indicates that the progress callback will be called three + times and pass along the specified keyword arguments and corresponding + values. + """ + raise NotImplementedError( + 'create_expected_progress_callback_info is not implemented' + ) + + def _setup_default_stubbed_responses(self): + for stubbed_response in self.create_stubbed_responses(): + self.stubber.add_response(**stubbed_response) + + def test_returns_future_with_meta(self): + self._setup_default_stubbed_responses() + future = self.method(**self.create_call_kwargs()) + # The result is called so we ensure that the entire process executes + # before we try to clean up resources in the tearDown. + future.result() + + # Assert the return value is a future with metadata associated to it. + self.assertIsInstance(future, TransferFuture) + self.assertIsInstance(future.meta, TransferMeta) + + def test_returns_correct_call_args(self): + self._setup_default_stubbed_responses() + call_kwargs = self.create_call_kwargs() + future = self.method(**call_kwargs) + # The result is called so we ensure that the entire process executes + # before we try to clean up resources in the tearDown. + future.result() + + # Assert that there are call args associated to the metadata + self.assertIsInstance(future.meta.call_args, CallArgs) + # Assert that all of the arguments passed to the method exist and + # are of the correct value in call_args. + for param, value in call_kwargs.items(): + self.assertEqual(value, getattr(future.meta.call_args, param)) + + def test_has_transfer_id_associated_to_future(self): + self._setup_default_stubbed_responses() + call_kwargs = self.create_call_kwargs() + future = self.method(**call_kwargs) + # The result is called so we ensure that the entire process executes + # before we try to clean up resources in the tearDown. + future.result() + + # Assert that an transfer id was associated to the future. + # Since there is only one transfer request is made for that transfer + # manager the id will be zero since it will be the first transfer + # request made for that transfer manager. + self.assertEqual(future.meta.transfer_id, 0) + + # If we make a second request, the transfer id should have incremented + # by one for that new TransferFuture. + self._setup_default_stubbed_responses() + future = self.method(**call_kwargs) + future.result() + self.assertEqual(future.meta.transfer_id, 1) + + def test_invalid_extra_args(self): + with self.assertRaisesRegex(ValueError, 'Invalid extra_args'): + self.method( + extra_args=self.create_invalid_extra_args(), + **self.create_call_kwargs(), + ) + + def test_for_callback_kwargs_correctness(self): + # Add the stubbed responses before invoking the method + self._setup_default_stubbed_responses() + + subscriber = RecordingSubscriber() + future = self.method( + subscribers=[subscriber], **self.create_call_kwargs() + ) + # We call shutdown instead of result on future because the future + # could be finished but the done callback could still be going. + # The manager's shutdown method ensures everything completes. + self.manager.shutdown() + + # Assert the various subscribers were called with the + # expected kwargs + expected_progress_calls = self.create_expected_progress_callback_info() + for expected_progress_call in expected_progress_calls: + expected_progress_call['future'] = future + + self.assertEqual(subscriber.on_queued_calls, [{'future': future}]) + self.assertEqual(subscriber.on_progress_calls, expected_progress_calls) + self.assertEqual(subscriber.on_done_calls, [{'future': future}]) + + +class NonSeekableReader(io.RawIOBase): + def __init__(self, b=b''): + super().__init__() + self._data = io.BytesIO(b) + + def seekable(self): + return False + + def writable(self): + return False + + def readable(self): + return True + + def write(self, b): + # This is needed because python will not always return the correct + # kind of error even though writeable returns False. + raise io.UnsupportedOperation("write") + + def read(self, n=-1): + return self._data.read(n) + + +class NonSeekableWriter(io.RawIOBase): + def __init__(self, fileobj): + super().__init__() + self._fileobj = fileobj + + def seekable(self): + return False + + def writable(self): + return True + + def readable(self): + return False + + def write(self, b): + self._fileobj.write(b) + + def read(self, n=-1): + raise io.UnsupportedOperation("read") diff --git a/contrib/python/s3transfer/py3/tests/functional/__init__.py b/contrib/python/s3transfer/py3/tests/functional/__init__.py index d69af8958d..fa58dbdb55 100644 --- a/contrib/python/s3transfer/py3/tests/functional/__init__.py +++ b/contrib/python/s3transfer/py3/tests/functional/__init__.py @@ -1,12 +1,12 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/contrib/python/s3transfer/py3/tests/functional/test_copy.py b/contrib/python/s3transfer/py3/tests/functional/test_copy.py index 5f6aeabf83..801c9003bb 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_copy.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_copy.py @@ -1,554 +1,554 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from botocore.exceptions import ClientError -from botocore.stub import Stubber - -from s3transfer.manager import TransferConfig, TransferManager -from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE -from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider - - -class BaseCopyTest(BaseGeneralInterfaceTest): - def setUp(self): - super().setUp() - self.config = TransferConfig( - max_request_concurrency=1, - multipart_chunksize=MIN_UPLOAD_CHUNKSIZE, - multipart_threshold=MIN_UPLOAD_CHUNKSIZE * 4, - ) - self._manager = TransferManager(self.client, self.config) - - # Initialize some default arguments - self.bucket = 'mybucket' - self.key = 'mykey' - self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} - self.extra_args = {} - self.subscribers = [] - - self.half_chunksize = int(MIN_UPLOAD_CHUNKSIZE / 2) - self.content = b'0' * (2 * MIN_UPLOAD_CHUNKSIZE + self.half_chunksize) - - @property - def manager(self): - return self._manager - - @property - def method(self): - return self.manager.copy - - def create_call_kwargs(self): - return { - 'copy_source': self.copy_source, - 'bucket': self.bucket, - 'key': self.key, - } - - def create_invalid_extra_args(self): - return {'Foo': 'bar'} - - def create_stubbed_responses(self): - return [ - { - 'method': 'head_object', - 'service_response': {'ContentLength': len(self.content)}, - }, - {'method': 'copy_object', 'service_response': {}}, - ] - - def create_expected_progress_callback_info(self): - return [ - {'bytes_transferred': len(self.content)}, - ] - - def add_head_object_response(self, expected_params=None, stubber=None): - if not stubber: - stubber = self.stubber - head_response = self.create_stubbed_responses()[0] - if expected_params: - head_response['expected_params'] = expected_params - stubber.add_response(**head_response) - - def add_successful_copy_responses( - self, - expected_copy_params=None, - expected_create_mpu_params=None, - expected_complete_mpu_params=None, - ): - - # Add all responses needed to do the copy of the object. - # Should account for both ranged and nonranged downloads. - stubbed_responses = self.create_stubbed_responses()[1:] - - # If the length of copy responses is greater than one then it is - # a multipart copy. - copy_responses = stubbed_responses[0:1] - if len(stubbed_responses) > 1: - copy_responses = stubbed_responses[1:-1] - - # Add the expected create multipart upload params. - if expected_create_mpu_params: - stubbed_responses[0][ - 'expected_params' - ] = expected_create_mpu_params - - # Add any expected copy parameters. - if expected_copy_params: - for i, copy_response in enumerate(copy_responses): - if isinstance(expected_copy_params, list): - copy_response['expected_params'] = expected_copy_params[i] - else: - copy_response['expected_params'] = expected_copy_params - - # Add the expected complete multipart upload params. - if expected_complete_mpu_params: - stubbed_responses[-1][ - 'expected_params' - ] = expected_complete_mpu_params - - # Add the responses to the stubber. - for stubbed_response in stubbed_responses: - self.stubber.add_response(**stubbed_response) - - def test_can_provide_file_size(self): - self.add_successful_copy_responses() - - call_kwargs = self.create_call_kwargs() - call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))] - - future = self.manager.copy(**call_kwargs) - future.result() - - # The HeadObject should have not happened and should have been able - # to successfully copy the file. - self.stubber.assert_no_pending_responses() - - def test_provide_copy_source_as_dict(self): - self.copy_source['VersionId'] = 'mysourceversionid' - expected_params = { - 'Bucket': 'mysourcebucket', - 'Key': 'mysourcekey', - 'VersionId': 'mysourceversionid', - } - - self.add_head_object_response(expected_params=expected_params) - self.add_successful_copy_responses() - - future = self.manager.copy(**self.create_call_kwargs()) - future.result() - self.stubber.assert_no_pending_responses() - - def test_invalid_copy_source(self): - self.copy_source = ['bucket', 'key'] - future = self.manager.copy(**self.create_call_kwargs()) - with self.assertRaises(TypeError): - future.result() - - def test_provide_copy_source_client(self): - source_client = self.session.create_client( - 's3', - 'eu-central-1', - aws_access_key_id='foo', - aws_secret_access_key='bar', - ) - source_stubber = Stubber(source_client) - source_stubber.activate() - self.addCleanup(source_stubber.deactivate) - - self.add_head_object_response(stubber=source_stubber) - self.add_successful_copy_responses() - - call_kwargs = self.create_call_kwargs() - call_kwargs['source_client'] = source_client - future = self.manager.copy(**call_kwargs) - future.result() - - # Make sure that all of the responses were properly - # used for both clients. - source_stubber.assert_no_pending_responses() - self.stubber.assert_no_pending_responses() - - -class TestNonMultipartCopy(BaseCopyTest): - __test__ = True - - def test_copy(self): - expected_head_params = { - 'Bucket': 'mysourcebucket', - 'Key': 'mysourcekey', - } - expected_copy_object = { - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - } - self.add_head_object_response(expected_params=expected_head_params) - self.add_successful_copy_responses( - expected_copy_params=expected_copy_object - ) - - future = self.manager.copy(**self.create_call_kwargs()) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_with_extra_args(self): - self.extra_args['MetadataDirective'] = 'REPLACE' - - expected_head_params = { - 'Bucket': 'mysourcebucket', - 'Key': 'mysourcekey', - } - expected_copy_object = { - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'MetadataDirective': 'REPLACE', - } - - self.add_head_object_response(expected_params=expected_head_params) - self.add_successful_copy_responses( - expected_copy_params=expected_copy_object - ) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_maps_extra_args_to_head_object(self): - self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256' - - expected_head_params = { - 'Bucket': 'mysourcebucket', - 'Key': 'mysourcekey', - 'SSECustomerAlgorithm': 'AES256', - } - expected_copy_object = { - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'CopySourceSSECustomerAlgorithm': 'AES256', - } - - self.add_head_object_response(expected_params=expected_head_params) - self.add_successful_copy_responses( - expected_copy_params=expected_copy_object - ) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_allowed_copy_params_are_valid(self): - op_model = self.client.meta.service_model.operation_model('CopyObject') - for allowed_upload_arg in self._manager.ALLOWED_COPY_ARGS: - self.assertIn(allowed_upload_arg, op_model.input_shape.members) - - def test_copy_with_tagging(self): - extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'} - self.add_head_object_response() - self.add_successful_copy_responses( - expected_copy_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'Tagging': 'tag1=val1', - 'TaggingDirective': 'REPLACE', - } - ) - future = self.manager.copy( - self.copy_source, self.bucket, self.key, extra_args - ) - future.result() - self.stubber.assert_no_pending_responses() - - def test_raise_exception_on_s3_object_lambda_resource(self): - s3_object_lambda_arn = ( - 'arn:aws:s3-object-lambda:us-west-2:123456789012:' - 'accesspoint:my-accesspoint' - ) - with self.assertRaisesRegex(ValueError, 'methods do not support'): - self.manager.copy(self.copy_source, s3_object_lambda_arn, self.key) - - def test_raise_exception_on_s3_object_lambda_resource_as_source(self): - source = { - 'Bucket': 'arn:aws:s3-object-lambda:us-west-2:123456789012:' - 'accesspoint:my-accesspoint' - } - with self.assertRaisesRegex(ValueError, 'methods do not support'): - self.manager.copy(source, self.bucket, self.key) - - -class TestMultipartCopy(BaseCopyTest): - __test__ = True - - def setUp(self): - super().setUp() - self.config = TransferConfig( - max_request_concurrency=1, - multipart_threshold=1, - multipart_chunksize=4, - ) - self._manager = TransferManager(self.client, self.config) - - def create_stubbed_responses(self): - return [ - { - 'method': 'head_object', - 'service_response': {'ContentLength': len(self.content)}, - }, - { - 'method': 'create_multipart_upload', - 'service_response': {'UploadId': 'my-upload-id'}, - }, - { - 'method': 'upload_part_copy', - 'service_response': {'CopyPartResult': {'ETag': 'etag-1'}}, - }, - { - 'method': 'upload_part_copy', - 'service_response': {'CopyPartResult': {'ETag': 'etag-2'}}, - }, - { - 'method': 'upload_part_copy', - 'service_response': {'CopyPartResult': {'ETag': 'etag-3'}}, - }, - {'method': 'complete_multipart_upload', 'service_response': {}}, - ] - - def create_expected_progress_callback_info(self): - # Note that last read is from the empty sentinel indicating - # that the stream is done. - return [ - {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE}, - {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE}, - {'bytes_transferred': self.half_chunksize}, - ] - - def add_create_multipart_upload_response(self): - self.stubber.add_response(**self.create_stubbed_responses()[1]) - - def _get_expected_params(self): - upload_id = 'my-upload-id' - - # Add expected parameters to the head object - expected_head_params = { - 'Bucket': 'mysourcebucket', - 'Key': 'mysourcekey', - } - - # Add expected parameters for the create multipart - expected_create_mpu_params = { - 'Bucket': self.bucket, - 'Key': self.key, - } - - expected_copy_params = [] - # Add expected parameters to the copy part - ranges = [ - 'bytes=0-5242879', - 'bytes=5242880-10485759', - 'bytes=10485760-13107199', - ] - for i, range_val in enumerate(ranges): - expected_copy_params.append( - { - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'UploadId': upload_id, - 'PartNumber': i + 1, - 'CopySourceRange': range_val, - } - ) - - # Add expected parameters for the complete multipart - expected_complete_mpu_params = { - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': upload_id, - 'MultipartUpload': { - 'Parts': [ - {'ETag': 'etag-1', 'PartNumber': 1}, - {'ETag': 'etag-2', 'PartNumber': 2}, - {'ETag': 'etag-3', 'PartNumber': 3}, - ] - }, - } - - return expected_head_params, { - 'expected_create_mpu_params': expected_create_mpu_params, - 'expected_copy_params': expected_copy_params, - 'expected_complete_mpu_params': expected_complete_mpu_params, - } - - def _add_params_to_expected_params( - self, add_copy_kwargs, operation_types, new_params - ): - - expected_params_to_update = [] - for operation_type in operation_types: - add_copy_kwargs_key = 'expected_' + operation_type + '_params' - expected_params = add_copy_kwargs[add_copy_kwargs_key] - if isinstance(expected_params, list): - expected_params_to_update.extend(expected_params) - else: - expected_params_to_update.append(expected_params) - - for expected_params in expected_params_to_update: - expected_params.update(new_params) - - def test_copy(self): - head_params, add_copy_kwargs = self._get_expected_params() - self.add_head_object_response(expected_params=head_params) - self.add_successful_copy_responses(**add_copy_kwargs) - - future = self.manager.copy(**self.create_call_kwargs()) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_with_extra_args(self): - # This extra argument should be added to the head object, - # the create multipart upload, and upload part copy. - self.extra_args['RequestPayer'] = 'requester' - - head_params, add_copy_kwargs = self._get_expected_params() - head_params.update(self.extra_args) - self.add_head_object_response(expected_params=head_params) - - self._add_params_to_expected_params( - add_copy_kwargs, - ['create_mpu', 'copy', 'complete_mpu'], - self.extra_args, - ) - self.add_successful_copy_responses(**add_copy_kwargs) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_blacklists_args_to_create_multipart(self): - # This argument can never be used for multipart uploads - self.extra_args['MetadataDirective'] = 'COPY' - - head_params, add_copy_kwargs = self._get_expected_params() - self.add_head_object_response(expected_params=head_params) - self.add_successful_copy_responses(**add_copy_kwargs) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_args_to_only_create_multipart(self): - self.extra_args['ACL'] = 'private' - - head_params, add_copy_kwargs = self._get_expected_params() - self.add_head_object_response(expected_params=head_params) - - self._add_params_to_expected_params( - add_copy_kwargs, ['create_mpu'], self.extra_args - ) - self.add_successful_copy_responses(**add_copy_kwargs) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_passes_args_to_create_multipart_and_upload_part(self): - # This will only be used for the complete multipart upload - # and upload part. - self.extra_args['SSECustomerAlgorithm'] = 'AES256' - - head_params, add_copy_kwargs = self._get_expected_params() - self.add_head_object_response(expected_params=head_params) - - self._add_params_to_expected_params( - add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args - ) - self.add_successful_copy_responses(**add_copy_kwargs) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_copy_maps_extra_args_to_head_object(self): - self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256' - - head_params, add_copy_kwargs = self._get_expected_params() - - # The CopySourceSSECustomerAlgorithm needs to get mapped to - # SSECustomerAlgorithm for HeadObject - head_params['SSECustomerAlgorithm'] = 'AES256' - self.add_head_object_response(expected_params=head_params) - - # However, it needs to remain the same for UploadPartCopy. - self._add_params_to_expected_params( - add_copy_kwargs, ['copy'], self.extra_args - ) - self.add_successful_copy_responses(**add_copy_kwargs) - - call_kwargs = self.create_call_kwargs() - call_kwargs['extra_args'] = self.extra_args - future = self.manager.copy(**call_kwargs) - future.result() - self.stubber.assert_no_pending_responses() - - def test_abort_on_failure(self): - # First add the head object and create multipart upload - self.add_head_object_response() - self.add_create_multipart_upload_response() - - # Cause an error on upload_part_copy - self.stubber.add_client_error('upload_part_copy', 'ArbitraryFailure') - - # Add the abort multipart to ensure it gets cleaned up on failure - self.stubber.add_response( - 'abort_multipart_upload', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': 'my-upload-id', - }, - ) - - future = self.manager.copy(**self.create_call_kwargs()) - with self.assertRaisesRegex(ClientError, 'ArbitraryFailure'): - future.result() - self.stubber.assert_no_pending_responses() - - def test_mp_copy_with_tagging_directive(self): - extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'} - self.add_head_object_response() - self.add_successful_copy_responses( - expected_create_mpu_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'Tagging': 'tag1=val1', - } - ) - future = self.manager.copy( - self.copy_source, self.bucket, self.key, extra_args - ) - future.result() - self.stubber.assert_no_pending_responses() +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.exceptions import ClientError +from botocore.stub import Stubber + +from s3transfer.manager import TransferConfig, TransferManager +from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE +from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider + + +class BaseCopyTest(BaseGeneralInterfaceTest): + def setUp(self): + super().setUp() + self.config = TransferConfig( + max_request_concurrency=1, + multipart_chunksize=MIN_UPLOAD_CHUNKSIZE, + multipart_threshold=MIN_UPLOAD_CHUNKSIZE * 4, + ) + self._manager = TransferManager(self.client, self.config) + + # Initialize some default arguments + self.bucket = 'mybucket' + self.key = 'mykey' + self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} + self.extra_args = {} + self.subscribers = [] + + self.half_chunksize = int(MIN_UPLOAD_CHUNKSIZE / 2) + self.content = b'0' * (2 * MIN_UPLOAD_CHUNKSIZE + self.half_chunksize) + + @property + def manager(self): + return self._manager + + @property + def method(self): + return self.manager.copy + + def create_call_kwargs(self): + return { + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + } + + def create_invalid_extra_args(self): + return {'Foo': 'bar'} + + def create_stubbed_responses(self): + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + {'method': 'copy_object', 'service_response': {}}, + ] + + def create_expected_progress_callback_info(self): + return [ + {'bytes_transferred': len(self.content)}, + ] + + def add_head_object_response(self, expected_params=None, stubber=None): + if not stubber: + stubber = self.stubber + head_response = self.create_stubbed_responses()[0] + if expected_params: + head_response['expected_params'] = expected_params + stubber.add_response(**head_response) + + def add_successful_copy_responses( + self, + expected_copy_params=None, + expected_create_mpu_params=None, + expected_complete_mpu_params=None, + ): + + # Add all responses needed to do the copy of the object. + # Should account for both ranged and nonranged downloads. + stubbed_responses = self.create_stubbed_responses()[1:] + + # If the length of copy responses is greater than one then it is + # a multipart copy. + copy_responses = stubbed_responses[0:1] + if len(stubbed_responses) > 1: + copy_responses = stubbed_responses[1:-1] + + # Add the expected create multipart upload params. + if expected_create_mpu_params: + stubbed_responses[0][ + 'expected_params' + ] = expected_create_mpu_params + + # Add any expected copy parameters. + if expected_copy_params: + for i, copy_response in enumerate(copy_responses): + if isinstance(expected_copy_params, list): + copy_response['expected_params'] = expected_copy_params[i] + else: + copy_response['expected_params'] = expected_copy_params + + # Add the expected complete multipart upload params. + if expected_complete_mpu_params: + stubbed_responses[-1][ + 'expected_params' + ] = expected_complete_mpu_params + + # Add the responses to the stubber. + for stubbed_response in stubbed_responses: + self.stubber.add_response(**stubbed_response) + + def test_can_provide_file_size(self): + self.add_successful_copy_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))] + + future = self.manager.copy(**call_kwargs) + future.result() + + # The HeadObject should have not happened and should have been able + # to successfully copy the file. + self.stubber.assert_no_pending_responses() + + def test_provide_copy_source_as_dict(self): + self.copy_source['VersionId'] = 'mysourceversionid' + expected_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + 'VersionId': 'mysourceversionid', + } + + self.add_head_object_response(expected_params=expected_params) + self.add_successful_copy_responses() + + future = self.manager.copy(**self.create_call_kwargs()) + future.result() + self.stubber.assert_no_pending_responses() + + def test_invalid_copy_source(self): + self.copy_source = ['bucket', 'key'] + future = self.manager.copy(**self.create_call_kwargs()) + with self.assertRaises(TypeError): + future.result() + + def test_provide_copy_source_client(self): + source_client = self.session.create_client( + 's3', + 'eu-central-1', + aws_access_key_id='foo', + aws_secret_access_key='bar', + ) + source_stubber = Stubber(source_client) + source_stubber.activate() + self.addCleanup(source_stubber.deactivate) + + self.add_head_object_response(stubber=source_stubber) + self.add_successful_copy_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['source_client'] = source_client + future = self.manager.copy(**call_kwargs) + future.result() + + # Make sure that all of the responses were properly + # used for both clients. + source_stubber.assert_no_pending_responses() + self.stubber.assert_no_pending_responses() + + +class TestNonMultipartCopy(BaseCopyTest): + __test__ = True + + def test_copy(self): + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + } + expected_copy_object = { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + } + self.add_head_object_response(expected_params=expected_head_params) + self.add_successful_copy_responses( + expected_copy_params=expected_copy_object + ) + + future = self.manager.copy(**self.create_call_kwargs()) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_with_extra_args(self): + self.extra_args['MetadataDirective'] = 'REPLACE' + + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + } + expected_copy_object = { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'MetadataDirective': 'REPLACE', + } + + self.add_head_object_response(expected_params=expected_head_params) + self.add_successful_copy_responses( + expected_copy_params=expected_copy_object + ) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_maps_extra_args_to_head_object(self): + self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256' + + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + 'SSECustomerAlgorithm': 'AES256', + } + expected_copy_object = { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'CopySourceSSECustomerAlgorithm': 'AES256', + } + + self.add_head_object_response(expected_params=expected_head_params) + self.add_successful_copy_responses( + expected_copy_params=expected_copy_object + ) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_allowed_copy_params_are_valid(self): + op_model = self.client.meta.service_model.operation_model('CopyObject') + for allowed_upload_arg in self._manager.ALLOWED_COPY_ARGS: + self.assertIn(allowed_upload_arg, op_model.input_shape.members) + + def test_copy_with_tagging(self): + extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'} + self.add_head_object_response() + self.add_successful_copy_responses( + expected_copy_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'Tagging': 'tag1=val1', + 'TaggingDirective': 'REPLACE', + } + ) + future = self.manager.copy( + self.copy_source, self.bucket, self.key, extra_args + ) + future.result() + self.stubber.assert_no_pending_responses() + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.copy(self.copy_source, s3_object_lambda_arn, self.key) + + def test_raise_exception_on_s3_object_lambda_resource_as_source(self): + source = { + 'Bucket': 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + } + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.copy(source, self.bucket, self.key) + + +class TestMultipartCopy(BaseCopyTest): + __test__ = True + + def setUp(self): + super().setUp() + self.config = TransferConfig( + max_request_concurrency=1, + multipart_threshold=1, + multipart_chunksize=4, + ) + self._manager = TransferManager(self.client, self.config) + + def create_stubbed_responses(self): + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + { + 'method': 'create_multipart_upload', + 'service_response': {'UploadId': 'my-upload-id'}, + }, + { + 'method': 'upload_part_copy', + 'service_response': {'CopyPartResult': {'ETag': 'etag-1'}}, + }, + { + 'method': 'upload_part_copy', + 'service_response': {'CopyPartResult': {'ETag': 'etag-2'}}, + }, + { + 'method': 'upload_part_copy', + 'service_response': {'CopyPartResult': {'ETag': 'etag-3'}}, + }, + {'method': 'complete_multipart_upload', 'service_response': {}}, + ] + + def create_expected_progress_callback_info(self): + # Note that last read is from the empty sentinel indicating + # that the stream is done. + return [ + {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE}, + {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE}, + {'bytes_transferred': self.half_chunksize}, + ] + + def add_create_multipart_upload_response(self): + self.stubber.add_response(**self.create_stubbed_responses()[1]) + + def _get_expected_params(self): + upload_id = 'my-upload-id' + + # Add expected parameters to the head object + expected_head_params = { + 'Bucket': 'mysourcebucket', + 'Key': 'mysourcekey', + } + + # Add expected parameters for the create multipart + expected_create_mpu_params = { + 'Bucket': self.bucket, + 'Key': self.key, + } + + expected_copy_params = [] + # Add expected parameters to the copy part + ranges = [ + 'bytes=0-5242879', + 'bytes=5242880-10485759', + 'bytes=10485760-13107199', + ] + for i, range_val in enumerate(ranges): + expected_copy_params.append( + { + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': upload_id, + 'PartNumber': i + 1, + 'CopySourceRange': range_val, + } + ) + + # Add expected parameters for the complete multipart + expected_complete_mpu_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': { + 'Parts': [ + {'ETag': 'etag-1', 'PartNumber': 1}, + {'ETag': 'etag-2', 'PartNumber': 2}, + {'ETag': 'etag-3', 'PartNumber': 3}, + ] + }, + } + + return expected_head_params, { + 'expected_create_mpu_params': expected_create_mpu_params, + 'expected_copy_params': expected_copy_params, + 'expected_complete_mpu_params': expected_complete_mpu_params, + } + + def _add_params_to_expected_params( + self, add_copy_kwargs, operation_types, new_params + ): + + expected_params_to_update = [] + for operation_type in operation_types: + add_copy_kwargs_key = 'expected_' + operation_type + '_params' + expected_params = add_copy_kwargs[add_copy_kwargs_key] + if isinstance(expected_params, list): + expected_params_to_update.extend(expected_params) + else: + expected_params_to_update.append(expected_params) + + for expected_params in expected_params_to_update: + expected_params.update(new_params) + + def test_copy(self): + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + self.add_successful_copy_responses(**add_copy_kwargs) + + future = self.manager.copy(**self.create_call_kwargs()) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_with_extra_args(self): + # This extra argument should be added to the head object, + # the create multipart upload, and upload part copy. + self.extra_args['RequestPayer'] = 'requester' + + head_params, add_copy_kwargs = self._get_expected_params() + head_params.update(self.extra_args) + self.add_head_object_response(expected_params=head_params) + + self._add_params_to_expected_params( + add_copy_kwargs, + ['create_mpu', 'copy', 'complete_mpu'], + self.extra_args, + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_blacklists_args_to_create_multipart(self): + # This argument can never be used for multipart uploads + self.extra_args['MetadataDirective'] = 'COPY' + + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_args_to_only_create_multipart(self): + self.extra_args['ACL'] = 'private' + + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + + self._add_params_to_expected_params( + add_copy_kwargs, ['create_mpu'], self.extra_args + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_passes_args_to_create_multipart_and_upload_part(self): + # This will only be used for the complete multipart upload + # and upload part. + self.extra_args['SSECustomerAlgorithm'] = 'AES256' + + head_params, add_copy_kwargs = self._get_expected_params() + self.add_head_object_response(expected_params=head_params) + + self._add_params_to_expected_params( + add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_copy_maps_extra_args_to_head_object(self): + self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256' + + head_params, add_copy_kwargs = self._get_expected_params() + + # The CopySourceSSECustomerAlgorithm needs to get mapped to + # SSECustomerAlgorithm for HeadObject + head_params['SSECustomerAlgorithm'] = 'AES256' + self.add_head_object_response(expected_params=head_params) + + # However, it needs to remain the same for UploadPartCopy. + self._add_params_to_expected_params( + add_copy_kwargs, ['copy'], self.extra_args + ) + self.add_successful_copy_responses(**add_copy_kwargs) + + call_kwargs = self.create_call_kwargs() + call_kwargs['extra_args'] = self.extra_args + future = self.manager.copy(**call_kwargs) + future.result() + self.stubber.assert_no_pending_responses() + + def test_abort_on_failure(self): + # First add the head object and create multipart upload + self.add_head_object_response() + self.add_create_multipart_upload_response() + + # Cause an error on upload_part_copy + self.stubber.add_client_error('upload_part_copy', 'ArbitraryFailure') + + # Add the abort multipart to ensure it gets cleaned up on failure + self.stubber.add_response( + 'abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': 'my-upload-id', + }, + ) + + future = self.manager.copy(**self.create_call_kwargs()) + with self.assertRaisesRegex(ClientError, 'ArbitraryFailure'): + future.result() + self.stubber.assert_no_pending_responses() + + def test_mp_copy_with_tagging_directive(self): + extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'} + self.add_head_object_response() + self.add_successful_copy_responses( + expected_create_mpu_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Tagging': 'tag1=val1', + } + ) + future = self.manager.copy( + self.copy_source, self.bucket, self.key, extra_args + ) + future.result() + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_crt.py b/contrib/python/s3transfer/py3/tests/functional/test_crt.py index 9ebd48650d..fad0f4b23b 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_crt.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_crt.py @@ -1,267 +1,267 @@ -# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import fnmatch -import threading -import time -from concurrent.futures import Future - -from botocore.session import Session - -from s3transfer.subscribers import BaseSubscriber -from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest - -if HAS_CRT: - import awscrt - - import s3transfer.crt - - -class submitThread(threading.Thread): - def __init__(self, transfer_manager, futures, callargs): - threading.Thread.__init__(self) - self._transfer_manager = transfer_manager - self._futures = futures - self._callargs = callargs - - def run(self): - self._futures.append(self._transfer_manager.download(*self._callargs)) - - -class RecordingSubscriber(BaseSubscriber): - def __init__(self): - self.on_queued_called = False - self.on_done_called = False - self.bytes_transferred = 0 - self.on_queued_future = None - self.on_done_future = None - - def on_queued(self, future, **kwargs): - self.on_queued_called = True - self.on_queued_future = future - - def on_done(self, future, **kwargs): - self.on_done_called = True - self.on_done_future = future - - -@requires_crt -class TestCRTTransferManager(unittest.TestCase): - def setUp(self): - self.region = 'us-west-2' - self.bucket = "test_bucket" - self.key = "test_key" - self.files = FileCreator() - self.filename = self.files.create_file('myfile', 'my content') - self.expected_path = "/" + self.bucket + "/" + self.key - self.expected_host = "s3.%s.amazonaws.com" % (self.region) - self.s3_request = mock.Mock(awscrt.s3.S3Request) - self.s3_crt_client = mock.Mock(awscrt.s3.S3Client) - self.s3_crt_client.make_request.return_value = self.s3_request - self.session = Session() - self.session.set_config_variable('region', self.region) - self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( - self.session - ) - self.transfer_manager = s3transfer.crt.CRTTransferManager( - crt_s3_client=self.s3_crt_client, - crt_request_serializer=self.request_serializer, - ) - self.record_subscriber = RecordingSubscriber() - - def tearDown(self): - self.files.remove_all() - - def _assert_subscribers_called(self, expected_future=None): - self.assertTrue(self.record_subscriber.on_queued_called) - self.assertTrue(self.record_subscriber.on_done_called) - if expected_future: - self.assertIs( - self.record_subscriber.on_queued_future, expected_future - ) - self.assertIs( - self.record_subscriber.on_done_future, expected_future - ) - - def _invoke_done_callbacks(self, **kwargs): - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - on_done = callargs_kwargs["on_done"] - on_done(error=None) - - def _simulate_file_download(self, recv_filepath): - self.files.create_file(recv_filepath, "fake response") - - def _simulate_make_request_side_effect(self, **kwargs): - if kwargs.get('recv_filepath'): - self._simulate_file_download(kwargs['recv_filepath']) - self._invoke_done_callbacks() - return mock.DEFAULT - - def test_upload(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect - ) - future = self.transfer_manager.upload( - self.filename, self.bucket, self.key, {}, [self.record_subscriber] - ) - future.result() - - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertEqual(callargs_kwargs["send_filepath"], self.filename) - self.assertIsNone(callargs_kwargs["recv_filepath"]) - self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT - ) - crt_request = callargs_kwargs["request"] - self.assertEqual("PUT", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) - self._assert_subscribers_called(future) - - def test_download(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect - ) - future = self.transfer_manager.download( - self.bucket, self.key, self.filename, {}, [self.record_subscriber] - ) - future.result() - - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - # the recv_filepath will be set to a temporary file path with some - # random suffix - self.assertTrue( - fnmatch.fnmatch( - callargs_kwargs["recv_filepath"], - f'{self.filename}.*', - ) - ) - self.assertIsNone(callargs_kwargs["send_filepath"]) - self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT - ) - crt_request = callargs_kwargs["request"] - self.assertEqual("GET", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) - self._assert_subscribers_called(future) - with open(self.filename, 'rb') as f: - # Check the fake response overwrites the file because of download - self.assertEqual(f.read(), b'fake response') - - def test_delete(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect - ) - future = self.transfer_manager.delete( - self.bucket, self.key, {}, [self.record_subscriber] - ) - future.result() - - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertIsNone(callargs_kwargs["send_filepath"]) - self.assertIsNone(callargs_kwargs["recv_filepath"]) - self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT - ) - crt_request = callargs_kwargs["request"] - self.assertEqual("DELETE", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) - self._assert_subscribers_called(future) - - def test_blocks_when_max_requests_processes_reached(self): - futures = [] - callargs = (self.bucket, self.key, self.filename, {}, []) - max_request_processes = 128 # the hard coded max processes - all_concurrent = max_request_processes + 1 - threads = [] - for i in range(0, all_concurrent): - thread = submitThread(self.transfer_manager, futures, callargs) - thread.start() - threads.append(thread) - # Sleep until the expected max requests has been reached - while len(futures) < max_request_processes: - time.sleep(0.05) - self.assertLessEqual( - self.s3_crt_client.make_request.call_count, max_request_processes - ) - # Release lock - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - on_done = callargs_kwargs["on_done"] - on_done(error=None) - for thread in threads: - thread.join() - self.assertEqual( - self.s3_crt_client.make_request.call_count, all_concurrent - ) - - def _cancel_function(self): - self.cancel_called = True - self.s3_request.finished_future.set_exception( - awscrt.exceptions.from_code(0) - ) - self._invoke_done_callbacks() - - def test_cancel(self): - self.s3_request.finished_future = Future() - self.cancel_called = False - self.s3_request.cancel = self._cancel_function - try: - with self.transfer_manager: - future = self.transfer_manager.upload( - self.filename, self.bucket, self.key, {}, [] - ) - raise KeyboardInterrupt() - except KeyboardInterrupt: - pass - - with self.assertRaises(awscrt.exceptions.AwsCrtError): - future.result() - self.assertTrue(self.cancel_called) - - def test_serializer_error_handling(self): - class SerializationException(Exception): - pass - - class ExceptionRaisingSerializer( - s3transfer.crt.BaseCRTRequestSerializer - ): - def serialize_http_request(self, transfer_type, future): - raise SerializationException() - - not_impl_serializer = ExceptionRaisingSerializer() - transfer_manager = s3transfer.crt.CRTTransferManager( - crt_s3_client=self.s3_crt_client, - crt_request_serializer=not_impl_serializer, - ) - future = transfer_manager.upload( - self.filename, self.bucket, self.key, {}, [] - ) - - with self.assertRaises(SerializationException): - future.result() - - def test_crt_s3_client_error_handling(self): - self.s3_crt_client.make_request.side_effect = ( - awscrt.exceptions.from_code(0) - ) - future = self.transfer_manager.upload( - self.filename, self.bucket, self.key, {}, [] - ) - with self.assertRaises(awscrt.exceptions.AwsCrtError): - future.result() +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import fnmatch +import threading +import time +from concurrent.futures import Future + +from botocore.session import Session + +from s3transfer.subscribers import BaseSubscriber +from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + +if HAS_CRT: + import awscrt + + import s3transfer.crt + + +class submitThread(threading.Thread): + def __init__(self, transfer_manager, futures, callargs): + threading.Thread.__init__(self) + self._transfer_manager = transfer_manager + self._futures = futures + self._callargs = callargs + + def run(self): + self._futures.append(self._transfer_manager.download(*self._callargs)) + + +class RecordingSubscriber(BaseSubscriber): + def __init__(self): + self.on_queued_called = False + self.on_done_called = False + self.bytes_transferred = 0 + self.on_queued_future = None + self.on_done_future = None + + def on_queued(self, future, **kwargs): + self.on_queued_called = True + self.on_queued_future = future + + def on_done(self, future, **kwargs): + self.on_done_called = True + self.on_done_future = future + + +@requires_crt +class TestCRTTransferManager(unittest.TestCase): + def setUp(self): + self.region = 'us-west-2' + self.bucket = "test_bucket" + self.key = "test_key" + self.files = FileCreator() + self.filename = self.files.create_file('myfile', 'my content') + self.expected_path = "/" + self.bucket + "/" + self.key + self.expected_host = "s3.%s.amazonaws.com" % (self.region) + self.s3_request = mock.Mock(awscrt.s3.S3Request) + self.s3_crt_client = mock.Mock(awscrt.s3.S3Client) + self.s3_crt_client.make_request.return_value = self.s3_request + self.session = Session() + self.session.set_config_variable('region', self.region) + self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( + self.session + ) + self.transfer_manager = s3transfer.crt.CRTTransferManager( + crt_s3_client=self.s3_crt_client, + crt_request_serializer=self.request_serializer, + ) + self.record_subscriber = RecordingSubscriber() + + def tearDown(self): + self.files.remove_all() + + def _assert_subscribers_called(self, expected_future=None): + self.assertTrue(self.record_subscriber.on_queued_called) + self.assertTrue(self.record_subscriber.on_done_called) + if expected_future: + self.assertIs( + self.record_subscriber.on_queued_future, expected_future + ) + self.assertIs( + self.record_subscriber.on_done_future, expected_future + ) + + def _invoke_done_callbacks(self, **kwargs): + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + on_done = callargs_kwargs["on_done"] + on_done(error=None) + + def _simulate_file_download(self, recv_filepath): + self.files.create_file(recv_filepath, "fake response") + + def _simulate_make_request_side_effect(self, **kwargs): + if kwargs.get('recv_filepath'): + self._simulate_file_download(kwargs['recv_filepath']) + self._invoke_done_callbacks() + return mock.DEFAULT + + def test_upload(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + self.assertEqual(callargs_kwargs["send_filepath"], self.filename) + self.assertIsNone(callargs_kwargs["recv_filepath"]) + self.assertEqual( + callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT + ) + crt_request = callargs_kwargs["request"] + self.assertEqual("PUT", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self._assert_subscribers_called(future) + + def test_download(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + future = self.transfer_manager.download( + self.bucket, self.key, self.filename, {}, [self.record_subscriber] + ) + future.result() + + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + # the recv_filepath will be set to a temporary file path with some + # random suffix + self.assertTrue( + fnmatch.fnmatch( + callargs_kwargs["recv_filepath"], + f'{self.filename}.*', + ) + ) + self.assertIsNone(callargs_kwargs["send_filepath"]) + self.assertEqual( + callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT + ) + crt_request = callargs_kwargs["request"] + self.assertEqual("GET", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self._assert_subscribers_called(future) + with open(self.filename, 'rb') as f: + # Check the fake response overwrites the file because of download + self.assertEqual(f.read(), b'fake response') + + def test_delete(self): + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) + future = self.transfer_manager.delete( + self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + self.assertIsNone(callargs_kwargs["send_filepath"]) + self.assertIsNone(callargs_kwargs["recv_filepath"]) + self.assertEqual( + callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT + ) + crt_request = callargs_kwargs["request"] + self.assertEqual("DELETE", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self._assert_subscribers_called(future) + + def test_blocks_when_max_requests_processes_reached(self): + futures = [] + callargs = (self.bucket, self.key, self.filename, {}, []) + max_request_processes = 128 # the hard coded max processes + all_concurrent = max_request_processes + 1 + threads = [] + for i in range(0, all_concurrent): + thread = submitThread(self.transfer_manager, futures, callargs) + thread.start() + threads.append(thread) + # Sleep until the expected max requests has been reached + while len(futures) < max_request_processes: + time.sleep(0.05) + self.assertLessEqual( + self.s3_crt_client.make_request.call_count, max_request_processes + ) + # Release lock + callargs = self.s3_crt_client.make_request.call_args + callargs_kwargs = callargs[1] + on_done = callargs_kwargs["on_done"] + on_done(error=None) + for thread in threads: + thread.join() + self.assertEqual( + self.s3_crt_client.make_request.call_count, all_concurrent + ) + + def _cancel_function(self): + self.cancel_called = True + self.s3_request.finished_future.set_exception( + awscrt.exceptions.from_code(0) + ) + self._invoke_done_callbacks() + + def test_cancel(self): + self.s3_request.finished_future = Future() + self.cancel_called = False + self.s3_request.cancel = self._cancel_function + try: + with self.transfer_manager: + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [] + ) + raise KeyboardInterrupt() + except KeyboardInterrupt: + pass + + with self.assertRaises(awscrt.exceptions.AwsCrtError): + future.result() + self.assertTrue(self.cancel_called) + + def test_serializer_error_handling(self): + class SerializationException(Exception): + pass + + class ExceptionRaisingSerializer( + s3transfer.crt.BaseCRTRequestSerializer + ): + def serialize_http_request(self, transfer_type, future): + raise SerializationException() + + not_impl_serializer = ExceptionRaisingSerializer() + transfer_manager = s3transfer.crt.CRTTransferManager( + crt_s3_client=self.s3_crt_client, + crt_request_serializer=not_impl_serializer, + ) + future = transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [] + ) + + with self.assertRaises(SerializationException): + future.result() + + def test_crt_s3_client_error_handling(self): + self.s3_crt_client.make_request.side_effect = ( + awscrt.exceptions.from_code(0) + ) + future = self.transfer_manager.upload( + self.filename, self.bucket, self.key, {}, [] + ) + with self.assertRaises(awscrt.exceptions.AwsCrtError): + future.result() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_delete.py b/contrib/python/s3transfer/py3/tests/functional/test_delete.py index 6d53537448..28587a47a4 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_delete.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_delete.py @@ -1,76 +1,76 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from s3transfer.manager import TransferManager +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.manager import TransferManager from __tests__ import BaseGeneralInterfaceTest - - -class TestDeleteObject(BaseGeneralInterfaceTest): - - __test__ = True - - def setUp(self): - super().setUp() - self.bucket = 'mybucket' - self.key = 'mykey' - self.manager = TransferManager(self.client) - - @property - def method(self): - """The transfer manager method to invoke i.e. upload()""" - return self.manager.delete - - def create_call_kwargs(self): - """The kwargs to be passed to the transfer manager method""" - return { - 'bucket': self.bucket, - 'key': self.key, - } - - def create_invalid_extra_args(self): - return { - 'BadKwargs': True, - } - - def create_stubbed_responses(self): - """A list of stubbed responses that will cause the request to succeed - - The elements of this list is a dictionary that will be used as key - word arguments to botocore.Stubber.add_response(). For example:: - - [{'method': 'put_object', 'service_response': {}}] - """ - return [ - { - 'method': 'delete_object', - 'service_response': {}, - 'expected_params': {'Bucket': self.bucket, 'Key': self.key}, - } - ] - - def create_expected_progress_callback_info(self): - return [] - - def test_known_allowed_args_in_input_shape(self): - op_model = self.client.meta.service_model.operation_model( - 'DeleteObject' - ) - for allowed_arg in self.manager.ALLOWED_DELETE_ARGS: - self.assertIn(allowed_arg, op_model.input_shape.members) - - def test_raise_exception_on_s3_object_lambda_resource(self): - s3_object_lambda_arn = ( - 'arn:aws:s3-object-lambda:us-west-2:123456789012:' - 'accesspoint:my-accesspoint' - ) - with self.assertRaisesRegex(ValueError, 'methods do not support'): - self.manager.delete(s3_object_lambda_arn, self.key) + + +class TestDeleteObject(BaseGeneralInterfaceTest): + + __test__ = True + + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.manager = TransferManager(self.client) + + @property + def method(self): + """The transfer manager method to invoke i.e. upload()""" + return self.manager.delete + + def create_call_kwargs(self): + """The kwargs to be passed to the transfer manager method""" + return { + 'bucket': self.bucket, + 'key': self.key, + } + + def create_invalid_extra_args(self): + return { + 'BadKwargs': True, + } + + def create_stubbed_responses(self): + """A list of stubbed responses that will cause the request to succeed + + The elements of this list is a dictionary that will be used as key + word arguments to botocore.Stubber.add_response(). For example:: + + [{'method': 'put_object', 'service_response': {}}] + """ + return [ + { + 'method': 'delete_object', + 'service_response': {}, + 'expected_params': {'Bucket': self.bucket, 'Key': self.key}, + } + ] + + def create_expected_progress_callback_info(self): + return [] + + def test_known_allowed_args_in_input_shape(self): + op_model = self.client.meta.service_model.operation_model( + 'DeleteObject' + ) + for allowed_arg in self.manager.ALLOWED_DELETE_ARGS: + self.assertIn(allowed_arg, op_model.input_shape.members) + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.delete(s3_object_lambda_arn, self.key) diff --git a/contrib/python/s3transfer/py3/tests/functional/test_download.py b/contrib/python/s3transfer/py3/tests/functional/test_download.py index 746c040d42..64a8a1309d 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_download.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_download.py @@ -1,497 +1,497 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import copy -import glob -import os -import shutil -import tempfile -import time -from io import BytesIO - -from botocore.exceptions import ClientError - -from s3transfer.compat import SOCKET_ERROR -from s3transfer.exceptions import RetriesExceededError -from s3transfer.manager import TransferConfig, TransferManager -from __tests__ import ( - BaseGeneralInterfaceTest, - FileSizeProvider, - NonSeekableWriter, - RecordingOSUtils, - RecordingSubscriber, - StreamWithError, - skip_if_using_serial_implementation, - skip_if_windows, -) - - -class BaseDownloadTest(BaseGeneralInterfaceTest): - def setUp(self): - super().setUp() - self.config = TransferConfig(max_request_concurrency=1) - self._manager = TransferManager(self.client, self.config) - - # Create a temporary directory to write to - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - - # Initialize some default arguments - self.bucket = 'mybucket' - self.key = 'mykey' - self.extra_args = {} - self.subscribers = [] - - # Create a stream to read from - self.content = b'my content' - self.stream = BytesIO(self.content) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.tempdir) - - @property - def manager(self): - return self._manager - - @property - def method(self): - return self.manager.download - - def create_call_kwargs(self): - return { - 'bucket': self.bucket, - 'key': self.key, - 'fileobj': self.filename, - } - - def create_invalid_extra_args(self): - return {'Foo': 'bar'} - - def create_stubbed_responses(self): - # We want to make sure the beginning of the stream is always used - # in case this gets called twice. - self.stream.seek(0) - return [ - { - 'method': 'head_object', - 'service_response': {'ContentLength': len(self.content)}, - }, - { - 'method': 'get_object', - 'service_response': {'Body': self.stream}, - }, - ] - - def create_expected_progress_callback_info(self): - # Note that last read is from the empty sentinel indicating - # that the stream is done. - return [{'bytes_transferred': 10}] - - def add_head_object_response(self, expected_params=None): - head_response = self.create_stubbed_responses()[0] - if expected_params: - head_response['expected_params'] = expected_params - self.stubber.add_response(**head_response) - - def add_successful_get_object_responses( - self, expected_params=None, expected_ranges=None - ): - # Add all get_object responses needed to complete the download. - # Should account for both ranged and nonranged downloads. - for i, stubbed_response in enumerate( - self.create_stubbed_responses()[1:] - ): - if expected_params: - stubbed_response['expected_params'] = copy.deepcopy( - expected_params - ) - if expected_ranges: - stubbed_response['expected_params'][ - 'Range' - ] = expected_ranges[i] - self.stubber.add_response(**stubbed_response) - - def add_n_retryable_get_object_responses(self, n, num_reads=0): - for _ in range(n): - self.stubber.add_response( - method='get_object', - service_response={ - 'Body': StreamWithError( - copy.deepcopy(self.stream), SOCKET_ERROR, num_reads - ) - }, - ) - - def test_download_temporary_file_does_not_exist(self): - self.add_head_object_response() - self.add_successful_get_object_responses() - - future = self.manager.download(**self.create_call_kwargs()) - future.result() - # Make sure the file exists - self.assertTrue(os.path.exists(self.filename)) - # Make sure the random temporary file does not exist - possible_matches = glob.glob('%s*' % self.filename + os.extsep) - self.assertEqual(possible_matches, []) - - def test_download_for_fileobj(self): - self.add_head_object_response() - self.add_successful_get_object_responses() - - with open(self.filename, 'wb') as f: - future = self.manager.download( - self.bucket, self.key, f, self.extra_args - ) - future.result() - - # Ensure that the contents are correct - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - def test_download_for_seekable_filelike_obj(self): - self.add_head_object_response() - self.add_successful_get_object_responses() - - # Create a file-like object to test. In this case, it is a BytesIO - # object. - bytes_io = BytesIO() - - future = self.manager.download( - self.bucket, self.key, bytes_io, self.extra_args - ) - future.result() - - # Ensure that the contents are correct - bytes_io.seek(0) - self.assertEqual(self.content, bytes_io.read()) - - def test_download_for_nonseekable_filelike_obj(self): - self.add_head_object_response() - self.add_successful_get_object_responses() - - with open(self.filename, 'wb') as f: - future = self.manager.download( - self.bucket, self.key, NonSeekableWriter(f), self.extra_args - ) - future.result() - - # Ensure that the contents are correct - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - def test_download_cleanup_on_failure(self): - self.add_head_object_response() - - # Throw an error on the download - self.stubber.add_client_error('get_object') - - future = self.manager.download(**self.create_call_kwargs()) - - with self.assertRaises(ClientError): - future.result() - # Make sure the actual file and the temporary do not exist - # by globbing for the file and any of its extensions - possible_matches = glob.glob('%s*' % self.filename) - self.assertEqual(possible_matches, []) - - def test_download_with_nonexistent_directory(self): - self.add_head_object_response() - self.add_successful_get_object_responses() - - call_kwargs = self.create_call_kwargs() - call_kwargs['fileobj'] = os.path.join( - self.tempdir, 'missing-directory', 'myfile' - ) - future = self.manager.download(**call_kwargs) - with self.assertRaises(IOError): - future.result() - - def test_retries_and_succeeds(self): - self.add_head_object_response() - # Insert a response that will trigger a retry. - self.add_n_retryable_get_object_responses(1) - # Add the normal responses to simulate the download proceeding - # as normal after the retry. - self.add_successful_get_object_responses() - - future = self.manager.download(**self.create_call_kwargs()) - future.result() - - # The retry should have been consumed and the process should have - # continued using the successful responses. - self.stubber.assert_no_pending_responses() - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - def test_retry_failure(self): - self.add_head_object_response() - - max_retries = 3 - self.config.num_download_attempts = max_retries - self._manager = TransferManager(self.client, self.config) - # Add responses that fill up the maximum number of retries. - self.add_n_retryable_get_object_responses(max_retries) - - future = self.manager.download(**self.create_call_kwargs()) - - # A retry exceeded error should have happened. - with self.assertRaises(RetriesExceededError): - future.result() - - # All of the retries should have been used up. - self.stubber.assert_no_pending_responses() - - def test_retry_rewinds_callbacks(self): - self.add_head_object_response() - # Insert a response that will trigger a retry after one read of the - # stream has been made. - self.add_n_retryable_get_object_responses(1, num_reads=1) - # Add the normal responses to simulate the download proceeding - # as normal after the retry. - self.add_successful_get_object_responses() - - recorder_subscriber = RecordingSubscriber() - # Set the streaming to a size that is smaller than the data we - # currently provide to it to simulate rewinds of callbacks. - self.config.io_chunksize = 3 - future = self.manager.download( - subscribers=[recorder_subscriber], **self.create_call_kwargs() - ) - future.result() - - # Ensure that there is no more remaining responses and that contents - # are correct. - self.stubber.assert_no_pending_responses() - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - # Assert that the number of bytes seen is equal to the length of - # downloaded content. - self.assertEqual( - recorder_subscriber.calculate_bytes_seen(), len(self.content) - ) - - # Also ensure that the second progress invocation was negative three - # because a retry happened on the second read of the stream and we - # know that the chunk size for each read is 3. - progress_byte_amts = [ - call['bytes_transferred'] - for call in recorder_subscriber.on_progress_calls - ] - self.assertEqual(-3, progress_byte_amts[1]) - - def test_can_provide_file_size(self): - self.add_successful_get_object_responses() - - call_kwargs = self.create_call_kwargs() - call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))] - - future = self.manager.download(**call_kwargs) - future.result() - - # The HeadObject should have not happened and should have been able - # to successfully download the file. - self.stubber.assert_no_pending_responses() - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - def test_uses_provided_osutil(self): - osutil = RecordingOSUtils() - # Use the recording os utility for the transfer manager - self._manager = TransferManager(self.client, self.config, osutil) - - self.add_head_object_response() - self.add_successful_get_object_responses() - - future = self.manager.download(**self.create_call_kwargs()) - future.result() - # The osutil should have had its open() method invoked when opening - # a temporary file and its rename_file() method invoked when the - # the temporary file was moved to its final location. - self.assertEqual(len(osutil.open_records), 1) - self.assertEqual(len(osutil.rename_records), 1) - - @skip_if_windows('Windows does not support UNIX special files') - @skip_if_using_serial_implementation( - 'A separate thread is needed to read from the fifo' - ) - def test_download_for_fifo_file(self): - self.add_head_object_response() - self.add_successful_get_object_responses() - - # Create the fifo file - os.mkfifo(self.filename) - - future = self.manager.download( - self.bucket, self.key, self.filename, self.extra_args - ) - - # The call to open a fifo will block until there is both a reader - # and a writer, so we need to open it for reading after we've - # started the transfer. - with open(self.filename, 'rb') as fifo: - future.result() - self.assertEqual(fifo.read(), self.content) - - def test_raise_exception_on_s3_object_lambda_resource(self): - s3_object_lambda_arn = ( - 'arn:aws:s3-object-lambda:us-west-2:123456789012:' - 'accesspoint:my-accesspoint' - ) - with self.assertRaisesRegex(ValueError, 'methods do not support'): - self.manager.download( - s3_object_lambda_arn, self.key, self.filename, self.extra_args - ) - - -class TestNonRangedDownload(BaseDownloadTest): - # TODO: If you want to add tests outside of this test class and still - # subclass from BaseDownloadTest you need to set ``__test__ = True``. If - # you do not, your tests will not get picked up by the test runner! This - # needs to be done until we find a better way to ignore running test cases - # from the general test base class, which we do not want ran. - __test__ = True - - def test_download(self): - self.extra_args['RequestPayer'] = 'requester' - expected_params = { - 'Bucket': self.bucket, - 'Key': self.key, - 'RequestPayer': 'requester', - } - self.add_head_object_response(expected_params) - self.add_successful_get_object_responses(expected_params) - future = self.manager.download( - self.bucket, self.key, self.filename, self.extra_args - ) - future.result() - - # Ensure that the contents are correct - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - def test_allowed_copy_params_are_valid(self): - op_model = self.client.meta.service_model.operation_model('GetObject') - for allowed_upload_arg in self._manager.ALLOWED_DOWNLOAD_ARGS: - self.assertIn(allowed_upload_arg, op_model.input_shape.members) - - def test_download_empty_object(self): - self.content = b'' - self.stream = BytesIO(self.content) - self.add_head_object_response() - self.add_successful_get_object_responses() - future = self.manager.download( - self.bucket, self.key, self.filename, self.extra_args - ) - future.result() - - # Ensure that the empty file exists - with open(self.filename, 'rb') as f: - self.assertEqual(b'', f.read()) - - def test_uses_bandwidth_limiter(self): - self.content = b'a' * 1024 * 1024 - self.stream = BytesIO(self.content) - self.config = TransferConfig( - max_request_concurrency=1, max_bandwidth=len(self.content) / 2 - ) - self._manager = TransferManager(self.client, self.config) - - self.add_head_object_response() - self.add_successful_get_object_responses() - - start = time.time() - future = self.manager.download( - self.bucket, self.key, self.filename, self.extra_args - ) - future.result() - # This is just a smoke test to make sure that the limiter is - # being used and not necessary its exactness. So we set the maximum - # bandwidth to len(content)/2 per sec and make sure that it is - # noticeably slower. Ideally it will take more than two seconds, but - # given tracking at the beginning of transfers are not entirely - # accurate setting at the initial start of a transfer, we give us - # some flexibility by setting the expected time to half of the - # theoretical time to take. - self.assertGreaterEqual(time.time() - start, 1) - - # Ensure that the contents are correct - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) - - -class TestRangedDownload(BaseDownloadTest): - # TODO: If you want to add tests outside of this test class and still - # subclass from BaseDownloadTest you need to set ``__test__ = True``. If - # you do not, your tests will not get picked up by the test runner! This - # needs to be done until we find a better way to ignore running test cases - # from the general test base class, which we do not want ran. - __test__ = True - - def setUp(self): - super().setUp() - self.config = TransferConfig( - max_request_concurrency=1, - multipart_threshold=1, - multipart_chunksize=4, - ) - self._manager = TransferManager(self.client, self.config) - - def create_stubbed_responses(self): - return [ - { - 'method': 'head_object', - 'service_response': {'ContentLength': len(self.content)}, - }, - { - 'method': 'get_object', - 'service_response': {'Body': BytesIO(self.content[0:4])}, - }, - { - 'method': 'get_object', - 'service_response': {'Body': BytesIO(self.content[4:8])}, - }, - { - 'method': 'get_object', - 'service_response': {'Body': BytesIO(self.content[8:])}, - }, - ] - - def create_expected_progress_callback_info(self): - return [ - {'bytes_transferred': 4}, - {'bytes_transferred': 4}, - {'bytes_transferred': 2}, - ] - - def test_download(self): - self.extra_args['RequestPayer'] = 'requester' - expected_params = { - 'Bucket': self.bucket, - 'Key': self.key, - 'RequestPayer': 'requester', - } - expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] - self.add_head_object_response(expected_params) - self.add_successful_get_object_responses( - expected_params, expected_ranges - ) - - future = self.manager.download( - self.bucket, self.key, self.filename, self.extra_args - ) - future.result() - - # Ensure that the contents are correct - with open(self.filename, 'rb') as f: - self.assertEqual(self.content, f.read()) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import glob +import os +import shutil +import tempfile +import time +from io import BytesIO + +from botocore.exceptions import ClientError + +from s3transfer.compat import SOCKET_ERROR +from s3transfer.exceptions import RetriesExceededError +from s3transfer.manager import TransferConfig, TransferManager +from __tests__ import ( + BaseGeneralInterfaceTest, + FileSizeProvider, + NonSeekableWriter, + RecordingOSUtils, + RecordingSubscriber, + StreamWithError, + skip_if_using_serial_implementation, + skip_if_windows, +) + + +class BaseDownloadTest(BaseGeneralInterfaceTest): + def setUp(self): + super().setUp() + self.config = TransferConfig(max_request_concurrency=1) + self._manager = TransferManager(self.client, self.config) + + # Create a temporary directory to write to + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + # Initialize some default arguments + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # Create a stream to read from + self.content = b'my content' + self.stream = BytesIO(self.content) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + @property + def manager(self): + return self._manager + + @property + def method(self): + return self.manager.download + + def create_call_kwargs(self): + return { + 'bucket': self.bucket, + 'key': self.key, + 'fileobj': self.filename, + } + + def create_invalid_extra_args(self): + return {'Foo': 'bar'} + + def create_stubbed_responses(self): + # We want to make sure the beginning of the stream is always used + # in case this gets called twice. + self.stream.seek(0) + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': self.stream}, + }, + ] + + def create_expected_progress_callback_info(self): + # Note that last read is from the empty sentinel indicating + # that the stream is done. + return [{'bytes_transferred': 10}] + + def add_head_object_response(self, expected_params=None): + head_response = self.create_stubbed_responses()[0] + if expected_params: + head_response['expected_params'] = expected_params + self.stubber.add_response(**head_response) + + def add_successful_get_object_responses( + self, expected_params=None, expected_ranges=None + ): + # Add all get_object responses needed to complete the download. + # Should account for both ranged and nonranged downloads. + for i, stubbed_response in enumerate( + self.create_stubbed_responses()[1:] + ): + if expected_params: + stubbed_response['expected_params'] = copy.deepcopy( + expected_params + ) + if expected_ranges: + stubbed_response['expected_params'][ + 'Range' + ] = expected_ranges[i] + self.stubber.add_response(**stubbed_response) + + def add_n_retryable_get_object_responses(self, n, num_reads=0): + for _ in range(n): + self.stubber.add_response( + method='get_object', + service_response={ + 'Body': StreamWithError( + copy.deepcopy(self.stream), SOCKET_ERROR, num_reads + ) + }, + ) + + def test_download_temporary_file_does_not_exist(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + future = self.manager.download(**self.create_call_kwargs()) + future.result() + # Make sure the file exists + self.assertTrue(os.path.exists(self.filename)) + # Make sure the random temporary file does not exist + possible_matches = glob.glob('%s*' % self.filename + os.extsep) + self.assertEqual(possible_matches, []) + + def test_download_for_fileobj(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + with open(self.filename, 'wb') as f: + future = self.manager.download( + self.bucket, self.key, f, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_download_for_seekable_filelike_obj(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + # Create a file-like object to test. In this case, it is a BytesIO + # object. + bytes_io = BytesIO() + + future = self.manager.download( + self.bucket, self.key, bytes_io, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + bytes_io.seek(0) + self.assertEqual(self.content, bytes_io.read()) + + def test_download_for_nonseekable_filelike_obj(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + with open(self.filename, 'wb') as f: + future = self.manager.download( + self.bucket, self.key, NonSeekableWriter(f), self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_download_cleanup_on_failure(self): + self.add_head_object_response() + + # Throw an error on the download + self.stubber.add_client_error('get_object') + + future = self.manager.download(**self.create_call_kwargs()) + + with self.assertRaises(ClientError): + future.result() + # Make sure the actual file and the temporary do not exist + # by globbing for the file and any of its extensions + possible_matches = glob.glob('%s*' % self.filename) + self.assertEqual(possible_matches, []) + + def test_download_with_nonexistent_directory(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['fileobj'] = os.path.join( + self.tempdir, 'missing-directory', 'myfile' + ) + future = self.manager.download(**call_kwargs) + with self.assertRaises(IOError): + future.result() + + def test_retries_and_succeeds(self): + self.add_head_object_response() + # Insert a response that will trigger a retry. + self.add_n_retryable_get_object_responses(1) + # Add the normal responses to simulate the download proceeding + # as normal after the retry. + self.add_successful_get_object_responses() + + future = self.manager.download(**self.create_call_kwargs()) + future.result() + + # The retry should have been consumed and the process should have + # continued using the successful responses. + self.stubber.assert_no_pending_responses() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_retry_failure(self): + self.add_head_object_response() + + max_retries = 3 + self.config.num_download_attempts = max_retries + self._manager = TransferManager(self.client, self.config) + # Add responses that fill up the maximum number of retries. + self.add_n_retryable_get_object_responses(max_retries) + + future = self.manager.download(**self.create_call_kwargs()) + + # A retry exceeded error should have happened. + with self.assertRaises(RetriesExceededError): + future.result() + + # All of the retries should have been used up. + self.stubber.assert_no_pending_responses() + + def test_retry_rewinds_callbacks(self): + self.add_head_object_response() + # Insert a response that will trigger a retry after one read of the + # stream has been made. + self.add_n_retryable_get_object_responses(1, num_reads=1) + # Add the normal responses to simulate the download proceeding + # as normal after the retry. + self.add_successful_get_object_responses() + + recorder_subscriber = RecordingSubscriber() + # Set the streaming to a size that is smaller than the data we + # currently provide to it to simulate rewinds of callbacks. + self.config.io_chunksize = 3 + future = self.manager.download( + subscribers=[recorder_subscriber], **self.create_call_kwargs() + ) + future.result() + + # Ensure that there is no more remaining responses and that contents + # are correct. + self.stubber.assert_no_pending_responses() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + # Assert that the number of bytes seen is equal to the length of + # downloaded content. + self.assertEqual( + recorder_subscriber.calculate_bytes_seen(), len(self.content) + ) + + # Also ensure that the second progress invocation was negative three + # because a retry happened on the second read of the stream and we + # know that the chunk size for each read is 3. + progress_byte_amts = [ + call['bytes_transferred'] + for call in recorder_subscriber.on_progress_calls + ] + self.assertEqual(-3, progress_byte_amts[1]) + + def test_can_provide_file_size(self): + self.add_successful_get_object_responses() + + call_kwargs = self.create_call_kwargs() + call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))] + + future = self.manager.download(**call_kwargs) + future.result() + + # The HeadObject should have not happened and should have been able + # to successfully download the file. + self.stubber.assert_no_pending_responses() + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_uses_provided_osutil(self): + osutil = RecordingOSUtils() + # Use the recording os utility for the transfer manager + self._manager = TransferManager(self.client, self.config, osutil) + + self.add_head_object_response() + self.add_successful_get_object_responses() + + future = self.manager.download(**self.create_call_kwargs()) + future.result() + # The osutil should have had its open() method invoked when opening + # a temporary file and its rename_file() method invoked when the + # the temporary file was moved to its final location. + self.assertEqual(len(osutil.open_records), 1) + self.assertEqual(len(osutil.rename_records), 1) + + @skip_if_windows('Windows does not support UNIX special files') + @skip_if_using_serial_implementation( + 'A separate thread is needed to read from the fifo' + ) + def test_download_for_fifo_file(self): + self.add_head_object_response() + self.add_successful_get_object_responses() + + # Create the fifo file + os.mkfifo(self.filename) + + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + + # The call to open a fifo will block until there is both a reader + # and a writer, so we need to open it for reading after we've + # started the transfer. + with open(self.filename, 'rb') as fifo: + future.result() + self.assertEqual(fifo.read(), self.content) + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.download( + s3_object_lambda_arn, self.key, self.filename, self.extra_args + ) + + +class TestNonRangedDownload(BaseDownloadTest): + # TODO: If you want to add tests outside of this test class and still + # subclass from BaseDownloadTest you need to set ``__test__ = True``. If + # you do not, your tests will not get picked up by the test runner! This + # needs to be done until we find a better way to ignore running test cases + # from the general test base class, which we do not want ran. + __test__ = True + + def test_download(self): + self.extra_args['RequestPayer'] = 'requester' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'RequestPayer': 'requester', + } + self.add_head_object_response(expected_params) + self.add_successful_get_object_responses(expected_params) + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + def test_allowed_copy_params_are_valid(self): + op_model = self.client.meta.service_model.operation_model('GetObject') + for allowed_upload_arg in self._manager.ALLOWED_DOWNLOAD_ARGS: + self.assertIn(allowed_upload_arg, op_model.input_shape.members) + + def test_download_empty_object(self): + self.content = b'' + self.stream = BytesIO(self.content) + self.add_head_object_response() + self.add_successful_get_object_responses() + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the empty file exists + with open(self.filename, 'rb') as f: + self.assertEqual(b'', f.read()) + + def test_uses_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + self.stream = BytesIO(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content) / 2 + ) + self._manager = TransferManager(self.client, self.config) + + self.add_head_object_response() + self.add_successful_get_object_responses() + + start = time.time() + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + + +class TestRangedDownload(BaseDownloadTest): + # TODO: If you want to add tests outside of this test class and still + # subclass from BaseDownloadTest you need to set ``__test__ = True``. If + # you do not, your tests will not get picked up by the test runner! This + # needs to be done until we find a better way to ignore running test cases + # from the general test base class, which we do not want ran. + __test__ = True + + def setUp(self): + super().setUp() + self.config = TransferConfig( + max_request_concurrency=1, + multipart_threshold=1, + multipart_chunksize=4, + ) + self._manager = TransferManager(self.client, self.config) + + def create_stubbed_responses(self): + return [ + { + 'method': 'head_object', + 'service_response': {'ContentLength': len(self.content)}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': BytesIO(self.content[0:4])}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': BytesIO(self.content[4:8])}, + }, + { + 'method': 'get_object', + 'service_response': {'Body': BytesIO(self.content[8:])}, + }, + ] + + def create_expected_progress_callback_info(self): + return [ + {'bytes_transferred': 4}, + {'bytes_transferred': 4}, + {'bytes_transferred': 2}, + ] + + def test_download(self): + self.extra_args['RequestPayer'] = 'requester' + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'RequestPayer': 'requester', + } + expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-'] + self.add_head_object_response(expected_params) + self.add_successful_get_object_responses( + expected_params, expected_ranges + ) + + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args + ) + future.result() + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) diff --git a/contrib/python/s3transfer/py3/tests/functional/test_manager.py b/contrib/python/s3transfer/py3/tests/functional/test_manager.py index bde2c10201..1c980e7bc6 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_manager.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_manager.py @@ -1,191 +1,191 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from io import BytesIO - -from botocore.awsrequest import create_request_object - -from s3transfer.exceptions import CancelledError, FatalError -from s3transfer.futures import BaseExecutor -from s3transfer.manager import TransferConfig, TransferManager -from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation - - -class ArbitraryException(Exception): - pass - - -class SignalTransferringBody(BytesIO): - """A mocked body with the ability to signal when transfers occur""" - - def __init__(self): - super().__init__() - self.signal_transferring_call_count = 0 - self.signal_not_transferring_call_count = 0 - - def signal_transferring(self): - self.signal_transferring_call_count += 1 - - def signal_not_transferring(self): - self.signal_not_transferring_call_count += 1 - - def seek(self, where, whence=0): - pass - - def tell(self): - return 0 - - def read(self, amount=0): - return b'' - - -class TestTransferManager(StubbedClientTest): - @skip_if_using_serial_implementation( - 'Exception is thrown once all transfers are submitted. ' - 'However for the serial implementation, transfers are performed ' - 'in main thread meaning all transfers will complete before the ' - 'exception being thrown.' - ) - def test_error_in_context_manager_cancels_incomplete_transfers(self): - # The purpose of this test is to make sure if an error is raised - # in the body of the context manager, incomplete transfers will - # be cancelled with value of the exception wrapped by a CancelledError - - # NOTE: The fact that delete() was chosen to test this is arbitrary - # other than it is the easiet to set up for the stubber. - # The specific operation is not important to the purpose of this test. - num_transfers = 100 - futures = [] - ref_exception_msg = 'arbitrary exception' - - for _ in range(num_transfers): - self.stubber.add_response('delete_object', {}) - - manager = TransferManager( - self.client, - TransferConfig( - max_request_concurrency=1, max_submission_concurrency=1 - ), - ) - try: - with manager: - for i in range(num_transfers): - futures.append(manager.delete('mybucket', 'mykey')) - raise ArbitraryException(ref_exception_msg) - except ArbitraryException: - # At least one of the submitted futures should have been - # cancelled. - with self.assertRaisesRegex(FatalError, ref_exception_msg): - for future in futures: - future.result() - - @skip_if_using_serial_implementation( - 'Exception is thrown once all transfers are submitted. ' - 'However for the serial implementation, transfers are performed ' - 'in main thread meaning all transfers will complete before the ' - 'exception being thrown.' - ) - def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self): - # The purpose of this test is to make sure if an error is raised - # in the body of the context manager, incomplete transfers will - # be cancelled with value of the exception wrapped by a CancelledError - - # NOTE: The fact that delete() was chosen to test this is arbitrary - # other than it is the easiet to set up for the stubber. - # The specific operation is not important to the purpose of this test. - num_transfers = 100 - futures = [] - - for _ in range(num_transfers): - self.stubber.add_response('delete_object', {}) - - manager = TransferManager( - self.client, - TransferConfig( - max_request_concurrency=1, max_submission_concurrency=1 - ), - ) - try: - with manager: - for i in range(num_transfers): - futures.append(manager.delete('mybucket', 'mykey')) - raise KeyboardInterrupt() - except KeyboardInterrupt: - # At least one of the submitted futures should have been - # cancelled. - with self.assertRaisesRegex(CancelledError, 'KeyboardInterrupt()'): - for future in futures: - future.result() - - def test_enable_disable_callbacks_only_ever_registered_once(self): - body = SignalTransferringBody() - request = create_request_object( - { - 'method': 'PUT', - 'url': 'https://s3.amazonaws.com', - 'body': body, - 'headers': {}, - 'context': {}, - } - ) - # Create two TransferManager's using the same client - TransferManager(self.client) - TransferManager(self.client) - self.client.meta.events.emit( - 'request-created.s3', request=request, operation_name='PutObject' - ) - # The client should have only have the enable/disable callback - # handlers registered once depite being used for two different - # TransferManagers. - self.assertEqual( - body.signal_transferring_call_count, - 1, - 'The enable_callback() should have only ever been registered once', - ) - self.assertEqual( - body.signal_not_transferring_call_count, - 1, - 'The disable_callback() should have only ever been registered ' - 'once', - ) - - def test_use_custom_executor_implementation(self): - mocked_executor_cls = mock.Mock(BaseExecutor) - transfer_manager = TransferManager( - self.client, executor_cls=mocked_executor_cls - ) - transfer_manager.delete('bucket', 'key') - self.assertTrue(mocked_executor_cls.return_value.submit.called) - - def test_unicode_exception_in_context_manager(self): - with self.assertRaises(ArbitraryException): - with TransferManager(self.client): - raise ArbitraryException('\u2713') - - def test_client_property(self): - manager = TransferManager(self.client) - self.assertIs(manager.client, self.client) - - def test_config_property(self): - config = TransferConfig() - manager = TransferManager(self.client, config) - self.assertIs(manager.config, config) - - def test_can_disable_bucket_validation(self): - s3_object_lambda_arn = ( - 'arn:aws:s3-object-lambda:us-west-2:123456789012:' - 'accesspoint:my-accesspoint' - ) - config = TransferConfig() - manager = TransferManager(self.client, config) - manager.VALIDATE_SUPPORTED_BUCKET_VALUES = False - manager.delete(s3_object_lambda_arn, 'my-key') +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from io import BytesIO + +from botocore.awsrequest import create_request_object + +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import BaseExecutor +from s3transfer.manager import TransferConfig, TransferManager +from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation + + +class ArbitraryException(Exception): + pass + + +class SignalTransferringBody(BytesIO): + """A mocked body with the ability to signal when transfers occur""" + + def __init__(self): + super().__init__() + self.signal_transferring_call_count = 0 + self.signal_not_transferring_call_count = 0 + + def signal_transferring(self): + self.signal_transferring_call_count += 1 + + def signal_not_transferring(self): + self.signal_not_transferring_call_count += 1 + + def seek(self, where, whence=0): + pass + + def tell(self): + return 0 + + def read(self, amount=0): + return b'' + + +class TestTransferManager(StubbedClientTest): + @skip_if_using_serial_implementation( + 'Exception is thrown once all transfers are submitted. ' + 'However for the serial implementation, transfers are performed ' + 'in main thread meaning all transfers will complete before the ' + 'exception being thrown.' + ) + def test_error_in_context_manager_cancels_incomplete_transfers(self): + # The purpose of this test is to make sure if an error is raised + # in the body of the context manager, incomplete transfers will + # be cancelled with value of the exception wrapped by a CancelledError + + # NOTE: The fact that delete() was chosen to test this is arbitrary + # other than it is the easiet to set up for the stubber. + # The specific operation is not important to the purpose of this test. + num_transfers = 100 + futures = [] + ref_exception_msg = 'arbitrary exception' + + for _ in range(num_transfers): + self.stubber.add_response('delete_object', {}) + + manager = TransferManager( + self.client, + TransferConfig( + max_request_concurrency=1, max_submission_concurrency=1 + ), + ) + try: + with manager: + for i in range(num_transfers): + futures.append(manager.delete('mybucket', 'mykey')) + raise ArbitraryException(ref_exception_msg) + except ArbitraryException: + # At least one of the submitted futures should have been + # cancelled. + with self.assertRaisesRegex(FatalError, ref_exception_msg): + for future in futures: + future.result() + + @skip_if_using_serial_implementation( + 'Exception is thrown once all transfers are submitted. ' + 'However for the serial implementation, transfers are performed ' + 'in main thread meaning all transfers will complete before the ' + 'exception being thrown.' + ) + def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self): + # The purpose of this test is to make sure if an error is raised + # in the body of the context manager, incomplete transfers will + # be cancelled with value of the exception wrapped by a CancelledError + + # NOTE: The fact that delete() was chosen to test this is arbitrary + # other than it is the easiet to set up for the stubber. + # The specific operation is not important to the purpose of this test. + num_transfers = 100 + futures = [] + + for _ in range(num_transfers): + self.stubber.add_response('delete_object', {}) + + manager = TransferManager( + self.client, + TransferConfig( + max_request_concurrency=1, max_submission_concurrency=1 + ), + ) + try: + with manager: + for i in range(num_transfers): + futures.append(manager.delete('mybucket', 'mykey')) + raise KeyboardInterrupt() + except KeyboardInterrupt: + # At least one of the submitted futures should have been + # cancelled. + with self.assertRaisesRegex(CancelledError, 'KeyboardInterrupt()'): + for future in futures: + future.result() + + def test_enable_disable_callbacks_only_ever_registered_once(self): + body = SignalTransferringBody() + request = create_request_object( + { + 'method': 'PUT', + 'url': 'https://s3.amazonaws.com', + 'body': body, + 'headers': {}, + 'context': {}, + } + ) + # Create two TransferManager's using the same client + TransferManager(self.client) + TransferManager(self.client) + self.client.meta.events.emit( + 'request-created.s3', request=request, operation_name='PutObject' + ) + # The client should have only have the enable/disable callback + # handlers registered once depite being used for two different + # TransferManagers. + self.assertEqual( + body.signal_transferring_call_count, + 1, + 'The enable_callback() should have only ever been registered once', + ) + self.assertEqual( + body.signal_not_transferring_call_count, + 1, + 'The disable_callback() should have only ever been registered ' + 'once', + ) + + def test_use_custom_executor_implementation(self): + mocked_executor_cls = mock.Mock(BaseExecutor) + transfer_manager = TransferManager( + self.client, executor_cls=mocked_executor_cls + ) + transfer_manager.delete('bucket', 'key') + self.assertTrue(mocked_executor_cls.return_value.submit.called) + + def test_unicode_exception_in_context_manager(self): + with self.assertRaises(ArbitraryException): + with TransferManager(self.client): + raise ArbitraryException('\u2713') + + def test_client_property(self): + manager = TransferManager(self.client) + self.assertIs(manager.client, self.client) + + def test_config_property(self): + config = TransferConfig() + manager = TransferManager(self.client, config) + self.assertIs(manager.config, config) + + def test_can_disable_bucket_validation(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + config = TransferConfig() + manager = TransferManager(self.client, config) + manager.VALIDATE_SUPPORTED_BUCKET_VALUES = False + manager.delete(s3_object_lambda_arn, 'my-key') diff --git a/contrib/python/s3transfer/py3/tests/functional/test_processpool.py b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py index d347efa869..1396c919f2 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_processpool.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py @@ -1,281 +1,281 @@ -# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import glob -import os -from io import BytesIO -from multiprocessing.managers import BaseManager - -import botocore.exceptions -import botocore.session -from botocore.stub import Stubber - -from s3transfer.exceptions import CancelledError -from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig -from __tests__ import FileCreator, mock, unittest - - -class StubbedClient: - def __init__(self): - self._client = botocore.session.get_session().create_client( - 's3', - 'us-west-2', - aws_access_key_id='foo', - aws_secret_access_key='bar', - ) - self._stubber = Stubber(self._client) - self._stubber.activate() - self._caught_stubber_errors = [] - - def get_object(self, **kwargs): - return self._client.get_object(**kwargs) - - def head_object(self, **kwargs): - return self._client.head_object(**kwargs) - - def add_response(self, *args, **kwargs): - self._stubber.add_response(*args, **kwargs) - - def add_client_error(self, *args, **kwargs): - self._stubber.add_client_error(*args, **kwargs) - - -class StubbedClientManager(BaseManager): - pass - - -StubbedClientManager.register('StubbedClient', StubbedClient) - - -# Ideally a Mock would be used here. However, they cannot be pickled -# for Windows. So instead we define a factory class at the module level that -# can return a stubbed client we initialized in the setUp. -class StubbedClientFactory: - def __init__(self, stubbed_client): - self._stubbed_client = stubbed_client - - def __call__(self, *args, **kwargs): - # The __call__ is defined so we can provide an instance of the - # StubbedClientFactory to mock.patch() and have the instance be - # returned when the patched class is instantiated. - return self - - def create_client(self): - return self._stubbed_client - - -class TestProcessPoolDownloader(unittest.TestCase): - def setUp(self): - # The stubbed client needs to run in a manager to be shared across - # processes and have it properly consume the stubbed response across - # processes. - self.manager = StubbedClientManager() - self.manager.start() - self.stubbed_client = self.manager.StubbedClient() - self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client) - - self.client_factory_patch = mock.patch( - 's3transfer.processpool.ClientFactory', self.stubbed_client_factory - ) - self.client_factory_patch.start() - self.files = FileCreator() - - self.config = ProcessTransferConfig(max_request_processes=1) - self.downloader = ProcessPoolDownloader(config=self.config) - self.bucket = 'mybucket' - self.key = 'mykey' - self.filename = self.files.full_path('filename') - self.remote_contents = b'my content' - self.stream = BytesIO(self.remote_contents) - - def tearDown(self): - self.manager.shutdown() - self.client_factory_patch.stop() - self.files.remove_all() - - def assert_contents(self, filename, expected_contents): - self.assertTrue(os.path.exists(filename)) - with open(filename, 'rb') as f: - self.assertEqual(f.read(), expected_contents) - - def test_download_file(self): - self.stubbed_client.add_response( - 'head_object', {'ContentLength': len(self.remote_contents)} - ) - self.stubbed_client.add_response('get_object', {'Body': self.stream}) - with self.downloader: - self.downloader.download_file(self.bucket, self.key, self.filename) - self.assert_contents(self.filename, self.remote_contents) - - def test_download_multiple_files(self): - self.stubbed_client.add_response('get_object', {'Body': self.stream}) - self.stubbed_client.add_response( - 'get_object', {'Body': BytesIO(self.remote_contents)} - ) - with self.downloader: - self.downloader.download_file( - self.bucket, - self.key, - self.filename, - expected_size=len(self.remote_contents), - ) - other_file = self.files.full_path('filename2') - self.downloader.download_file( - self.bucket, - self.key, - other_file, - expected_size=len(self.remote_contents), - ) - self.assert_contents(self.filename, self.remote_contents) - self.assert_contents(other_file, self.remote_contents) - - def test_download_file_ranged_download(self): - half_of_content_length = int(len(self.remote_contents) / 2) - self.stubbed_client.add_response( - 'head_object', {'ContentLength': len(self.remote_contents)} - ) - self.stubbed_client.add_response( - 'get_object', - {'Body': BytesIO(self.remote_contents[:half_of_content_length])}, - ) - self.stubbed_client.add_response( - 'get_object', - {'Body': BytesIO(self.remote_contents[half_of_content_length:])}, - ) - downloader = ProcessPoolDownloader( - config=ProcessTransferConfig( - multipart_chunksize=half_of_content_length, - multipart_threshold=half_of_content_length, - max_request_processes=1, - ) - ) - with downloader: - downloader.download_file(self.bucket, self.key, self.filename) - self.assert_contents(self.filename, self.remote_contents) - - def test_download_file_extra_args(self): - self.stubbed_client.add_response( - 'head_object', - {'ContentLength': len(self.remote_contents)}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'VersionId': 'versionid', - }, - ) - self.stubbed_client.add_response( - 'get_object', - {'Body': self.stream}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'VersionId': 'versionid', - }, - ) - with self.downloader: - self.downloader.download_file( - self.bucket, - self.key, - self.filename, - extra_args={'VersionId': 'versionid'}, - ) - self.assert_contents(self.filename, self.remote_contents) - - def test_download_file_expected_size(self): - self.stubbed_client.add_response('get_object', {'Body': self.stream}) - with self.downloader: - self.downloader.download_file( - self.bucket, - self.key, - self.filename, - expected_size=len(self.remote_contents), - ) - self.assert_contents(self.filename, self.remote_contents) - - def test_cleans_up_tempfile_on_failure(self): - self.stubbed_client.add_client_error('get_object', 'NoSuchKey') - with self.downloader: - self.downloader.download_file( - self.bucket, - self.key, - self.filename, - expected_size=len(self.remote_contents), - ) - self.assertFalse(os.path.exists(self.filename)) - # Any tempfile should have been erased as well - possible_matches = glob.glob('%s*' % self.filename + os.extsep) - self.assertEqual(possible_matches, []) - - def test_validates_extra_args(self): - with self.downloader: - with self.assertRaises(ValueError): - self.downloader.download_file( - self.bucket, - self.key, - self.filename, - extra_args={'NotSupported': 'NotSupported'}, - ) - - def test_result_with_success(self): - self.stubbed_client.add_response('get_object', {'Body': self.stream}) - with self.downloader: - future = self.downloader.download_file( - self.bucket, - self.key, - self.filename, - expected_size=len(self.remote_contents), - ) - self.assertIsNone(future.result()) - - def test_result_with_exception(self): - self.stubbed_client.add_client_error('get_object', 'NoSuchKey') - with self.downloader: - future = self.downloader.download_file( - self.bucket, - self.key, - self.filename, - expected_size=len(self.remote_contents), - ) - with self.assertRaises(botocore.exceptions.ClientError): - future.result() - - def test_result_with_cancel(self): - self.stubbed_client.add_response('get_object', {'Body': self.stream}) - with self.downloader: - future = self.downloader.download_file( - self.bucket, - self.key, - self.filename, - expected_size=len(self.remote_contents), - ) - future.cancel() - with self.assertRaises(CancelledError): - future.result() - - def test_shutdown_with_no_downloads(self): - downloader = ProcessPoolDownloader() - try: - downloader.shutdown() - except AttributeError: - self.fail( - 'The downloader should be able to be shutdown even though ' - 'the downloader was never started.' - ) - - def test_shutdown_with_no_downloads_and_ctrl_c(self): - # Special shutdown logic happens if a KeyboardInterrupt is raised in - # the context manager. However, this logic can not happen if the - # downloader was never started. So a KeyboardInterrupt should be - # the only exception propagated. - with self.assertRaises(KeyboardInterrupt): - with self.downloader: - raise KeyboardInterrupt() +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import glob +import os +from io import BytesIO +from multiprocessing.managers import BaseManager + +import botocore.exceptions +import botocore.session +from botocore.stub import Stubber + +from s3transfer.exceptions import CancelledError +from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig +from __tests__ import FileCreator, mock, unittest + + +class StubbedClient: + def __init__(self): + self._client = botocore.session.get_session().create_client( + 's3', + 'us-west-2', + aws_access_key_id='foo', + aws_secret_access_key='bar', + ) + self._stubber = Stubber(self._client) + self._stubber.activate() + self._caught_stubber_errors = [] + + def get_object(self, **kwargs): + return self._client.get_object(**kwargs) + + def head_object(self, **kwargs): + return self._client.head_object(**kwargs) + + def add_response(self, *args, **kwargs): + self._stubber.add_response(*args, **kwargs) + + def add_client_error(self, *args, **kwargs): + self._stubber.add_client_error(*args, **kwargs) + + +class StubbedClientManager(BaseManager): + pass + + +StubbedClientManager.register('StubbedClient', StubbedClient) + + +# Ideally a Mock would be used here. However, they cannot be pickled +# for Windows. So instead we define a factory class at the module level that +# can return a stubbed client we initialized in the setUp. +class StubbedClientFactory: + def __init__(self, stubbed_client): + self._stubbed_client = stubbed_client + + def __call__(self, *args, **kwargs): + # The __call__ is defined so we can provide an instance of the + # StubbedClientFactory to mock.patch() and have the instance be + # returned when the patched class is instantiated. + return self + + def create_client(self): + return self._stubbed_client + + +class TestProcessPoolDownloader(unittest.TestCase): + def setUp(self): + # The stubbed client needs to run in a manager to be shared across + # processes and have it properly consume the stubbed response across + # processes. + self.manager = StubbedClientManager() + self.manager.start() + self.stubbed_client = self.manager.StubbedClient() + self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client) + + self.client_factory_patch = mock.patch( + 's3transfer.processpool.ClientFactory', self.stubbed_client_factory + ) + self.client_factory_patch.start() + self.files = FileCreator() + + self.config = ProcessTransferConfig(max_request_processes=1) + self.downloader = ProcessPoolDownloader(config=self.config) + self.bucket = 'mybucket' + self.key = 'mykey' + self.filename = self.files.full_path('filename') + self.remote_contents = b'my content' + self.stream = BytesIO(self.remote_contents) + + def tearDown(self): + self.manager.shutdown() + self.client_factory_patch.stop() + self.files.remove_all() + + def assert_contents(self, filename, expected_contents): + self.assertTrue(os.path.exists(filename)) + with open(filename, 'rb') as f: + self.assertEqual(f.read(), expected_contents) + + def test_download_file(self): + self.stubbed_client.add_response( + 'head_object', {'ContentLength': len(self.remote_contents)} + ) + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + self.downloader.download_file(self.bucket, self.key, self.filename) + self.assert_contents(self.filename, self.remote_contents) + + def test_download_multiple_files(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + self.stubbed_client.add_response( + 'get_object', {'Body': BytesIO(self.remote_contents)} + ) + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + other_file = self.files.full_path('filename2') + self.downloader.download_file( + self.bucket, + self.key, + other_file, + expected_size=len(self.remote_contents), + ) + self.assert_contents(self.filename, self.remote_contents) + self.assert_contents(other_file, self.remote_contents) + + def test_download_file_ranged_download(self): + half_of_content_length = int(len(self.remote_contents) / 2) + self.stubbed_client.add_response( + 'head_object', {'ContentLength': len(self.remote_contents)} + ) + self.stubbed_client.add_response( + 'get_object', + {'Body': BytesIO(self.remote_contents[:half_of_content_length])}, + ) + self.stubbed_client.add_response( + 'get_object', + {'Body': BytesIO(self.remote_contents[half_of_content_length:])}, + ) + downloader = ProcessPoolDownloader( + config=ProcessTransferConfig( + multipart_chunksize=half_of_content_length, + multipart_threshold=half_of_content_length, + max_request_processes=1, + ) + ) + with downloader: + downloader.download_file(self.bucket, self.key, self.filename) + self.assert_contents(self.filename, self.remote_contents) + + def test_download_file_extra_args(self): + self.stubbed_client.add_response( + 'head_object', + {'ContentLength': len(self.remote_contents)}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + self.stubbed_client.add_response( + 'get_object', + {'Body': self.stream}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + extra_args={'VersionId': 'versionid'}, + ) + self.assert_contents(self.filename, self.remote_contents) + + def test_download_file_expected_size(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + self.assert_contents(self.filename, self.remote_contents) + + def test_cleans_up_tempfile_on_failure(self): + self.stubbed_client.add_client_error('get_object', 'NoSuchKey') + with self.downloader: + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + self.assertFalse(os.path.exists(self.filename)) + # Any tempfile should have been erased as well + possible_matches = glob.glob('%s*' % self.filename + os.extsep) + self.assertEqual(possible_matches, []) + + def test_validates_extra_args(self): + with self.downloader: + with self.assertRaises(ValueError): + self.downloader.download_file( + self.bucket, + self.key, + self.filename, + extra_args={'NotSupported': 'NotSupported'}, + ) + + def test_result_with_success(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + future = self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + self.assertIsNone(future.result()) + + def test_result_with_exception(self): + self.stubbed_client.add_client_error('get_object', 'NoSuchKey') + with self.downloader: + future = self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + with self.assertRaises(botocore.exceptions.ClientError): + future.result() + + def test_result_with_cancel(self): + self.stubbed_client.add_response('get_object', {'Body': self.stream}) + with self.downloader: + future = self.downloader.download_file( + self.bucket, + self.key, + self.filename, + expected_size=len(self.remote_contents), + ) + future.cancel() + with self.assertRaises(CancelledError): + future.result() + + def test_shutdown_with_no_downloads(self): + downloader = ProcessPoolDownloader() + try: + downloader.shutdown() + except AttributeError: + self.fail( + 'The downloader should be able to be shutdown even though ' + 'the downloader was never started.' + ) + + def test_shutdown_with_no_downloads_and_ctrl_c(self): + # Special shutdown logic happens if a KeyboardInterrupt is raised in + # the context manager. However, this logic can not happen if the + # downloader was never started. So a KeyboardInterrupt should be + # the only exception propagated. + with self.assertRaises(KeyboardInterrupt): + with self.downloader: + raise KeyboardInterrupt() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_upload.py b/contrib/python/s3transfer/py3/tests/functional/test_upload.py index dcb2a65241..4f294e85ad 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_upload.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_upload.py @@ -1,538 +1,538 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import shutil -import tempfile -import time -from io import BytesIO - -from botocore.awsrequest import AWSRequest -from botocore.client import Config -from botocore.exceptions import ClientError -from botocore.stub import ANY - -from s3transfer.manager import TransferConfig, TransferManager -from s3transfer.utils import ChunksizeAdjuster -from __tests__ import ( - BaseGeneralInterfaceTest, - NonSeekableReader, - RecordingOSUtils, - RecordingSubscriber, - mock, -) - - -class BaseUploadTest(BaseGeneralInterfaceTest): - def setUp(self): - super().setUp() - # TODO: We do not want to use the real MIN_UPLOAD_CHUNKSIZE - # when we're adjusting parts. - # This is really wasteful and fails CI builds because self.contents - # would normally use 10MB+ of memory. - # Until there's an API to configure this, we're patching this with - # a min size of 1. We can't patch MIN_UPLOAD_CHUNKSIZE directly - # because it's already bound to a default value in the - # chunksize adjuster. Instead we need to patch out the - # chunksize adjuster class. - self.adjuster_patch = mock.patch( - 's3transfer.upload.ChunksizeAdjuster', - lambda: ChunksizeAdjuster(min_size=1), - ) - self.adjuster_patch.start() - self.config = TransferConfig(max_request_concurrency=1) - self._manager = TransferManager(self.client, self.config) - - # Create a temporary directory with files to read from - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - self.content = b'my content' - - with open(self.filename, 'wb') as f: - f.write(self.content) - - # Initialize some default arguments - self.bucket = 'mybucket' - self.key = 'mykey' - self.extra_args = {} - self.subscribers = [] - - # A list to keep track of all of the bodies sent over the wire - # and their order. - self.sent_bodies = [] - self.client.meta.events.register( - 'before-parameter-build.s3.*', self.collect_body - ) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.tempdir) - self.adjuster_patch.stop() - - def collect_body(self, params, model, **kwargs): - # A handler to simulate the reading of the body including the - # request-created event that signals to simulate the progress - # callbacks - if 'Body' in params: - # TODO: This is not ideal. Need to figure out a better idea of - # simulating reading of the request across the wire to trigger - # progress callbacks - request = AWSRequest( - method='PUT', - url='https://s3.amazonaws.com', - data=params['Body'], - ) - self.client.meta.events.emit( - 'request-created.s3.%s' % model.name, - request=request, - operation_name=model.name, - ) - self.sent_bodies.append(self._stream_body(params['Body'])) - - def _stream_body(self, body): - read_amt = 8 * 1024 - data = body.read(read_amt) - collected_body = data - while data: - data = body.read(read_amt) - collected_body += data - return collected_body - - @property - def manager(self): - return self._manager - - @property - def method(self): - return self.manager.upload - - def create_call_kwargs(self): - return { - 'fileobj': self.filename, - 'bucket': self.bucket, - 'key': self.key, - } - - def create_invalid_extra_args(self): - return {'Foo': 'bar'} - - def create_stubbed_responses(self): - return [{'method': 'put_object', 'service_response': {}}] - - def create_expected_progress_callback_info(self): - return [{'bytes_transferred': 10}] - - def assert_expected_client_calls_were_correct(self): - # We assert that expected client calls were made by ensuring that - # there are no more pending responses. If there are no more pending - # responses, then all stubbed responses were consumed. - self.stubber.assert_no_pending_responses() - - -class TestNonMultipartUpload(BaseUploadTest): - __test__ = True - - def add_put_object_response_with_default_expected_params( - self, extra_expected_params=None - ): - expected_params = {'Body': ANY, 'Bucket': self.bucket, 'Key': self.key} - if extra_expected_params: - expected_params.update(extra_expected_params) - upload_response = self.create_stubbed_responses()[0] - upload_response['expected_params'] = expected_params - self.stubber.add_response(**upload_response) - - def assert_put_object_body_was_correct(self): - self.assertEqual(self.sent_bodies, [self.content]) - - def test_upload(self): - self.extra_args['RequestPayer'] = 'requester' - self.add_put_object_response_with_default_expected_params( - extra_expected_params={'RequestPayer': 'requester'} - ) - future = self.manager.upload( - self.filename, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_put_object_body_was_correct() - - def test_upload_for_fileobj(self): - self.add_put_object_response_with_default_expected_params() - with open(self.filename, 'rb') as f: - future = self.manager.upload( - f, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_put_object_body_was_correct() - - def test_upload_for_seekable_filelike_obj(self): - self.add_put_object_response_with_default_expected_params() - bytes_io = BytesIO(self.content) - future = self.manager.upload( - bytes_io, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_put_object_body_was_correct() - - def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self): - self.add_put_object_response_with_default_expected_params() - bytes_io = BytesIO(self.content) - seek_pos = 5 - bytes_io.seek(seek_pos) - future = self.manager.upload( - bytes_io, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:]) - - def test_upload_for_non_seekable_filelike_obj(self): - self.add_put_object_response_with_default_expected_params() - body = NonSeekableReader(self.content) - future = self.manager.upload( - body, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_put_object_body_was_correct() - - def test_sigv4_progress_callbacks_invoked_once(self): - # Reset the client and manager to use sigv4 - self.reset_stubber_with_new_client( - {'config': Config(signature_version='s3v4')} - ) - self.client.meta.events.register( - 'before-parameter-build.s3.*', self.collect_body - ) - self._manager = TransferManager(self.client, self.config) - - # Add the stubbed response. - self.add_put_object_response_with_default_expected_params() - - subscriber = RecordingSubscriber() - future = self.manager.upload( - self.filename, self.bucket, self.key, subscribers=[subscriber] - ) - future.result() - self.assert_expected_client_calls_were_correct() - - # The amount of bytes seen should be the same as the file size - self.assertEqual(subscriber.calculate_bytes_seen(), len(self.content)) - - def test_uses_provided_osutil(self): - osutil = RecordingOSUtils() - # Use the recording os utility for the transfer manager - self._manager = TransferManager(self.client, self.config, osutil) - - self.add_put_object_response_with_default_expected_params() - - future = self.manager.upload(self.filename, self.bucket, self.key) - future.result() - - # The upload should have used the os utility. We check this by making - # sure that the recorded opens are as expected. - expected_opens = [(self.filename, 'rb')] - self.assertEqual(osutil.open_records, expected_opens) - - def test_allowed_upload_params_are_valid(self): - op_model = self.client.meta.service_model.operation_model('PutObject') - for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS: - self.assertIn(allowed_upload_arg, op_model.input_shape.members) - - def test_upload_with_bandwidth_limiter(self): - self.content = b'a' * 1024 * 1024 - with open(self.filename, 'wb') as f: - f.write(self.content) - self.config = TransferConfig( - max_request_concurrency=1, max_bandwidth=len(self.content) / 2 - ) - self._manager = TransferManager(self.client, self.config) - - self.add_put_object_response_with_default_expected_params() - start = time.time() - future = self.manager.upload(self.filename, self.bucket, self.key) - future.result() - # This is just a smoke test to make sure that the limiter is - # being used and not necessary its exactness. So we set the maximum - # bandwidth to len(content)/2 per sec and make sure that it is - # noticeably slower. Ideally it will take more than two seconds, but - # given tracking at the beginning of transfers are not entirely - # accurate setting at the initial start of a transfer, we give us - # some flexibility by setting the expected time to half of the - # theoretical time to take. - self.assertGreaterEqual(time.time() - start, 1) - - self.assert_expected_client_calls_were_correct() - self.assert_put_object_body_was_correct() - - def test_raise_exception_on_s3_object_lambda_resource(self): - s3_object_lambda_arn = ( - 'arn:aws:s3-object-lambda:us-west-2:123456789012:' - 'accesspoint:my-accesspoint' - ) - with self.assertRaisesRegex(ValueError, 'methods do not support'): - self.manager.upload(self.filename, s3_object_lambda_arn, self.key) - - -class TestMultipartUpload(BaseUploadTest): - __test__ = True - - def setUp(self): - super().setUp() - self.chunksize = 4 - self.config = TransferConfig( - max_request_concurrency=1, - multipart_threshold=1, - multipart_chunksize=self.chunksize, - ) - self._manager = TransferManager(self.client, self.config) - self.multipart_id = 'my-upload-id' - - def create_stubbed_responses(self): - return [ - { - 'method': 'create_multipart_upload', - 'service_response': {'UploadId': self.multipart_id}, - }, - {'method': 'upload_part', 'service_response': {'ETag': 'etag-1'}}, - {'method': 'upload_part', 'service_response': {'ETag': 'etag-2'}}, - {'method': 'upload_part', 'service_response': {'ETag': 'etag-3'}}, - {'method': 'complete_multipart_upload', 'service_response': {}}, - ] - - def create_expected_progress_callback_info(self): - return [ - {'bytes_transferred': 4}, - {'bytes_transferred': 4}, - {'bytes_transferred': 2}, - ] - - def assert_upload_part_bodies_were_correct(self): - expected_contents = [] - for i in range(0, len(self.content), self.chunksize): - end_i = i + self.chunksize - if end_i > len(self.content): - expected_contents.append(self.content[i:]) - else: - expected_contents.append(self.content[i:end_i]) - self.assertEqual(self.sent_bodies, expected_contents) - - def add_create_multipart_response_with_default_expected_params( - self, extra_expected_params=None - ): - expected_params = {'Bucket': self.bucket, 'Key': self.key} - if extra_expected_params: - expected_params.update(extra_expected_params) - response = self.create_stubbed_responses()[0] - response['expected_params'] = expected_params - self.stubber.add_response(**response) - - def add_upload_part_responses_with_default_expected_params( - self, extra_expected_params=None - ): - num_parts = 3 - upload_part_responses = self.create_stubbed_responses()[1:-1] - for i in range(num_parts): - upload_part_response = upload_part_responses[i] - expected_params = { - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': self.multipart_id, - 'Body': ANY, - 'PartNumber': i + 1, - } - if extra_expected_params: - expected_params.update(extra_expected_params) - upload_part_response['expected_params'] = expected_params - self.stubber.add_response(**upload_part_response) - - def add_complete_multipart_response_with_default_expected_params( - self, extra_expected_params=None - ): - expected_params = { - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': self.multipart_id, - 'MultipartUpload': { - 'Parts': [ - {'ETag': 'etag-1', 'PartNumber': 1}, - {'ETag': 'etag-2', 'PartNumber': 2}, - {'ETag': 'etag-3', 'PartNumber': 3}, - ] - }, - } - if extra_expected_params: - expected_params.update(extra_expected_params) - response = self.create_stubbed_responses()[-1] - response['expected_params'] = expected_params - self.stubber.add_response(**response) - - def test_upload(self): - self.extra_args['RequestPayer'] = 'requester' - - # Add requester pays to the create multipart upload and upload parts. - self.add_create_multipart_response_with_default_expected_params( - extra_expected_params={'RequestPayer': 'requester'} - ) - self.add_upload_part_responses_with_default_expected_params( - extra_expected_params={'RequestPayer': 'requester'} - ) - self.add_complete_multipart_response_with_default_expected_params( - extra_expected_params={'RequestPayer': 'requester'} - ) - - future = self.manager.upload( - self.filename, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - - def test_upload_for_fileobj(self): - self.add_create_multipart_response_with_default_expected_params() - self.add_upload_part_responses_with_default_expected_params() - self.add_complete_multipart_response_with_default_expected_params() - with open(self.filename, 'rb') as f: - future = self.manager.upload( - f, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_upload_part_bodies_were_correct() - - def test_upload_for_seekable_filelike_obj(self): - self.add_create_multipart_response_with_default_expected_params() - self.add_upload_part_responses_with_default_expected_params() - self.add_complete_multipart_response_with_default_expected_params() - bytes_io = BytesIO(self.content) - future = self.manager.upload( - bytes_io, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_upload_part_bodies_were_correct() - - def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self): - self.add_create_multipart_response_with_default_expected_params() - self.add_upload_part_responses_with_default_expected_params() - self.add_complete_multipart_response_with_default_expected_params() - bytes_io = BytesIO(self.content) - seek_pos = 1 - bytes_io.seek(seek_pos) - future = self.manager.upload( - bytes_io, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:]) - - def test_upload_for_non_seekable_filelike_obj(self): - self.add_create_multipart_response_with_default_expected_params() - self.add_upload_part_responses_with_default_expected_params() - self.add_complete_multipart_response_with_default_expected_params() - stream = NonSeekableReader(self.content) - future = self.manager.upload( - stream, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() - self.assert_upload_part_bodies_were_correct() - - def test_limits_in_memory_chunks_for_fileobj(self): - # Limit the maximum in memory chunks to one but make number of - # threads more than one. This means that the upload will have to - # happen sequentially despite having many threads available because - # data is sequentially partitioned into chunks in memory and since - # there can only every be one in memory chunk, each upload part will - # have to happen one at a time. - self.config.max_request_concurrency = 10 - self.config.max_in_memory_upload_chunks = 1 - self._manager = TransferManager(self.client, self.config) - - # Add some default stubbed responses. - # These responses are added in order of part number so if the - # multipart upload is not done sequentially, which it should because - # we limit the in memory upload chunks to one, the stubber will - # raise exceptions for mismatching parameters for partNumber when - # once the upload() method is called on the transfer manager. - # If there is a mismatch, the stubber error will propagate on - # the future.result() - self.add_create_multipart_response_with_default_expected_params() - self.add_upload_part_responses_with_default_expected_params() - self.add_complete_multipart_response_with_default_expected_params() - with open(self.filename, 'rb') as f: - future = self.manager.upload( - f, self.bucket, self.key, self.extra_args - ) - future.result() - - # Make sure that the stubber had all of its stubbed responses consumed. - self.assert_expected_client_calls_were_correct() - # Ensure the contents were uploaded in sequentially order by checking - # the sent contents were in order. - self.assert_upload_part_bodies_were_correct() - - def test_upload_failure_invokes_abort(self): - self.stubber.add_response( - method='create_multipart_upload', - service_response={'UploadId': self.multipart_id}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - self.stubber.add_response( - method='upload_part', - service_response={'ETag': 'etag-1'}, - expected_params={ - 'Bucket': self.bucket, - 'Body': ANY, - 'Key': self.key, - 'UploadId': self.multipart_id, - 'PartNumber': 1, - }, - ) - # With the upload part failing this should immediately initiate - # an abort multipart with no more upload parts called. - self.stubber.add_client_error(method='upload_part') - - self.stubber.add_response( - method='abort_multipart_upload', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': self.multipart_id, - }, - ) - - future = self.manager.upload(self.filename, self.bucket, self.key) - # The exception should get propagated to the future and not be - # a cancelled error or something. - with self.assertRaises(ClientError): - future.result() - self.assert_expected_client_calls_were_correct() - - def test_upload_passes_select_extra_args(self): - self.extra_args['Metadata'] = {'foo': 'bar'} - - # Add metadata to expected create multipart upload call - self.add_create_multipart_response_with_default_expected_params( - extra_expected_params={'Metadata': {'foo': 'bar'}} - ) - self.add_upload_part_responses_with_default_expected_params() - self.add_complete_multipart_response_with_default_expected_params() - - future = self.manager.upload( - self.filename, self.bucket, self.key, self.extra_args - ) - future.result() - self.assert_expected_client_calls_were_correct() +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile +import time +from io import BytesIO + +from botocore.awsrequest import AWSRequest +from botocore.client import Config +from botocore.exceptions import ClientError +from botocore.stub import ANY + +from s3transfer.manager import TransferConfig, TransferManager +from s3transfer.utils import ChunksizeAdjuster +from __tests__ import ( + BaseGeneralInterfaceTest, + NonSeekableReader, + RecordingOSUtils, + RecordingSubscriber, + mock, +) + + +class BaseUploadTest(BaseGeneralInterfaceTest): + def setUp(self): + super().setUp() + # TODO: We do not want to use the real MIN_UPLOAD_CHUNKSIZE + # when we're adjusting parts. + # This is really wasteful and fails CI builds because self.contents + # would normally use 10MB+ of memory. + # Until there's an API to configure this, we're patching this with + # a min size of 1. We can't patch MIN_UPLOAD_CHUNKSIZE directly + # because it's already bound to a default value in the + # chunksize adjuster. Instead we need to patch out the + # chunksize adjuster class. + self.adjuster_patch = mock.patch( + 's3transfer.upload.ChunksizeAdjuster', + lambda: ChunksizeAdjuster(min_size=1), + ) + self.adjuster_patch.start() + self.config = TransferConfig(max_request_concurrency=1) + self._manager = TransferManager(self.client, self.config) + + # Create a temporary directory with files to read from + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'my content' + + with open(self.filename, 'wb') as f: + f.write(self.content) + + # Initialize some default arguments + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + self.adjuster_patch.stop() + + def collect_body(self, params, model, **kwargs): + # A handler to simulate the reading of the body including the + # request-created event that signals to simulate the progress + # callbacks + if 'Body' in params: + # TODO: This is not ideal. Need to figure out a better idea of + # simulating reading of the request across the wire to trigger + # progress callbacks + request = AWSRequest( + method='PUT', + url='https://s3.amazonaws.com', + data=params['Body'], + ) + self.client.meta.events.emit( + 'request-created.s3.%s' % model.name, + request=request, + operation_name=model.name, + ) + self.sent_bodies.append(self._stream_body(params['Body'])) + + def _stream_body(self, body): + read_amt = 8 * 1024 + data = body.read(read_amt) + collected_body = data + while data: + data = body.read(read_amt) + collected_body += data + return collected_body + + @property + def manager(self): + return self._manager + + @property + def method(self): + return self.manager.upload + + def create_call_kwargs(self): + return { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + } + + def create_invalid_extra_args(self): + return {'Foo': 'bar'} + + def create_stubbed_responses(self): + return [{'method': 'put_object', 'service_response': {}}] + + def create_expected_progress_callback_info(self): + return [{'bytes_transferred': 10}] + + def assert_expected_client_calls_were_correct(self): + # We assert that expected client calls were made by ensuring that + # there are no more pending responses. If there are no more pending + # responses, then all stubbed responses were consumed. + self.stubber.assert_no_pending_responses() + + +class TestNonMultipartUpload(BaseUploadTest): + __test__ = True + + def add_put_object_response_with_default_expected_params( + self, extra_expected_params=None + ): + expected_params = {'Body': ANY, 'Bucket': self.bucket, 'Key': self.key} + if extra_expected_params: + expected_params.update(extra_expected_params) + upload_response = self.create_stubbed_responses()[0] + upload_response['expected_params'] = expected_params + self.stubber.add_response(**upload_response) + + def assert_put_object_body_was_correct(self): + self.assertEqual(self.sent_bodies, [self.content]) + + def test_upload(self): + self.extra_args['RequestPayer'] = 'requester' + self.add_put_object_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + future = self.manager.upload( + self.filename, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_upload_for_fileobj(self): + self.add_put_object_response_with_default_expected_params() + with open(self.filename, 'rb') as f: + future = self.manager.upload( + f, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_upload_for_seekable_filelike_obj(self): + self.add_put_object_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self): + self.add_put_object_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + seek_pos = 5 + bytes_io.seek(seek_pos) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:]) + + def test_upload_for_non_seekable_filelike_obj(self): + self.add_put_object_response_with_default_expected_params() + body = NonSeekableReader(self.content) + future = self.manager.upload( + body, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_sigv4_progress_callbacks_invoked_once(self): + # Reset the client and manager to use sigv4 + self.reset_stubber_with_new_client( + {'config': Config(signature_version='s3v4')} + ) + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + self._manager = TransferManager(self.client, self.config) + + # Add the stubbed response. + self.add_put_object_response_with_default_expected_params() + + subscriber = RecordingSubscriber() + future = self.manager.upload( + self.filename, self.bucket, self.key, subscribers=[subscriber] + ) + future.result() + self.assert_expected_client_calls_were_correct() + + # The amount of bytes seen should be the same as the file size + self.assertEqual(subscriber.calculate_bytes_seen(), len(self.content)) + + def test_uses_provided_osutil(self): + osutil = RecordingOSUtils() + # Use the recording os utility for the transfer manager + self._manager = TransferManager(self.client, self.config, osutil) + + self.add_put_object_response_with_default_expected_params() + + future = self.manager.upload(self.filename, self.bucket, self.key) + future.result() + + # The upload should have used the os utility. We check this by making + # sure that the recorded opens are as expected. + expected_opens = [(self.filename, 'rb')] + self.assertEqual(osutil.open_records, expected_opens) + + def test_allowed_upload_params_are_valid(self): + op_model = self.client.meta.service_model.operation_model('PutObject') + for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS: + self.assertIn(allowed_upload_arg, op_model.input_shape.members) + + def test_upload_with_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + with open(self.filename, 'wb') as f: + f.write(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content) / 2 + ) + self._manager = TransferManager(self.client, self.config) + + self.add_put_object_response_with_default_expected_params() + start = time.time() + future = self.manager.upload(self.filename, self.bucket, self.key) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + + def test_raise_exception_on_s3_object_lambda_resource(self): + s3_object_lambda_arn = ( + 'arn:aws:s3-object-lambda:us-west-2:123456789012:' + 'accesspoint:my-accesspoint' + ) + with self.assertRaisesRegex(ValueError, 'methods do not support'): + self.manager.upload(self.filename, s3_object_lambda_arn, self.key) + + +class TestMultipartUpload(BaseUploadTest): + __test__ = True + + def setUp(self): + super().setUp() + self.chunksize = 4 + self.config = TransferConfig( + max_request_concurrency=1, + multipart_threshold=1, + multipart_chunksize=self.chunksize, + ) + self._manager = TransferManager(self.client, self.config) + self.multipart_id = 'my-upload-id' + + def create_stubbed_responses(self): + return [ + { + 'method': 'create_multipart_upload', + 'service_response': {'UploadId': self.multipart_id}, + }, + {'method': 'upload_part', 'service_response': {'ETag': 'etag-1'}}, + {'method': 'upload_part', 'service_response': {'ETag': 'etag-2'}}, + {'method': 'upload_part', 'service_response': {'ETag': 'etag-3'}}, + {'method': 'complete_multipart_upload', 'service_response': {}}, + ] + + def create_expected_progress_callback_info(self): + return [ + {'bytes_transferred': 4}, + {'bytes_transferred': 4}, + {'bytes_transferred': 2}, + ] + + def assert_upload_part_bodies_were_correct(self): + expected_contents = [] + for i in range(0, len(self.content), self.chunksize): + end_i = i + self.chunksize + if end_i > len(self.content): + expected_contents.append(self.content[i:]) + else: + expected_contents.append(self.content[i:end_i]) + self.assertEqual(self.sent_bodies, expected_contents) + + def add_create_multipart_response_with_default_expected_params( + self, extra_expected_params=None + ): + expected_params = {'Bucket': self.bucket, 'Key': self.key} + if extra_expected_params: + expected_params.update(extra_expected_params) + response = self.create_stubbed_responses()[0] + response['expected_params'] = expected_params + self.stubber.add_response(**response) + + def add_upload_part_responses_with_default_expected_params( + self, extra_expected_params=None + ): + num_parts = 3 + upload_part_responses = self.create_stubbed_responses()[1:-1] + for i in range(num_parts): + upload_part_response = upload_part_responses[i] + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': self.multipart_id, + 'Body': ANY, + 'PartNumber': i + 1, + } + if extra_expected_params: + expected_params.update(extra_expected_params) + upload_part_response['expected_params'] = expected_params + self.stubber.add_response(**upload_part_response) + + def add_complete_multipart_response_with_default_expected_params( + self, extra_expected_params=None + ): + expected_params = { + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': self.multipart_id, + 'MultipartUpload': { + 'Parts': [ + {'ETag': 'etag-1', 'PartNumber': 1}, + {'ETag': 'etag-2', 'PartNumber': 2}, + {'ETag': 'etag-3', 'PartNumber': 3}, + ] + }, + } + if extra_expected_params: + expected_params.update(extra_expected_params) + response = self.create_stubbed_responses()[-1] + response['expected_params'] = expected_params + self.stubber.add_response(**response) + + def test_upload(self): + self.extra_args['RequestPayer'] = 'requester' + + # Add requester pays to the create multipart upload and upload parts. + self.add_create_multipart_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + self.add_upload_part_responses_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + self.add_complete_multipart_response_with_default_expected_params( + extra_expected_params={'RequestPayer': 'requester'} + ) + + future = self.manager.upload( + self.filename, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + + def test_upload_for_fileobj(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + with open(self.filename, 'rb') as f: + future = self.manager.upload( + f, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_upload_part_bodies_were_correct() + + def test_upload_for_seekable_filelike_obj(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_upload_part_bodies_were_correct() + + def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + bytes_io = BytesIO(self.content) + seek_pos = 1 + bytes_io.seek(seek_pos) + future = self.manager.upload( + bytes_io, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:]) + + def test_upload_for_non_seekable_filelike_obj(self): + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + stream = NonSeekableReader(self.content) + future = self.manager.upload( + stream, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() + self.assert_upload_part_bodies_were_correct() + + def test_limits_in_memory_chunks_for_fileobj(self): + # Limit the maximum in memory chunks to one but make number of + # threads more than one. This means that the upload will have to + # happen sequentially despite having many threads available because + # data is sequentially partitioned into chunks in memory and since + # there can only every be one in memory chunk, each upload part will + # have to happen one at a time. + self.config.max_request_concurrency = 10 + self.config.max_in_memory_upload_chunks = 1 + self._manager = TransferManager(self.client, self.config) + + # Add some default stubbed responses. + # These responses are added in order of part number so if the + # multipart upload is not done sequentially, which it should because + # we limit the in memory upload chunks to one, the stubber will + # raise exceptions for mismatching parameters for partNumber when + # once the upload() method is called on the transfer manager. + # If there is a mismatch, the stubber error will propagate on + # the future.result() + self.add_create_multipart_response_with_default_expected_params() + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + with open(self.filename, 'rb') as f: + future = self.manager.upload( + f, self.bucket, self.key, self.extra_args + ) + future.result() + + # Make sure that the stubber had all of its stubbed responses consumed. + self.assert_expected_client_calls_were_correct() + # Ensure the contents were uploaded in sequentially order by checking + # the sent contents were in order. + self.assert_upload_part_bodies_were_correct() + + def test_upload_failure_invokes_abort(self): + self.stubber.add_response( + method='create_multipart_upload', + service_response={'UploadId': self.multipart_id}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + method='upload_part', + service_response={'ETag': 'etag-1'}, + expected_params={ + 'Bucket': self.bucket, + 'Body': ANY, + 'Key': self.key, + 'UploadId': self.multipart_id, + 'PartNumber': 1, + }, + ) + # With the upload part failing this should immediately initiate + # an abort multipart with no more upload parts called. + self.stubber.add_client_error(method='upload_part') + + self.stubber.add_response( + method='abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': self.multipart_id, + }, + ) + + future = self.manager.upload(self.filename, self.bucket, self.key) + # The exception should get propagated to the future and not be + # a cancelled error or something. + with self.assertRaises(ClientError): + future.result() + self.assert_expected_client_calls_were_correct() + + def test_upload_passes_select_extra_args(self): + self.extra_args['Metadata'] = {'foo': 'bar'} + + # Add metadata to expected create multipart upload call + self.add_create_multipart_response_with_default_expected_params( + extra_expected_params={'Metadata': {'foo': 'bar'}} + ) + self.add_upload_part_responses_with_default_expected_params() + self.add_complete_multipart_response_with_default_expected_params() + + future = self.manager.upload( + self.filename, self.bucket, self.key, self.extra_args + ) + future.result() + self.assert_expected_client_calls_were_correct() diff --git a/contrib/python/s3transfer/py3/tests/functional/test_utils.py b/contrib/python/s3transfer/py3/tests/functional/test_utils.py index cc51b955d9..fd4a232ecc 100644 --- a/contrib/python/s3transfer/py3/tests/functional/test_utils.py +++ b/contrib/python/s3transfer/py3/tests/functional/test_utils.py @@ -1,41 +1,41 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import shutil -import socket -import tempfile - -from s3transfer.utils import OSUtils -from __tests__ import skip_if_windows, unittest - - -@skip_if_windows('Windows does not support UNIX special files') -class TestOSUtilsSpecialFiles(unittest.TestCase): - def setUp(self): - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def test_character_device(self): - self.assertTrue(OSUtils().is_special_file('/dev/null')) - - def test_fifo(self): - os.mkfifo(self.filename) - self.assertTrue(OSUtils().is_special_file(self.filename)) - - def test_socket(self): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.bind(self.filename) - self.assertTrue(OSUtils().is_special_file(self.filename)) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import socket +import tempfile + +from s3transfer.utils import OSUtils +from __tests__ import skip_if_windows, unittest + + +@skip_if_windows('Windows does not support UNIX special files') +class TestOSUtilsSpecialFiles(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_character_device(self): + self.assertTrue(OSUtils().is_special_file('/dev/null')) + + def test_fifo(self): + os.mkfifo(self.filename) + self.assertTrue(OSUtils().is_special_file(self.filename)) + + def test_socket(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.bind(self.filename) + self.assertTrue(OSUtils().is_special_file(self.filename)) diff --git a/contrib/python/s3transfer/py3/tests/unit/__init__.py b/contrib/python/s3transfer/py3/tests/unit/__init__.py index 5e791f8b92..79ef91c6a2 100644 --- a/contrib/python/s3transfer/py3/tests/unit/__init__.py +++ b/contrib/python/s3transfer/py3/tests/unit/__init__.py @@ -1,12 +1,12 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py index 0e53ec157a..b796f8f24c 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py @@ -1,452 +1,452 @@ -# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import shutil -import tempfile - -from s3transfer.bandwidth import ( - BandwidthLimitedStream, - BandwidthLimiter, - BandwidthRateTracker, - ConsumptionScheduler, - LeakyBucket, - RequestExceededException, - RequestToken, - TimeUtils, -) -from s3transfer.futures import TransferCoordinator -from __tests__ import mock, unittest - - -class FixedIncrementalTickTimeUtils(TimeUtils): - def __init__(self, seconds_per_tick=1.0): - self._count = 0 - self._seconds_per_tick = seconds_per_tick - - def time(self): - current_count = self._count - self._count += self._seconds_per_tick - return current_count - - -class TestTimeUtils(unittest.TestCase): - @mock.patch('time.time') - def test_time(self, mock_time): - mock_return_val = 1 - mock_time.return_value = mock_return_val - time_utils = TimeUtils() - self.assertEqual(time_utils.time(), mock_return_val) - - @mock.patch('time.sleep') - def test_sleep(self, mock_sleep): - time_utils = TimeUtils() - time_utils.sleep(1) - self.assertEqual(mock_sleep.call_args_list, [mock.call(1)]) - - -class BaseBandwidthLimitTest(unittest.TestCase): - def setUp(self): - self.leaky_bucket = mock.Mock(LeakyBucket) - self.time_utils = mock.Mock(TimeUtils) - self.tempdir = tempfile.mkdtemp() - self.content = b'a' * 1024 * 1024 - self.filename = os.path.join(self.tempdir, 'myfile') - with open(self.filename, 'wb') as f: - f.write(self.content) - self.coordinator = TransferCoordinator() - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def assert_consume_calls(self, amts): - expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts] - self.assertEqual( - self.leaky_bucket.consume.call_args_list, expected_consume_args - ) - - -class TestBandwidthLimiter(BaseBandwidthLimitTest): - def setUp(self): - super().setUp() - self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket) - - def test_get_bandwidth_limited_stream(self): - with open(self.filename, 'rb') as f: - stream = self.bandwidth_limiter.get_bandwith_limited_stream( - f, self.coordinator - ) - self.assertIsInstance(stream, BandwidthLimitedStream) - self.assertEqual(stream.read(len(self.content)), self.content) - self.assert_consume_calls(amts=[len(self.content)]) - - def test_get_disabled_bandwidth_limited_stream(self): - with open(self.filename, 'rb') as f: - stream = self.bandwidth_limiter.get_bandwith_limited_stream( - f, self.coordinator, enabled=False - ) - self.assertIsInstance(stream, BandwidthLimitedStream) - self.assertEqual(stream.read(len(self.content)), self.content) - self.leaky_bucket.consume.assert_not_called() - - -class TestBandwidthLimitedStream(BaseBandwidthLimitTest): - def setUp(self): - super().setUp() - self.bytes_threshold = 1 - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def get_bandwidth_limited_stream(self, f): - return BandwidthLimitedStream( - f, - self.leaky_bucket, - self.coordinator, - self.time_utils, - self.bytes_threshold, - ) - - def assert_sleep_calls(self, amts): - expected_sleep_args_list = [mock.call(amt) for amt in amts] - self.assertEqual( - self.time_utils.sleep.call_args_list, expected_sleep_args_list - ) - - def get_unique_consume_request_tokens(self): - return { - call_args[0][1] - for call_args in self.leaky_bucket.consume.call_args_list - } - - def test_read(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - data = stream.read(len(self.content)) - self.assertEqual(self.content, data) - self.assert_consume_calls(amts=[len(self.content)]) - self.assert_sleep_calls(amts=[]) - - def test_retries_on_request_exceeded(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - retry_time = 1 - amt_requested = len(self.content) - self.leaky_bucket.consume.side_effect = [ - RequestExceededException(amt_requested, retry_time), - len(self.content), - ] - data = stream.read(len(self.content)) - self.assertEqual(self.content, data) - self.assert_consume_calls(amts=[amt_requested, amt_requested]) - self.assert_sleep_calls(amts=[retry_time]) - - def test_with_transfer_coordinator_exception(self): - self.coordinator.set_exception(ValueError()) - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - with self.assertRaises(ValueError): - stream.read(len(self.content)) - - def test_read_when_bandwidth_limiting_disabled(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - stream.disable_bandwidth_limiting() - data = stream.read(len(self.content)) - self.assertEqual(self.content, data) - self.assertFalse(self.leaky_bucket.consume.called) - - def test_read_toggle_disable_enable_bandwidth_limiting(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - stream.disable_bandwidth_limiting() - data = stream.read(1) - self.assertEqual(self.content[:1], data) - self.assert_consume_calls(amts=[]) - stream.enable_bandwidth_limiting() - data = stream.read(len(self.content) - 1) - self.assertEqual(self.content[1:], data) - self.assert_consume_calls(amts=[len(self.content) - 1]) - - def test_seek(self): - mock_fileobj = mock.Mock() - stream = self.get_bandwidth_limited_stream(mock_fileobj) - stream.seek(1) - self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)]) - - def test_tell(self): - mock_fileobj = mock.Mock() - stream = self.get_bandwidth_limited_stream(mock_fileobj) - stream.tell() - self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()]) - - def test_close(self): - mock_fileobj = mock.Mock() - stream = self.get_bandwidth_limited_stream(mock_fileobj) - stream.close() - self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) - - def test_context_manager(self): - mock_fileobj = mock.Mock() - stream = self.get_bandwidth_limited_stream(mock_fileobj) - with stream as stream_handle: - self.assertIs(stream_handle, stream) - self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) - - def test_reuses_request_token(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - stream.read(1) - stream.read(1) - self.assertEqual(len(self.get_unique_consume_request_tokens()), 1) - - def test_request_tokens_unique_per_stream(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - stream.read(1) - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - stream.read(1) - self.assertEqual(len(self.get_unique_consume_request_tokens()), 2) - - def test_call_consume_after_reaching_threshold(self): - self.bytes_threshold = 2 - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - self.assertEqual(stream.read(1), self.content[:1]) - self.assert_consume_calls(amts=[]) - self.assertEqual(stream.read(1), self.content[1:2]) - self.assert_consume_calls(amts=[2]) - - def test_resets_after_reaching_threshold(self): - self.bytes_threshold = 2 - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - self.assertEqual(stream.read(2), self.content[:2]) - self.assert_consume_calls(amts=[2]) - self.assertEqual(stream.read(1), self.content[2:3]) - self.assert_consume_calls(amts=[2]) - - def test_pending_bytes_seen_on_close(self): - self.bytes_threshold = 2 - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - self.assertEqual(stream.read(1), self.content[:1]) - self.assert_consume_calls(amts=[]) - stream.close() - self.assert_consume_calls(amts=[1]) - - def test_no_bytes_remaining_on(self): - self.bytes_threshold = 2 - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - self.assertEqual(stream.read(2), self.content[:2]) - self.assert_consume_calls(amts=[2]) - stream.close() - # There should have been no more consume() calls made - # as all bytes have been accounted for in the previous - # consume() call. - self.assert_consume_calls(amts=[2]) - - def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self): - self.bytes_threshold = 2 - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - self.assertEqual(stream.read(1), self.content[:1]) - self.assert_consume_calls(amts=[]) - stream.disable_bandwidth_limiting() - stream.close() - self.assert_consume_calls(amts=[]) - - def test_signal_transferring(self): - with open(self.filename, 'rb') as f: - stream = self.get_bandwidth_limited_stream(f) - stream.signal_not_transferring() - data = stream.read(1) - self.assertEqual(self.content[:1], data) - self.assert_consume_calls(amts=[]) - stream.signal_transferring() - data = stream.read(len(self.content) - 1) - self.assertEqual(self.content[1:], data) - self.assert_consume_calls(amts=[len(self.content) - 1]) - - -class TestLeakyBucket(unittest.TestCase): - def setUp(self): - self.max_rate = 1 - self.time_now = 1.0 - self.time_utils = mock.Mock(TimeUtils) - self.time_utils.time.return_value = self.time_now - self.scheduler = mock.Mock(ConsumptionScheduler) - self.scheduler.is_scheduled.return_value = False - self.rate_tracker = mock.Mock(BandwidthRateTracker) - self.leaky_bucket = LeakyBucket( - self.max_rate, self.time_utils, self.rate_tracker, self.scheduler - ) - - def set_projected_rate(self, rate): - self.rate_tracker.get_projected_rate.return_value = rate - - def set_retry_time(self, retry_time): - self.scheduler.schedule_consumption.return_value = retry_time - - def assert_recorded_consumed_amt(self, expected_amt): - self.assertEqual( - self.rate_tracker.record_consumption_rate.call_args, - mock.call(expected_amt, self.time_utils.time.return_value), - ) - - def assert_was_scheduled(self, amt, token): - self.assertEqual( - self.scheduler.schedule_consumption.call_args, - mock.call(amt, token, amt / (self.max_rate)), - ) - - def assert_nothing_scheduled(self): - self.assertFalse(self.scheduler.schedule_consumption.called) - - def assert_processed_request_token(self, request_token): - self.assertEqual( - self.scheduler.process_scheduled_consumption.call_args, - mock.call(request_token), - ) - - def test_consume_under_max_rate(self): - amt = 1 - self.set_projected_rate(self.max_rate / 2) - self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) - self.assert_recorded_consumed_amt(amt) - self.assert_nothing_scheduled() - - def test_consume_at_max_rate(self): - amt = 1 - self.set_projected_rate(self.max_rate) - self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) - self.assert_recorded_consumed_amt(amt) - self.assert_nothing_scheduled() - - def test_consume_over_max_rate(self): - amt = 1 - retry_time = 2.0 - self.set_projected_rate(self.max_rate + 1) - self.set_retry_time(retry_time) - request_token = RequestToken() - try: - self.leaky_bucket.consume(amt, request_token) - self.fail('A RequestExceededException should have been thrown') - except RequestExceededException as e: - self.assertEqual(e.requested_amt, amt) - self.assertEqual(e.retry_time, retry_time) - self.assert_was_scheduled(amt, request_token) - - def test_consume_with_scheduled_retry(self): - amt = 1 - self.set_projected_rate(self.max_rate + 1) - self.scheduler.is_scheduled.return_value = True - request_token = RequestToken() - self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt) - # Nothing new should have been scheduled but the request token - # should have been processed. - self.assert_nothing_scheduled() - self.assert_processed_request_token(request_token) - - -class TestConsumptionScheduler(unittest.TestCase): - def setUp(self): - self.scheduler = ConsumptionScheduler() - - def test_schedule_consumption(self): - token = RequestToken() - consume_time = 5 - actual_wait_time = self.scheduler.schedule_consumption( - 1, token, consume_time - ) - self.assertEqual(consume_time, actual_wait_time) - - def test_schedule_consumption_for_multiple_requests(self): - token = RequestToken() - consume_time = 5 - actual_wait_time = self.scheduler.schedule_consumption( - 1, token, consume_time - ) - self.assertEqual(consume_time, actual_wait_time) - - other_consume_time = 3 - other_token = RequestToken() - next_wait_time = self.scheduler.schedule_consumption( - 1, other_token, other_consume_time - ) - - # This wait time should be the previous time plus its desired - # wait time - self.assertEqual(next_wait_time, consume_time + other_consume_time) - - def test_is_scheduled(self): - token = RequestToken() - consume_time = 5 - self.scheduler.schedule_consumption(1, token, consume_time) - self.assertTrue(self.scheduler.is_scheduled(token)) - - def test_is_not_scheduled(self): - self.assertFalse(self.scheduler.is_scheduled(RequestToken())) - - def test_process_scheduled_consumption(self): - token = RequestToken() - consume_time = 5 - self.scheduler.schedule_consumption(1, token, consume_time) - self.scheduler.process_scheduled_consumption(token) - self.assertFalse(self.scheduler.is_scheduled(token)) - different_time = 7 - # The previous consume time should have no affect on the next wait tim - # as it has been completed. - self.assertEqual( - self.scheduler.schedule_consumption(1, token, different_time), - different_time, - ) - - -class TestBandwidthRateTracker(unittest.TestCase): - def setUp(self): - self.alpha = 0.8 - self.rate_tracker = BandwidthRateTracker(self.alpha) - - def test_current_rate_at_initilizations(self): - self.assertEqual(self.rate_tracker.current_rate, 0.0) - - def test_current_rate_after_one_recorded_point(self): - self.rate_tracker.record_consumption_rate(1, 1) - # There is no last time point to do a diff against so return a - # current rate of 0.0 - self.assertEqual(self.rate_tracker.current_rate, 0.0) - - def test_current_rate(self): - self.rate_tracker.record_consumption_rate(1, 1) - self.rate_tracker.record_consumption_rate(1, 2) - self.rate_tracker.record_consumption_rate(1, 3) - self.assertEqual(self.rate_tracker.current_rate, 0.96) - - def test_get_projected_rate_at_initilizations(self): - self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0) - - def test_get_projected_rate(self): - self.rate_tracker.record_consumption_rate(1, 1) - self.rate_tracker.record_consumption_rate(1, 2) - projected_rate = self.rate_tracker.get_projected_rate(1, 3) - self.assertEqual(projected_rate, 0.96) - self.rate_tracker.record_consumption_rate(1, 3) - self.assertEqual(self.rate_tracker.current_rate, projected_rate) - - def test_get_projected_rate_for_same_timestamp(self): - self.rate_tracker.record_consumption_rate(1, 1) - self.assertEqual( - self.rate_tracker.get_projected_rate(1, 1), float('inf') - ) +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile + +from s3transfer.bandwidth import ( + BandwidthLimitedStream, + BandwidthLimiter, + BandwidthRateTracker, + ConsumptionScheduler, + LeakyBucket, + RequestExceededException, + RequestToken, + TimeUtils, +) +from s3transfer.futures import TransferCoordinator +from __tests__ import mock, unittest + + +class FixedIncrementalTickTimeUtils(TimeUtils): + def __init__(self, seconds_per_tick=1.0): + self._count = 0 + self._seconds_per_tick = seconds_per_tick + + def time(self): + current_count = self._count + self._count += self._seconds_per_tick + return current_count + + +class TestTimeUtils(unittest.TestCase): + @mock.patch('time.time') + def test_time(self, mock_time): + mock_return_val = 1 + mock_time.return_value = mock_return_val + time_utils = TimeUtils() + self.assertEqual(time_utils.time(), mock_return_val) + + @mock.patch('time.sleep') + def test_sleep(self, mock_sleep): + time_utils = TimeUtils() + time_utils.sleep(1) + self.assertEqual(mock_sleep.call_args_list, [mock.call(1)]) + + +class BaseBandwidthLimitTest(unittest.TestCase): + def setUp(self): + self.leaky_bucket = mock.Mock(LeakyBucket) + self.time_utils = mock.Mock(TimeUtils) + self.tempdir = tempfile.mkdtemp() + self.content = b'a' * 1024 * 1024 + self.filename = os.path.join(self.tempdir, 'myfile') + with open(self.filename, 'wb') as f: + f.write(self.content) + self.coordinator = TransferCoordinator() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def assert_consume_calls(self, amts): + expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts] + self.assertEqual( + self.leaky_bucket.consume.call_args_list, expected_consume_args + ) + + +class TestBandwidthLimiter(BaseBandwidthLimitTest): + def setUp(self): + super().setUp() + self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket) + + def test_get_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator + ) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.assert_consume_calls(amts=[len(self.content)]) + + def test_get_disabled_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator, enabled=False + ) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.leaky_bucket.consume.assert_not_called() + + +class TestBandwidthLimitedStream(BaseBandwidthLimitTest): + def setUp(self): + super().setUp() + self.bytes_threshold = 1 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def get_bandwidth_limited_stream(self, f): + return BandwidthLimitedStream( + f, + self.leaky_bucket, + self.coordinator, + self.time_utils, + self.bytes_threshold, + ) + + def assert_sleep_calls(self, amts): + expected_sleep_args_list = [mock.call(amt) for amt in amts] + self.assertEqual( + self.time_utils.sleep.call_args_list, expected_sleep_args_list + ) + + def get_unique_consume_request_tokens(self): + return { + call_args[0][1] + for call_args in self.leaky_bucket.consume.call_args_list + } + + def test_read(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[len(self.content)]) + self.assert_sleep_calls(amts=[]) + + def test_retries_on_request_exceeded(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + retry_time = 1 + amt_requested = len(self.content) + self.leaky_bucket.consume.side_effect = [ + RequestExceededException(amt_requested, retry_time), + len(self.content), + ] + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[amt_requested, amt_requested]) + self.assert_sleep_calls(amts=[retry_time]) + + def test_with_transfer_coordinator_exception(self): + self.coordinator.set_exception(ValueError()) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + with self.assertRaises(ValueError): + stream.read(len(self.content)) + + def test_read_when_bandwidth_limiting_disabled(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assertFalse(self.leaky_bucket.consume.called) + + def test_read_toggle_disable_enable_bandwidth_limiting(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.enable_bandwidth_limiting() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + def test_seek(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.seek(1) + self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)]) + + def test_tell(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.tell() + self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()]) + + def test_close(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.close() + self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) + + def test_context_manager(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + with stream as stream_handle: + self.assertIs(stream_handle, stream) + self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()]) + + def test_reuses_request_token(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 1) + + def test_request_tokens_unique_per_stream(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 2) + + def test_call_consume_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + self.assertEqual(stream.read(1), self.content[1:2]) + self.assert_consume_calls(amts=[2]) + + def test_resets_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + self.assertEqual(stream.read(1), self.content[2:3]) + self.assert_consume_calls(amts=[2]) + + def test_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.close() + self.assert_consume_calls(amts=[1]) + + def test_no_bytes_remaining_on(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + stream.close() + # There should have been no more consume() calls made + # as all bytes have been accounted for in the previous + # consume() call. + self.assert_consume_calls(amts=[2]) + + def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.disable_bandwidth_limiting() + stream.close() + self.assert_consume_calls(amts=[]) + + def test_signal_transferring(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.signal_not_transferring() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.signal_transferring() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + +class TestLeakyBucket(unittest.TestCase): + def setUp(self): + self.max_rate = 1 + self.time_now = 1.0 + self.time_utils = mock.Mock(TimeUtils) + self.time_utils.time.return_value = self.time_now + self.scheduler = mock.Mock(ConsumptionScheduler) + self.scheduler.is_scheduled.return_value = False + self.rate_tracker = mock.Mock(BandwidthRateTracker) + self.leaky_bucket = LeakyBucket( + self.max_rate, self.time_utils, self.rate_tracker, self.scheduler + ) + + def set_projected_rate(self, rate): + self.rate_tracker.get_projected_rate.return_value = rate + + def set_retry_time(self, retry_time): + self.scheduler.schedule_consumption.return_value = retry_time + + def assert_recorded_consumed_amt(self, expected_amt): + self.assertEqual( + self.rate_tracker.record_consumption_rate.call_args, + mock.call(expected_amt, self.time_utils.time.return_value), + ) + + def assert_was_scheduled(self, amt, token): + self.assertEqual( + self.scheduler.schedule_consumption.call_args, + mock.call(amt, token, amt / (self.max_rate)), + ) + + def assert_nothing_scheduled(self): + self.assertFalse(self.scheduler.schedule_consumption.called) + + def assert_processed_request_token(self, request_token): + self.assertEqual( + self.scheduler.process_scheduled_consumption.call_args, + mock.call(request_token), + ) + + def test_consume_under_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate / 2) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_at_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_over_max_rate(self): + amt = 1 + retry_time = 2.0 + self.set_projected_rate(self.max_rate + 1) + self.set_retry_time(retry_time) + request_token = RequestToken() + try: + self.leaky_bucket.consume(amt, request_token) + self.fail('A RequestExceededException should have been thrown') + except RequestExceededException as e: + self.assertEqual(e.requested_amt, amt) + self.assertEqual(e.retry_time, retry_time) + self.assert_was_scheduled(amt, request_token) + + def test_consume_with_scheduled_retry(self): + amt = 1 + self.set_projected_rate(self.max_rate + 1) + self.scheduler.is_scheduled.return_value = True + request_token = RequestToken() + self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt) + # Nothing new should have been scheduled but the request token + # should have been processed. + self.assert_nothing_scheduled() + self.assert_processed_request_token(request_token) + + +class TestConsumptionScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = ConsumptionScheduler() + + def test_schedule_consumption(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time + ) + self.assertEqual(consume_time, actual_wait_time) + + def test_schedule_consumption_for_multiple_requests(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time + ) + self.assertEqual(consume_time, actual_wait_time) + + other_consume_time = 3 + other_token = RequestToken() + next_wait_time = self.scheduler.schedule_consumption( + 1, other_token, other_consume_time + ) + + # This wait time should be the previous time plus its desired + # wait time + self.assertEqual(next_wait_time, consume_time + other_consume_time) + + def test_is_scheduled(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.assertTrue(self.scheduler.is_scheduled(token)) + + def test_is_not_scheduled(self): + self.assertFalse(self.scheduler.is_scheduled(RequestToken())) + + def test_process_scheduled_consumption(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.scheduler.process_scheduled_consumption(token) + self.assertFalse(self.scheduler.is_scheduled(token)) + different_time = 7 + # The previous consume time should have no affect on the next wait tim + # as it has been completed. + self.assertEqual( + self.scheduler.schedule_consumption(1, token, different_time), + different_time, + ) + + +class TestBandwidthRateTracker(unittest.TestCase): + def setUp(self): + self.alpha = 0.8 + self.rate_tracker = BandwidthRateTracker(self.alpha) + + def test_current_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate_after_one_recorded_point(self): + self.rate_tracker.record_consumption_rate(1, 1) + # There is no last time point to do a diff against so return a + # current rate of 0.0 + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, 0.96) + + def test_get_projected_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0) + + def test_get_projected_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + projected_rate = self.rate_tracker.get_projected_rate(1, 3) + self.assertEqual(projected_rate, 0.96) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, projected_rate) + + def test_get_projected_rate_for_same_timestamp(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.assertEqual( + self.rate_tracker.get_projected_rate(1, 1), float('inf') + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_compat.py b/contrib/python/s3transfer/py3/tests/unit/test_compat.py index 8cd9ff1b9d..78fdc25845 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_compat.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_compat.py @@ -1,105 +1,105 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import shutil -import signal -import tempfile -from io import BytesIO - -from s3transfer.compat import BaseManager, readable, seekable -from __tests__ import skip_if_windows, unittest - - -class ErrorRaisingSeekWrapper: - """An object wrapper that throws an error when seeked on - - :param fileobj: The fileobj that it wraps - :param exception: The exception to raise when seeked on. - """ - - def __init__(self, fileobj, exception): - self._fileobj = fileobj - self._exception = exception - - def seek(self, offset, whence=0): - raise self._exception - - def tell(self): - return self._fileobj.tell() - - -class TestSeekable(unittest.TestCase): - def setUp(self): - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'foo') - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def test_seekable_fileobj(self): - with open(self.filename, 'w') as f: - self.assertTrue(seekable(f)) - - def test_non_file_like_obj(self): - # Fails because there is no seekable(), seek(), nor tell() - self.assertFalse(seekable(object())) - - def test_non_seekable_ioerror(self): - # Should return False if IOError is thrown. - with open(self.filename, 'w') as f: - self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError()))) - - def test_non_seekable_oserror(self): - # Should return False if OSError is thrown. - with open(self.filename, 'w') as f: - self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError()))) - - -class TestReadable(unittest.TestCase): - def test_readable_fileobj(self): - with tempfile.TemporaryFile() as f: - self.assertTrue(readable(f)) - - def test_readable_file_like_obj(self): - self.assertTrue(readable(BytesIO())) - - def test_non_file_like_obj(self): - self.assertFalse(readable(object())) - - -class TestBaseManager(unittest.TestCase): - def create_pid_manager(self): - class PIDManager(BaseManager): - pass - - PIDManager.register('getpid', os.getpid) - return PIDManager() - - def get_pid(self, pid_manager): - pid = pid_manager.getpid() - # A proxy object is returned back. The needed value can be acquired - # from the repr and converting that to an integer - return int(str(pid)) - - @skip_if_windows('os.kill() with SIGINT not supported on Windows') - def test_can_provide_signal_handler_initializers_to_start(self): - manager = self.create_pid_manager() - manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN)) - pid = self.get_pid(manager) - try: - os.kill(pid, signal.SIGINT) - except KeyboardInterrupt: - pass - # Try using the manager after the os.kill on the parent process. The - # manager should not have died and should still be usable. - self.assertEqual(pid, self.get_pid(manager)) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import signal +import tempfile +from io import BytesIO + +from s3transfer.compat import BaseManager, readable, seekable +from __tests__ import skip_if_windows, unittest + + +class ErrorRaisingSeekWrapper: + """An object wrapper that throws an error when seeked on + + :param fileobj: The fileobj that it wraps + :param exception: The exception to raise when seeked on. + """ + + def __init__(self, fileobj, exception): + self._fileobj = fileobj + self._exception = exception + + def seek(self, offset, whence=0): + raise self._exception + + def tell(self): + return self._fileobj.tell() + + +class TestSeekable(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'foo') + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_seekable_fileobj(self): + with open(self.filename, 'w') as f: + self.assertTrue(seekable(f)) + + def test_non_file_like_obj(self): + # Fails because there is no seekable(), seek(), nor tell() + self.assertFalse(seekable(object())) + + def test_non_seekable_ioerror(self): + # Should return False if IOError is thrown. + with open(self.filename, 'w') as f: + self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError()))) + + def test_non_seekable_oserror(self): + # Should return False if OSError is thrown. + with open(self.filename, 'w') as f: + self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError()))) + + +class TestReadable(unittest.TestCase): + def test_readable_fileobj(self): + with tempfile.TemporaryFile() as f: + self.assertTrue(readable(f)) + + def test_readable_file_like_obj(self): + self.assertTrue(readable(BytesIO())) + + def test_non_file_like_obj(self): + self.assertFalse(readable(object())) + + +class TestBaseManager(unittest.TestCase): + def create_pid_manager(self): + class PIDManager(BaseManager): + pass + + PIDManager.register('getpid', os.getpid) + return PIDManager() + + def get_pid(self, pid_manager): + pid = pid_manager.getpid() + # A proxy object is returned back. The needed value can be acquired + # from the repr and converting that to an integer + return int(str(pid)) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_can_provide_signal_handler_initializers_to_start(self): + manager = self.create_pid_manager() + manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN)) + pid = self.get_pid(manager) + try: + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + pass + # Try using the manager after the os.kill on the parent process. The + # manager should not have died and should still be usable. + self.assertEqual(pid, self.get_pid(manager)) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_copies.py b/contrib/python/s3transfer/py3/tests/unit/test_copies.py index 9d43284382..3681f69b94 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_copies.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_copies.py @@ -1,177 +1,177 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from s3transfer.copies import CopyObjectTask, CopyPartTask -from __tests__ import BaseTaskTest, RecordingSubscriber - - -class BaseCopyTaskTest(BaseTaskTest): - def setUp(self): - super().setUp() - self.bucket = 'mybucket' - self.key = 'mykey' - self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} - self.extra_args = {} - self.callbacks = [] - self.size = 5 - - -class TestCopyObjectTask(BaseCopyTaskTest): - def get_copy_task(self, **kwargs): - default_kwargs = { - 'client': self.client, - 'copy_source': self.copy_source, - 'bucket': self.bucket, - 'key': self.key, - 'extra_args': self.extra_args, - 'callbacks': self.callbacks, - 'size': self.size, - } - default_kwargs.update(kwargs) - return self.get_task(CopyObjectTask, main_kwargs=default_kwargs) - - def test_main(self): - self.stubber.add_response( - 'copy_object', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - }, - ) - task = self.get_copy_task() - task() - - self.stubber.assert_no_pending_responses() - - def test_extra_args(self): - self.extra_args['ACL'] = 'private' - self.stubber.add_response( - 'copy_object', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'ACL': 'private', - }, - ) - task = self.get_copy_task() - task() - - self.stubber.assert_no_pending_responses() - - def test_callbacks_invoked(self): - subscriber = RecordingSubscriber() - self.callbacks.append(subscriber.on_progress) - self.stubber.add_response( - 'copy_object', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - }, - ) - task = self.get_copy_task() - task() - - self.stubber.assert_no_pending_responses() - self.assertEqual(subscriber.calculate_bytes_seen(), self.size) - - -class TestCopyPartTask(BaseCopyTaskTest): - def setUp(self): - super().setUp() - self.copy_source_range = 'bytes=5-9' - self.extra_args['CopySourceRange'] = self.copy_source_range - self.upload_id = 'myuploadid' - self.part_number = 1 - self.result_etag = 'my-etag' - - def get_copy_task(self, **kwargs): - default_kwargs = { - 'client': self.client, - 'copy_source': self.copy_source, - 'bucket': self.bucket, - 'key': self.key, - 'upload_id': self.upload_id, - 'part_number': self.part_number, - 'extra_args': self.extra_args, - 'callbacks': self.callbacks, - 'size': self.size, - } - default_kwargs.update(kwargs) - return self.get_task(CopyPartTask, main_kwargs=default_kwargs) - - def test_main(self): - self.stubber.add_response( - 'upload_part_copy', - service_response={'CopyPartResult': {'ETag': self.result_etag}}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'UploadId': self.upload_id, - 'PartNumber': self.part_number, - 'CopySourceRange': self.copy_source_range, - }, - ) - task = self.get_copy_task() - self.assertEqual( - task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} - ) - self.stubber.assert_no_pending_responses() - - def test_extra_args(self): - self.extra_args['RequestPayer'] = 'requester' - self.stubber.add_response( - 'upload_part_copy', - service_response={'CopyPartResult': {'ETag': self.result_etag}}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'UploadId': self.upload_id, - 'PartNumber': self.part_number, - 'CopySourceRange': self.copy_source_range, - 'RequestPayer': 'requester', - }, - ) - task = self.get_copy_task() - self.assertEqual( - task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} - ) - self.stubber.assert_no_pending_responses() - - def test_callbacks_invoked(self): - subscriber = RecordingSubscriber() - self.callbacks.append(subscriber.on_progress) - self.stubber.add_response( - 'upload_part_copy', - service_response={'CopyPartResult': {'ETag': self.result_etag}}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'CopySource': self.copy_source, - 'UploadId': self.upload_id, - 'PartNumber': self.part_number, - 'CopySourceRange': self.copy_source_range, - }, - ) - task = self.get_copy_task() - self.assertEqual( - task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} - ) - self.stubber.assert_no_pending_responses() - self.assertEqual(subscriber.calculate_bytes_seen(), self.size) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.copies import CopyObjectTask, CopyPartTask +from __tests__ import BaseTaskTest, RecordingSubscriber + + +class BaseCopyTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'} + self.extra_args = {} + self.callbacks = [] + self.size = 5 + + +class TestCopyObjectTask(BaseCopyTaskTest): + def get_copy_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'size': self.size, + } + default_kwargs.update(kwargs) + return self.get_task(CopyObjectTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['ACL'] = 'private' + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'ACL': 'private', + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_callbacks_invoked(self): + subscriber = RecordingSubscriber() + self.callbacks.append(subscriber.on_progress) + self.stubber.add_response( + 'copy_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + }, + ) + task = self.get_copy_task() + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual(subscriber.calculate_bytes_seen(), self.size) + + +class TestCopyPartTask(BaseCopyTaskTest): + def setUp(self): + super().setUp() + self.copy_source_range = 'bytes=5-9' + self.extra_args['CopySourceRange'] = self.copy_source_range + self.upload_id = 'myuploadid' + self.part_number = 1 + self.result_etag = 'my-etag' + + def get_copy_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'copy_source': self.copy_source, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': self.upload_id, + 'part_number': self.part_number, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'size': self.size, + } + default_kwargs.update(kwargs) + return self.get_task(CopyPartTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['RequestPayer'] = 'requester' + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + 'RequestPayer': 'requester', + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + + def test_callbacks_invoked(self): + subscriber = RecordingSubscriber() + self.callbacks.append(subscriber.on_progress) + self.stubber.add_response( + 'upload_part_copy', + service_response={'CopyPartResult': {'ETag': self.result_etag}}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'CopySource': self.copy_source, + 'UploadId': self.upload_id, + 'PartNumber': self.part_number, + 'CopySourceRange': self.copy_source_range, + }, + ) + task = self.get_copy_task() + self.assertEqual( + task(), {'PartNumber': self.part_number, 'ETag': self.result_etag} + ) + self.stubber.assert_no_pending_responses() + self.assertEqual(subscriber.calculate_bytes_seen(), self.size) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_crt.py b/contrib/python/s3transfer/py3/tests/unit/test_crt.py index 923b0f4c66..8c32668eab 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_crt.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_crt.py @@ -1,173 +1,173 @@ -# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from botocore.credentials import CredentialResolver, ReadOnlyCredentials -from botocore.session import Session - -from s3transfer.exceptions import TransferNotDoneError -from s3transfer.utils import CallArgs -from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest - -if HAS_CRT: - import awscrt.s3 - - import s3transfer.crt - - -class CustomFutureException(Exception): - pass - - -@requires_crt -class TestBotocoreCRTRequestSerializer(unittest.TestCase): - def setUp(self): - self.region = 'us-west-2' - self.session = Session() - self.session.set_config_variable('region', self.region) - self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( - self.session - ) - self.bucket = "test_bucket" - self.key = "test_key" - self.files = FileCreator() - self.filename = self.files.create_file('myfile', 'my content') - self.expected_path = "/" + self.bucket + "/" + self.key - self.expected_host = "s3.%s.amazonaws.com" % (self.region) - - def tearDown(self): - self.files.remove_all() - - def test_upload_request(self): - callargs = CallArgs( - bucket=self.bucket, - key=self.key, - fileobj=self.filename, - extra_args={}, - subscribers=[], - ) - coordinator = s3transfer.crt.CRTTransferCoordinator() - future = s3transfer.crt.CRTTransferFuture( - s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator - ) - crt_request = self.request_serializer.serialize_http_request( - "put_object", future - ) - self.assertEqual("PUT", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) - self.assertIsNone(crt_request.headers.get("Authorization")) - - def test_download_request(self): - callargs = CallArgs( - bucket=self.bucket, - key=self.key, - fileobj=self.filename, - extra_args={}, - subscribers=[], - ) - coordinator = s3transfer.crt.CRTTransferCoordinator() - future = s3transfer.crt.CRTTransferFuture( - s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator - ) - crt_request = self.request_serializer.serialize_http_request( - "get_object", future - ) - self.assertEqual("GET", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) - self.assertIsNone(crt_request.headers.get("Authorization")) - - def test_delete_request(self): - callargs = CallArgs( - bucket=self.bucket, key=self.key, extra_args={}, subscribers=[] - ) - coordinator = s3transfer.crt.CRTTransferCoordinator() - future = s3transfer.crt.CRTTransferFuture( - s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator - ) - crt_request = self.request_serializer.serialize_http_request( - "delete_object", future - ) - self.assertEqual("DELETE", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) - self.assertIsNone(crt_request.headers.get("Authorization")) - - -@requires_crt -class TestCRTCredentialProviderAdapter(unittest.TestCase): - def setUp(self): - self.botocore_credential_provider = mock.Mock(CredentialResolver) - self.access_key = "access_key" - self.secret_key = "secret_key" - self.token = "token" - self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( - self.access_key, self.secret_key, self.token - ) - - def _call_adapter_and_check(self, credentails_provider_adapter): - credentials = credentails_provider_adapter() - self.assertEqual(credentials.access_key_id, self.access_key) - self.assertEqual(credentials.secret_access_key, self.secret_key) - self.assertEqual(credentials.session_token, self.token) - - def test_fetch_crt_credentials_successfully(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) - ) - self._call_adapter_and_check(credentails_provider_adapter) - - def test_load_credentials_once(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) - ) - called_times = 5 - for i in range(called_times): - self._call_adapter_and_check(credentails_provider_adapter) - # Assert that the load_credentails of botocore credential provider - # will only be called once - self.assertEqual( - self.botocore_credential_provider.load_credentials.call_count, 1 - ) - - -@requires_crt -class TestCRTTransferFuture(unittest.TestCase): - def setUp(self): - self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType) - self.mock_crt_future = mock.Mock(awscrt.s3.Future) - self.mock_s3_request.finished_future = self.mock_crt_future - self.coordinator = s3transfer.crt.CRTTransferCoordinator() - self.coordinator.set_s3_request(self.mock_s3_request) - self.future = s3transfer.crt.CRTTransferFuture( - coordinator=self.coordinator - ) - - def test_set_exception(self): - self.future.set_exception(CustomFutureException()) - with self.assertRaises(CustomFutureException): - self.future.result() - - def test_set_exception_raises_error_when_not_done(self): - self.mock_crt_future.done.return_value = False - with self.assertRaises(TransferNotDoneError): - self.future.set_exception(CustomFutureException()) - - def test_set_exception_can_override_previous_exception(self): - self.future.set_exception(Exception()) - self.future.set_exception(CustomFutureException()) - with self.assertRaises(CustomFutureException): - self.future.result() +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.credentials import CredentialResolver, ReadOnlyCredentials +from botocore.session import Session + +from s3transfer.exceptions import TransferNotDoneError +from s3transfer.utils import CallArgs +from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest + +if HAS_CRT: + import awscrt.s3 + + import s3transfer.crt + + +class CustomFutureException(Exception): + pass + + +@requires_crt +class TestBotocoreCRTRequestSerializer(unittest.TestCase): + def setUp(self): + self.region = 'us-west-2' + self.session = Session() + self.session.set_config_variable('region', self.region) + self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( + self.session + ) + self.bucket = "test_bucket" + self.key = "test_key" + self.files = FileCreator() + self.filename = self.files.create_file('myfile', 'my content') + self.expected_path = "/" + self.bucket + "/" + self.key + self.expected_host = "s3.%s.amazonaws.com" % (self.region) + + def tearDown(self): + self.files.remove_all() + + def test_upload_request(self): + callargs = CallArgs( + bucket=self.bucket, + key=self.key, + fileobj=self.filename, + extra_args={}, + subscribers=[], + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "put_object", future + ) + self.assertEqual("PUT", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + def test_download_request(self): + callargs = CallArgs( + bucket=self.bucket, + key=self.key, + fileobj=self.filename, + extra_args={}, + subscribers=[], + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "get_object", future + ) + self.assertEqual("GET", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + def test_delete_request(self): + callargs = CallArgs( + bucket=self.bucket, key=self.key, extra_args={}, subscribers=[] + ) + coordinator = s3transfer.crt.CRTTransferCoordinator() + future = s3transfer.crt.CRTTransferFuture( + s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator + ) + crt_request = self.request_serializer.serialize_http_request( + "delete_object", future + ) + self.assertEqual("DELETE", crt_request.method) + self.assertEqual(self.expected_path, crt_request.path) + self.assertEqual(self.expected_host, crt_request.headers.get("host")) + self.assertIsNone(crt_request.headers.get("Authorization")) + + +@requires_crt +class TestCRTCredentialProviderAdapter(unittest.TestCase): + def setUp(self): + self.botocore_credential_provider = mock.Mock(CredentialResolver) + self.access_key = "access_key" + self.secret_key = "secret_key" + self.token = "token" + self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( + self.access_key, self.secret_key, self.token + ) + + def _call_adapter_and_check(self, credentails_provider_adapter): + credentials = credentails_provider_adapter() + self.assertEqual(credentials.access_key_id, self.access_key) + self.assertEqual(credentials.secret_access_key, self.secret_key) + self.assertEqual(credentials.session_token, self.token) + + def test_fetch_crt_credentials_successfully(self): + credentails_provider_adapter = ( + s3transfer.crt.CRTCredentialProviderAdapter( + self.botocore_credential_provider + ) + ) + self._call_adapter_and_check(credentails_provider_adapter) + + def test_load_credentials_once(self): + credentails_provider_adapter = ( + s3transfer.crt.CRTCredentialProviderAdapter( + self.botocore_credential_provider + ) + ) + called_times = 5 + for i in range(called_times): + self._call_adapter_and_check(credentails_provider_adapter) + # Assert that the load_credentails of botocore credential provider + # will only be called once + self.assertEqual( + self.botocore_credential_provider.load_credentials.call_count, 1 + ) + + +@requires_crt +class TestCRTTransferFuture(unittest.TestCase): + def setUp(self): + self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType) + self.mock_crt_future = mock.Mock(awscrt.s3.Future) + self.mock_s3_request.finished_future = self.mock_crt_future + self.coordinator = s3transfer.crt.CRTTransferCoordinator() + self.coordinator.set_s3_request(self.mock_s3_request) + self.future = s3transfer.crt.CRTTransferFuture( + coordinator=self.coordinator + ) + + def test_set_exception(self): + self.future.set_exception(CustomFutureException()) + with self.assertRaises(CustomFutureException): + self.future.result() + + def test_set_exception_raises_error_when_not_done(self): + self.mock_crt_future.done.return_value = False + with self.assertRaises(TransferNotDoneError): + self.future.set_exception(CustomFutureException()) + + def test_set_exception_can_override_previous_exception(self): + self.future.set_exception(Exception()) + self.future.set_exception(CustomFutureException()) + with self.assertRaises(CustomFutureException): + self.future.result() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_delete.py b/contrib/python/s3transfer/py3/tests/unit/test_delete.py index 5f960adbc5..23b77112f2 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_delete.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_delete.py @@ -1,67 +1,67 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from s3transfer.delete import DeleteObjectTask +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.delete import DeleteObjectTask from __tests__ import BaseTaskTest - - -class TestDeleteObjectTask(BaseTaskTest): - def setUp(self): - super().setUp() - self.bucket = 'mybucket' - self.key = 'mykey' - self.extra_args = {} - self.callbacks = [] - - def get_delete_task(self, **kwargs): - default_kwargs = { - 'client': self.client, - 'bucket': self.bucket, - 'key': self.key, - 'extra_args': self.extra_args, - } - default_kwargs.update(kwargs) - return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs) - - def test_main(self): - self.stubber.add_response( - 'delete_object', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - }, - ) - task = self.get_delete_task() - task() - - self.stubber.assert_no_pending_responses() - - def test_extra_args(self): - self.extra_args['MFA'] = 'mfa-code' - self.extra_args['VersionId'] = '12345' - self.stubber.add_response( - 'delete_object', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - # These extra_args should be injected into the - # expected params for the delete_object call. - 'MFA': 'mfa-code', - 'VersionId': '12345', - }, - ) - task = self.get_delete_task() - task() - - self.stubber.assert_no_pending_responses() + + +class TestDeleteObjectTask(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.callbacks = [] + + def get_delete_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + } + default_kwargs.update(kwargs) + return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs) + + def test_main(self): + self.stubber.add_response( + 'delete_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + }, + ) + task = self.get_delete_task() + task() + + self.stubber.assert_no_pending_responses() + + def test_extra_args(self): + self.extra_args['MFA'] = 'mfa-code' + self.extra_args['VersionId'] = '12345' + self.stubber.add_response( + 'delete_object', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + # These extra_args should be injected into the + # expected params for the delete_object call. + 'MFA': 'mfa-code', + 'VersionId': '12345', + }, + ) + task = self.get_delete_task() + task() + + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_download.py b/contrib/python/s3transfer/py3/tests/unit/test_download.py index 46f346347c..2bd095f867 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_download.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_download.py @@ -1,999 +1,999 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import copy -import os -import shutil -import tempfile -from io import BytesIO - -from s3transfer.bandwidth import BandwidthLimiter -from s3transfer.compat import SOCKET_ERROR -from s3transfer.download import ( - CompleteDownloadNOOPTask, - DeferQueue, - DownloadChunkIterator, - DownloadFilenameOutputManager, - DownloadNonSeekableOutputManager, - DownloadSeekableOutputManager, - DownloadSpecialFilenameOutputManager, - DownloadSubmissionTask, - GetObjectTask, - ImmediatelyWriteIOGetObjectTask, - IOCloseTask, - IORenameFileTask, - IOStreamingWriteTask, - IOWriteTask, -) -from s3transfer.exceptions import RetriesExceededError -from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor -from s3transfer.utils import CallArgs, OSUtils -from __tests__ import ( - BaseSubmissionTaskTest, - BaseTaskTest, - FileCreator, - NonSeekableWriter, - RecordingExecutor, - StreamWithError, - mock, - unittest, -) - - -class DownloadException(Exception): - pass - - -class WriteCollector: - """A utility to collect information about writes and seeks""" - - def __init__(self): - self._pos = 0 - self.writes = [] - - def seek(self, pos, whence=0): - self._pos = pos - - def write(self, data): - self.writes.append((self._pos, data)) - self._pos += len(data) - - -class AlwaysIndicatesSpecialFileOSUtils(OSUtils): - """OSUtil that always returns True for is_special_file""" - - def is_special_file(self, filename): - return True - - -class CancelledStreamWrapper: - """A wrapper to trigger a cancellation while stream reading - - Forces the transfer coordinator to cancel after a certain amount of reads - :param stream: The underlying stream to read from - :param transfer_coordinator: The coordinator for the transfer - :param num_reads: On which read to signal a cancellation. 0 is the first - read. - """ - - def __init__(self, stream, transfer_coordinator, num_reads=0): - self._stream = stream - self._transfer_coordinator = transfer_coordinator - self._num_reads = num_reads - self._count = 0 - - def read(self, *args, **kwargs): - if self._num_reads == self._count: - self._transfer_coordinator.cancel() - self._stream.read(*args, **kwargs) - self._count += 1 - - -class BaseDownloadOutputManagerTest(BaseTaskTest): - def setUp(self): - super().setUp() - self.osutil = OSUtils() - - # Create a file to write to - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - - self.call_args = CallArgs(fileobj=self.filename) - self.future = self.get_transfer_future(self.call_args) - self.io_executor = BoundedExecutor(1000, 1) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.tempdir) - - -class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest): - def setUp(self): - super().setUp() - self.download_output_manager = DownloadFilenameOutputManager( - self.osutil, - self.transfer_coordinator, - io_executor=self.io_executor, - ) - - def test_is_compatible(self): - self.assertTrue( - self.download_output_manager.is_compatible( - self.filename, self.osutil - ) - ) - - def test_get_download_task_tag(self): - self.assertIsNone(self.download_output_manager.get_download_task_tag()) - - def test_get_fileobj_for_io_writes(self): - with self.download_output_manager.get_fileobj_for_io_writes( - self.future - ) as f: - # Ensure it is a file like object returned - self.assertTrue(hasattr(f, 'read')) - self.assertTrue(hasattr(f, 'seek')) - # Make sure the name of the file returned is not the same as the - # final filename as we should be writing to a temporary file. - self.assertNotEqual(f.name, self.filename) - - def test_get_final_io_task(self): - ref_contents = b'my_contents' - with self.download_output_manager.get_fileobj_for_io_writes( - self.future - ) as f: - temp_filename = f.name - # Write some data to test that the data gets moved over to the - # final location. - f.write(ref_contents) - final_task = self.download_output_manager.get_final_io_task() - # Make sure it is the appropriate task. - self.assertIsInstance(final_task, IORenameFileTask) - final_task() - # Make sure the temp_file gets removed - self.assertFalse(os.path.exists(temp_filename)) - # Make sure what ever was written to the temp file got moved to - # the final filename - with open(self.filename, 'rb') as f: - self.assertEqual(f.read(), ref_contents) - - def test_can_queue_file_io_task(self): - fileobj = WriteCollector() - self.download_output_manager.queue_file_io_task( - fileobj=fileobj, data='foo', offset=0 - ) - self.download_output_manager.queue_file_io_task( - fileobj=fileobj, data='bar', offset=3 - ) - self.io_executor.shutdown() - self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) - - def test_get_file_io_write_task(self): - fileobj = WriteCollector() - io_write_task = self.download_output_manager.get_io_write_task( - fileobj=fileobj, data='foo', offset=3 - ) - self.assertIsInstance(io_write_task, IOWriteTask) - - io_write_task() - self.assertEqual(fileobj.writes, [(3, 'foo')]) - - -class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest): - def setUp(self): - super().setUp() - self.osutil = AlwaysIndicatesSpecialFileOSUtils() - self.download_output_manager = DownloadSpecialFilenameOutputManager( - self.osutil, - self.transfer_coordinator, - io_executor=self.io_executor, - ) - - def test_is_compatible_for_special_file(self): - self.assertTrue( - self.download_output_manager.is_compatible( - self.filename, AlwaysIndicatesSpecialFileOSUtils() - ) - ) - - def test_is_not_compatible_for_non_special_file(self): - self.assertFalse( - self.download_output_manager.is_compatible( - self.filename, OSUtils() - ) - ) - - def test_get_fileobj_for_io_writes(self): - with self.download_output_manager.get_fileobj_for_io_writes( - self.future - ) as f: - # Ensure it is a file like object returned - self.assertTrue(hasattr(f, 'read')) - # Make sure the name of the file returned is the same as the - # final filename as we should not be writing to a temporary file. - self.assertEqual(f.name, self.filename) - - def test_get_final_io_task(self): - self.assertIsInstance( - self.download_output_manager.get_final_io_task(), IOCloseTask - ) - - def test_can_queue_file_io_task(self): - fileobj = WriteCollector() - self.download_output_manager.queue_file_io_task( - fileobj=fileobj, data='foo', offset=0 - ) - self.download_output_manager.queue_file_io_task( - fileobj=fileobj, data='bar', offset=3 - ) - self.io_executor.shutdown() - self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) - - -class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest): - def setUp(self): - super().setUp() - self.download_output_manager = DownloadSeekableOutputManager( - self.osutil, - self.transfer_coordinator, - io_executor=self.io_executor, - ) - - # Create a fileobj to write to - self.fileobj = open(self.filename, 'wb') - - self.call_args = CallArgs(fileobj=self.fileobj) - self.future = self.get_transfer_future(self.call_args) - - def tearDown(self): - self.fileobj.close() - super().tearDown() - - def test_is_compatible(self): - self.assertTrue( - self.download_output_manager.is_compatible( - self.fileobj, self.osutil - ) - ) - - def test_is_compatible_bytes_io(self): - self.assertTrue( - self.download_output_manager.is_compatible(BytesIO(), self.osutil) - ) - - def test_not_compatible_for_non_filelike_obj(self): - self.assertFalse( - self.download_output_manager.is_compatible(object(), self.osutil) - ) - - def test_get_download_task_tag(self): - self.assertIsNone(self.download_output_manager.get_download_task_tag()) - - def test_get_fileobj_for_io_writes(self): - self.assertIs( - self.download_output_manager.get_fileobj_for_io_writes( - self.future - ), - self.fileobj, - ) - - def test_get_final_io_task(self): - self.assertIsInstance( - self.download_output_manager.get_final_io_task(), - CompleteDownloadNOOPTask, - ) - - def test_can_queue_file_io_task(self): - fileobj = WriteCollector() - self.download_output_manager.queue_file_io_task( - fileobj=fileobj, data='foo', offset=0 - ) - self.download_output_manager.queue_file_io_task( - fileobj=fileobj, data='bar', offset=3 - ) - self.io_executor.shutdown() - self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) - - def test_get_file_io_write_task(self): - fileobj = WriteCollector() - io_write_task = self.download_output_manager.get_io_write_task( - fileobj=fileobj, data='foo', offset=3 - ) - self.assertIsInstance(io_write_task, IOWriteTask) - - io_write_task() - self.assertEqual(fileobj.writes, [(3, 'foo')]) - - -class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest): - def setUp(self): - super().setUp() - self.download_output_manager = DownloadNonSeekableOutputManager( - self.osutil, self.transfer_coordinator, io_executor=None - ) - - def test_is_compatible_with_seekable_stream(self): - with open(self.filename, 'wb') as f: - self.assertTrue( - self.download_output_manager.is_compatible(f, self.osutil) - ) - - def test_not_compatible_with_filename(self): - self.assertFalse( - self.download_output_manager.is_compatible( - self.filename, self.osutil - ) - ) - - def test_compatible_with_non_seekable_stream(self): - class NonSeekable: - def write(self, data): - pass - - f = NonSeekable() - self.assertTrue( - self.download_output_manager.is_compatible(f, self.osutil) - ) - - def test_is_compatible_with_bytesio(self): - self.assertTrue( - self.download_output_manager.is_compatible(BytesIO(), self.osutil) - ) - - def test_get_download_task_tag(self): - self.assertIs( - self.download_output_manager.get_download_task_tag(), - IN_MEMORY_DOWNLOAD_TAG, - ) - - def test_submit_writes_from_internal_queue(self): - class FakeQueue: - def request_writes(self, offset, data): - return [ - {'offset': 0, 'data': 'foo'}, - {'offset': 3, 'data': 'bar'}, - ] - - q = FakeQueue() - io_executor = BoundedExecutor(1000, 1) - manager = DownloadNonSeekableOutputManager( - self.osutil, - self.transfer_coordinator, - io_executor=io_executor, - defer_queue=q, - ) - fileobj = WriteCollector() - manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1) - io_executor.shutdown() - self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) - - def test_get_file_io_write_task(self): - fileobj = WriteCollector() - io_write_task = self.download_output_manager.get_io_write_task( - fileobj=fileobj, data='foo', offset=1 - ) - self.assertIsInstance(io_write_task, IOStreamingWriteTask) - - io_write_task() - self.assertEqual(fileobj.writes, [(0, 'foo')]) - - -class TestDownloadSubmissionTask(BaseSubmissionTaskTest): - def setUp(self): - super().setUp() - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - - self.bucket = 'mybucket' - self.key = 'mykey' - self.extra_args = {} - self.subscribers = [] - - # Create a stream to read from - self.content = b'my content' - self.stream = BytesIO(self.content) - - # A list to keep track of all of the bodies sent over the wire - # and their order. - - self.call_args = self.get_call_args() - self.transfer_future = self.get_transfer_future(self.call_args) - self.io_executor = BoundedExecutor(1000, 1) - self.submission_main_kwargs = { - 'client': self.client, - 'config': self.config, - 'osutil': self.osutil, - 'request_executor': self.executor, - 'io_executor': self.io_executor, - 'transfer_future': self.transfer_future, - } - self.submission_task = self.get_download_submission_task() - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.tempdir) - - def get_call_args(self, **kwargs): - default_call_args = { - 'fileobj': self.filename, - 'bucket': self.bucket, - 'key': self.key, - 'extra_args': self.extra_args, - 'subscribers': self.subscribers, - } - default_call_args.update(kwargs) - return CallArgs(**default_call_args) - - def wrap_executor_in_recorder(self): - self.executor = RecordingExecutor(self.executor) - self.submission_main_kwargs['request_executor'] = self.executor - - def use_fileobj_in_call_args(self, fileobj): - self.call_args = self.get_call_args(fileobj=fileobj) - self.transfer_future = self.get_transfer_future(self.call_args) - self.submission_main_kwargs['transfer_future'] = self.transfer_future - - def assert_tag_for_get_object(self, tag_value): - submissions_to_compare = self.executor.submissions - if len(submissions_to_compare) > 1: - # If it was ranged get, make sure we do not include the join task. - submissions_to_compare = submissions_to_compare[:-1] - for submission in submissions_to_compare: - self.assertEqual(submission['tag'], tag_value) - - def add_head_object_response(self): - self.stubber.add_response( - 'head_object', {'ContentLength': len(self.content)} - ) - - def add_get_responses(self): - chunksize = self.config.multipart_chunksize - for i in range(0, len(self.content), chunksize): - if i + chunksize > len(self.content): - stream = BytesIO(self.content[i:]) - self.stubber.add_response('get_object', {'Body': stream}) - else: - stream = BytesIO(self.content[i : i + chunksize]) - self.stubber.add_response('get_object', {'Body': stream}) - - def configure_for_ranged_get(self): - self.config.multipart_threshold = 1 - self.config.multipart_chunksize = 4 - - def get_download_submission_task(self): - return self.get_task( - DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - - def wait_and_assert_completed_successfully(self, submission_task): - submission_task() - self.transfer_future.result() - self.stubber.assert_no_pending_responses() - - def test_submits_no_tag_for_get_object_filename(self): - self.wrap_executor_in_recorder() - self.add_head_object_response() - self.add_get_responses() - - self.submission_task = self.get_download_submission_task() - self.wait_and_assert_completed_successfully(self.submission_task) - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_for_get_object(None) - - def test_submits_no_tag_for_ranged_get_filename(self): - self.wrap_executor_in_recorder() - self.configure_for_ranged_get() - self.add_head_object_response() - self.add_get_responses() - - self.submission_task = self.get_download_submission_task() - self.wait_and_assert_completed_successfully(self.submission_task) - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_for_get_object(None) - - def test_submits_no_tag_for_get_object_fileobj(self): - self.wrap_executor_in_recorder() - self.add_head_object_response() - self.add_get_responses() - - with open(self.filename, 'wb') as f: - self.use_fileobj_in_call_args(f) - self.submission_task = self.get_download_submission_task() - self.wait_and_assert_completed_successfully(self.submission_task) - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_for_get_object(None) - - def test_submits_no_tag_for_ranged_get_object_fileobj(self): - self.wrap_executor_in_recorder() - self.configure_for_ranged_get() - self.add_head_object_response() - self.add_get_responses() - - with open(self.filename, 'wb') as f: - self.use_fileobj_in_call_args(f) - self.submission_task = self.get_download_submission_task() - self.wait_and_assert_completed_successfully(self.submission_task) - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_for_get_object(None) - - def tests_submits_tag_for_get_object_nonseekable_fileobj(self): - self.wrap_executor_in_recorder() - self.add_head_object_response() - self.add_get_responses() - - with open(self.filename, 'wb') as f: - self.use_fileobj_in_call_args(NonSeekableWriter(f)) - self.submission_task = self.get_download_submission_task() - self.wait_and_assert_completed_successfully(self.submission_task) - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) - - def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): - self.wrap_executor_in_recorder() - self.configure_for_ranged_get() - self.add_head_object_response() - self.add_get_responses() - - with open(self.filename, 'wb') as f: - self.use_fileobj_in_call_args(NonSeekableWriter(f)) - self.submission_task = self.get_download_submission_task() - self.wait_and_assert_completed_successfully(self.submission_task) - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) - - -class TestGetObjectTask(BaseTaskTest): - def setUp(self): - super().setUp() - self.bucket = 'mybucket' - self.key = 'mykey' - self.extra_args = {} - self.callbacks = [] - self.max_attempts = 5 - self.io_executor = BoundedExecutor(1000, 1) - self.content = b'my content' - self.stream = BytesIO(self.content) - self.fileobj = WriteCollector() - self.osutil = OSUtils() - self.io_chunksize = 64 * (1024 ** 2) - self.task_cls = GetObjectTask - self.download_output_manager = DownloadSeekableOutputManager( - self.osutil, self.transfer_coordinator, self.io_executor - ) - - def get_download_task(self, **kwargs): - default_kwargs = { - 'client': self.client, - 'bucket': self.bucket, - 'key': self.key, - 'fileobj': self.fileobj, - 'extra_args': self.extra_args, - 'callbacks': self.callbacks, - 'max_attempts': self.max_attempts, - 'download_output_manager': self.download_output_manager, - 'io_chunksize': self.io_chunksize, - } - default_kwargs.update(kwargs) - self.transfer_coordinator.set_status_to_queued() - return self.get_task(self.task_cls, main_kwargs=default_kwargs) - - def assert_io_writes(self, expected_writes): - # Let the io executor process all of the writes before checking - # what writes were sent to it. - self.io_executor.shutdown() - self.assertEqual(self.fileobj.writes, expected_writes) - - def test_main(self): - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task() - task() - - self.stubber.assert_no_pending_responses() - self.assert_io_writes([(0, self.content)]) - - def test_extra_args(self): - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'Range': 'bytes=0-', - }, - ) - self.extra_args['Range'] = 'bytes=0-' - task = self.get_download_task() - task() - - self.stubber.assert_no_pending_responses() - self.assert_io_writes([(0, self.content)]) - - def test_control_chunk_size(self): - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task(io_chunksize=1) - task() - - self.stubber.assert_no_pending_responses() - expected_contents = [] - for i in range(len(self.content)): - expected_contents.append((i, bytes(self.content[i : i + 1]))) - - self.assert_io_writes(expected_contents) - - def test_start_index(self): - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task(start_index=5) - task() - - self.stubber.assert_no_pending_responses() - self.assert_io_writes([(5, self.content)]) - - def test_uses_bandwidth_limiter(self): - bandwidth_limiter = mock.Mock(BandwidthLimiter) - - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task(bandwidth_limiter=bandwidth_limiter) - task() - - self.stubber.assert_no_pending_responses() - self.assertEqual( - bandwidth_limiter.get_bandwith_limited_stream.call_args_list, - [mock.call(mock.ANY, self.transfer_coordinator)], - ) - - def test_retries_succeeds(self): - self.stubber.add_response( - 'get_object', - service_response={ - 'Body': StreamWithError(self.stream, SOCKET_ERROR) - }, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task() - task() - - # Retryable error should have not affected the bytes placed into - # the io queue. - self.stubber.assert_no_pending_responses() - self.assert_io_writes([(0, self.content)]) - - def test_retries_failure(self): - for _ in range(self.max_attempts): - self.stubber.add_response( - 'get_object', - service_response={ - 'Body': StreamWithError(self.stream, SOCKET_ERROR) - }, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - - task = self.get_download_task() - task() - self.transfer_coordinator.announce_done() - - # Should have failed out on a RetriesExceededError - with self.assertRaises(RetriesExceededError): - self.transfer_coordinator.result() - self.stubber.assert_no_pending_responses() - - def test_retries_in_middle_of_streaming(self): - # After the first read a retryable error will be thrown - self.stubber.add_response( - 'get_object', - service_response={ - 'Body': StreamWithError( - copy.deepcopy(self.stream), SOCKET_ERROR, 1 - ) - }, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - self.stubber.add_response( - 'get_object', - service_response={'Body': self.stream}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task(io_chunksize=1) - task() - - self.stubber.assert_no_pending_responses() - expected_contents = [] - # This is the content initially read in before the retry hit on the - # second read() - expected_contents.append((0, bytes(self.content[0:1]))) - - # The rest of the content should be the entire set of data partitioned - # out based on the one byte stream chunk size. Note the second - # element in the list should be a copy of the first element since - # a retryable exception happened in between. - for i in range(len(self.content)): - expected_contents.append((i, bytes(self.content[i : i + 1]))) - self.assert_io_writes(expected_contents) - - def test_cancels_out_of_queueing(self): - self.stubber.add_response( - 'get_object', - service_response={ - 'Body': CancelledStreamWrapper( - self.stream, self.transfer_coordinator - ) - }, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - task = self.get_download_task() - task() - - self.stubber.assert_no_pending_responses() - # Make sure that no contents were added to the queue because the task - # should have been canceled before trying to add the contents to the - # io queue. - self.assert_io_writes([]) - - def test_handles_callback_on_initial_error(self): - # We can't use the stubber for this because we need to raise - # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows - # you to raise a ClientError. - self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR()) - task = self.get_download_task() - task() - self.transfer_coordinator.announce_done() - # Should have failed out on a RetriesExceededError because - # get_object keeps raising a socket error. - with self.assertRaises(RetriesExceededError): - self.transfer_coordinator.result() - - -class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask): - def setUp(self): - super().setUp() - self.task_cls = ImmediatelyWriteIOGetObjectTask - # When data is written out, it should not use the io executor at all - # if it does use the io executor that is a deviation from expected - # behavior as the data should be written immediately to the file - # object once downloaded. - self.io_executor = None - self.download_output_manager = DownloadSeekableOutputManager( - self.osutil, self.transfer_coordinator, self.io_executor - ) - - def assert_io_writes(self, expected_writes): - self.assertEqual(self.fileobj.writes, expected_writes) - - -class BaseIOTaskTest(BaseTaskTest): - def setUp(self): - super().setUp() - self.files = FileCreator() - self.osutil = OSUtils() - self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile') - self.final_filename = os.path.join(self.files.rootdir, 'myfile') - - def tearDown(self): - super().tearDown() - self.files.remove_all() - - -class TestIOStreamingWriteTask(BaseIOTaskTest): - def test_main(self): - with open(self.temp_filename, 'wb') as f: - task = self.get_task( - IOStreamingWriteTask, - main_kwargs={'fileobj': f, 'data': b'foobar'}, - ) - task() - task2 = self.get_task( - IOStreamingWriteTask, - main_kwargs={'fileobj': f, 'data': b'baz'}, - ) - task2() - with open(self.temp_filename, 'rb') as f: - # We should just have written to the file in the order - # the tasks were executed. - self.assertEqual(f.read(), b'foobarbaz') - - -class TestIOWriteTask(BaseIOTaskTest): - def test_main(self): - with open(self.temp_filename, 'wb') as f: - # Write once to the file - task = self.get_task( - IOWriteTask, - main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0}, - ) - task() - - # Write again to the file - task = self.get_task( - IOWriteTask, - main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3}, - ) - task() - - with open(self.temp_filename, 'rb') as f: - self.assertEqual(f.read(), b'foobar') - - -class TestIORenameFileTask(BaseIOTaskTest): - def test_main(self): - with open(self.temp_filename, 'wb') as f: - task = self.get_task( - IORenameFileTask, - main_kwargs={ - 'fileobj': f, - 'final_filename': self.final_filename, - 'osutil': self.osutil, - }, - ) - task() - self.assertTrue(os.path.exists(self.final_filename)) - self.assertFalse(os.path.exists(self.temp_filename)) - - -class TestIOCloseTask(BaseIOTaskTest): - def test_main(self): - with open(self.temp_filename, 'w') as f: - task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f}) - task() - self.assertTrue(f.closed) - - -class TestDownloadChunkIterator(unittest.TestCase): - def test_iter(self): - content = b'my content' - body = BytesIO(content) - ref_chunks = [] - for chunk in DownloadChunkIterator(body, len(content)): - ref_chunks.append(chunk) - self.assertEqual(ref_chunks, [b'my content']) - - def test_iter_chunksize(self): - content = b'1234' - body = BytesIO(content) - ref_chunks = [] - for chunk in DownloadChunkIterator(body, 3): - ref_chunks.append(chunk) - self.assertEqual(ref_chunks, [b'123', b'4']) - - def test_empty_content(self): - body = BytesIO(b'') - ref_chunks = [] - for chunk in DownloadChunkIterator(body, 3): - ref_chunks.append(chunk) - self.assertEqual(ref_chunks, [b'']) - - -class TestDeferQueue(unittest.TestCase): - def setUp(self): - self.q = DeferQueue() - - def test_no_writes_when_not_lowest_block(self): - writes = self.q.request_writes(offset=1, data='bar') - self.assertEqual(writes, []) - - def test_writes_returned_in_order(self): - self.assertEqual(self.q.request_writes(offset=3, data='d'), []) - self.assertEqual(self.q.request_writes(offset=2, data='c'), []) - self.assertEqual(self.q.request_writes(offset=1, data='b'), []) - - # Everything at this point has been deferred, but as soon as we - # send offset=0, that will unlock offsets 0-3. - writes = self.q.request_writes(offset=0, data='a') - self.assertEqual( - writes, - [ - {'offset': 0, 'data': 'a'}, - {'offset': 1, 'data': 'b'}, - {'offset': 2, 'data': 'c'}, - {'offset': 3, 'data': 'd'}, - ], - ) - - def test_unlocks_partial_range(self): - self.assertEqual(self.q.request_writes(offset=5, data='f'), []) - self.assertEqual(self.q.request_writes(offset=1, data='b'), []) - - # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first. - writes = self.q.request_writes(offset=0, data='a') - self.assertEqual( - writes, - [ - {'offset': 0, 'data': 'a'}, - {'offset': 1, 'data': 'b'}, - ], - ) - - def test_data_can_be_any_size(self): - self.q.request_writes(offset=5, data='hello world') - writes = self.q.request_writes(offset=0, data='abcde') - self.assertEqual( - writes, - [ - {'offset': 0, 'data': 'abcde'}, - {'offset': 5, 'data': 'hello world'}, - ], - ) - - def test_data_queued_in_order(self): - # This immediately gets returned because offset=0 is the - # next range we're waiting on. - writes = self.q.request_writes(offset=0, data='hello world') - self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}]) - # Same thing here but with offset - writes = self.q.request_writes(offset=11, data='hello again') - self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}]) - - def test_writes_below_min_offset_are_ignored(self): - self.q.request_writes(offset=0, data='a') - self.q.request_writes(offset=1, data='b') - self.q.request_writes(offset=2, data='c') - - # At this point we're expecting offset=3, so if a write - # comes in below 3, we ignore it. - self.assertEqual(self.q.request_writes(offset=0, data='a'), []) - self.assertEqual(self.q.request_writes(offset=1, data='b'), []) - - self.assertEqual( - self.q.request_writes(offset=3, data='d'), - [{'offset': 3, 'data': 'd'}], - ) - - def test_duplicate_writes_are_ignored(self): - self.q.request_writes(offset=2, data='c') - self.q.request_writes(offset=1, data='b') - - # We're still waiting for offset=0, but if - # a duplicate write comes in for offset=2/offset=1 - # it's ignored. This gives "first one wins" behavior. - self.assertEqual(self.q.request_writes(offset=2, data='X'), []) - self.assertEqual(self.q.request_writes(offset=1, data='Y'), []) - - self.assertEqual( - self.q.request_writes(offset=0, data='a'), - [ - {'offset': 0, 'data': 'a'}, - # Note we're seeing 'b' 'c', and not 'X', 'Y'. - {'offset': 1, 'data': 'b'}, - {'offset': 2, 'data': 'c'}, - ], - ) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import copy +import os +import shutil +import tempfile +from io import BytesIO + +from s3transfer.bandwidth import BandwidthLimiter +from s3transfer.compat import SOCKET_ERROR +from s3transfer.download import ( + CompleteDownloadNOOPTask, + DeferQueue, + DownloadChunkIterator, + DownloadFilenameOutputManager, + DownloadNonSeekableOutputManager, + DownloadSeekableOutputManager, + DownloadSpecialFilenameOutputManager, + DownloadSubmissionTask, + GetObjectTask, + ImmediatelyWriteIOGetObjectTask, + IOCloseTask, + IORenameFileTask, + IOStreamingWriteTask, + IOWriteTask, +) +from s3transfer.exceptions import RetriesExceededError +from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor +from s3transfer.utils import CallArgs, OSUtils +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileCreator, + NonSeekableWriter, + RecordingExecutor, + StreamWithError, + mock, + unittest, +) + + +class DownloadException(Exception): + pass + + +class WriteCollector: + """A utility to collect information about writes and seeks""" + + def __init__(self): + self._pos = 0 + self.writes = [] + + def seek(self, pos, whence=0): + self._pos = pos + + def write(self, data): + self.writes.append((self._pos, data)) + self._pos += len(data) + + +class AlwaysIndicatesSpecialFileOSUtils(OSUtils): + """OSUtil that always returns True for is_special_file""" + + def is_special_file(self, filename): + return True + + +class CancelledStreamWrapper: + """A wrapper to trigger a cancellation while stream reading + + Forces the transfer coordinator to cancel after a certain amount of reads + :param stream: The underlying stream to read from + :param transfer_coordinator: The coordinator for the transfer + :param num_reads: On which read to signal a cancellation. 0 is the first + read. + """ + + def __init__(self, stream, transfer_coordinator, num_reads=0): + self._stream = stream + self._transfer_coordinator = transfer_coordinator + self._num_reads = num_reads + self._count = 0 + + def read(self, *args, **kwargs): + if self._num_reads == self._count: + self._transfer_coordinator.cancel() + self._stream.read(*args, **kwargs) + self._count += 1 + + +class BaseDownloadOutputManagerTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.osutil = OSUtils() + + # Create a file to write to + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + self.call_args = CallArgs(fileobj=self.filename) + self.future = self.get_transfer_future(self.call_args) + self.io_executor = BoundedExecutor(1000, 1) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + +class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadFilenameOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + def test_is_compatible(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.filename, self.osutil + ) + ) + + def test_get_download_task_tag(self): + self.assertIsNone(self.download_output_manager.get_download_task_tag()) + + def test_get_fileobj_for_io_writes(self): + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + # Ensure it is a file like object returned + self.assertTrue(hasattr(f, 'read')) + self.assertTrue(hasattr(f, 'seek')) + # Make sure the name of the file returned is not the same as the + # final filename as we should be writing to a temporary file. + self.assertNotEqual(f.name, self.filename) + + def test_get_final_io_task(self): + ref_contents = b'my_contents' + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + temp_filename = f.name + # Write some data to test that the data gets moved over to the + # final location. + f.write(ref_contents) + final_task = self.download_output_manager.get_final_io_task() + # Make sure it is the appropriate task. + self.assertIsInstance(final_task, IORenameFileTask) + final_task() + # Make sure the temp_file gets removed + self.assertFalse(os.path.exists(temp_filename)) + # Make sure what ever was written to the temp file got moved to + # the final filename + with open(self.filename, 'rb') as f: + self.assertEqual(f.read(), ref_contents) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=3 + ) + self.assertIsInstance(io_write_task, IOWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(3, 'foo')]) + + +class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.osutil = AlwaysIndicatesSpecialFileOSUtils() + self.download_output_manager = DownloadSpecialFilenameOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + def test_is_compatible_for_special_file(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.filename, AlwaysIndicatesSpecialFileOSUtils() + ) + ) + + def test_is_not_compatible_for_non_special_file(self): + self.assertFalse( + self.download_output_manager.is_compatible( + self.filename, OSUtils() + ) + ) + + def test_get_fileobj_for_io_writes(self): + with self.download_output_manager.get_fileobj_for_io_writes( + self.future + ) as f: + # Ensure it is a file like object returned + self.assertTrue(hasattr(f, 'read')) + # Make sure the name of the file returned is the same as the + # final filename as we should not be writing to a temporary file. + self.assertEqual(f.name, self.filename) + + def test_get_final_io_task(self): + self.assertIsInstance( + self.download_output_manager.get_final_io_task(), IOCloseTask + ) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + +class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=self.io_executor, + ) + + # Create a fileobj to write to + self.fileobj = open(self.filename, 'wb') + + self.call_args = CallArgs(fileobj=self.fileobj) + self.future = self.get_transfer_future(self.call_args) + + def tearDown(self): + self.fileobj.close() + super().tearDown() + + def test_is_compatible(self): + self.assertTrue( + self.download_output_manager.is_compatible( + self.fileobj, self.osutil + ) + ) + + def test_is_compatible_bytes_io(self): + self.assertTrue( + self.download_output_manager.is_compatible(BytesIO(), self.osutil) + ) + + def test_not_compatible_for_non_filelike_obj(self): + self.assertFalse( + self.download_output_manager.is_compatible(object(), self.osutil) + ) + + def test_get_download_task_tag(self): + self.assertIsNone(self.download_output_manager.get_download_task_tag()) + + def test_get_fileobj_for_io_writes(self): + self.assertIs( + self.download_output_manager.get_fileobj_for_io_writes( + self.future + ), + self.fileobj, + ) + + def test_get_final_io_task(self): + self.assertIsInstance( + self.download_output_manager.get_final_io_task(), + CompleteDownloadNOOPTask, + ) + + def test_can_queue_file_io_task(self): + fileobj = WriteCollector() + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='foo', offset=0 + ) + self.download_output_manager.queue_file_io_task( + fileobj=fileobj, data='bar', offset=3 + ) + self.io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=3 + ) + self.assertIsInstance(io_write_task, IOWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(3, 'foo')]) + + +class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest): + def setUp(self): + super().setUp() + self.download_output_manager = DownloadNonSeekableOutputManager( + self.osutil, self.transfer_coordinator, io_executor=None + ) + + def test_is_compatible_with_seekable_stream(self): + with open(self.filename, 'wb') as f: + self.assertTrue( + self.download_output_manager.is_compatible(f, self.osutil) + ) + + def test_not_compatible_with_filename(self): + self.assertFalse( + self.download_output_manager.is_compatible( + self.filename, self.osutil + ) + ) + + def test_compatible_with_non_seekable_stream(self): + class NonSeekable: + def write(self, data): + pass + + f = NonSeekable() + self.assertTrue( + self.download_output_manager.is_compatible(f, self.osutil) + ) + + def test_is_compatible_with_bytesio(self): + self.assertTrue( + self.download_output_manager.is_compatible(BytesIO(), self.osutil) + ) + + def test_get_download_task_tag(self): + self.assertIs( + self.download_output_manager.get_download_task_tag(), + IN_MEMORY_DOWNLOAD_TAG, + ) + + def test_submit_writes_from_internal_queue(self): + class FakeQueue: + def request_writes(self, offset, data): + return [ + {'offset': 0, 'data': 'foo'}, + {'offset': 3, 'data': 'bar'}, + ] + + q = FakeQueue() + io_executor = BoundedExecutor(1000, 1) + manager = DownloadNonSeekableOutputManager( + self.osutil, + self.transfer_coordinator, + io_executor=io_executor, + defer_queue=q, + ) + fileobj = WriteCollector() + manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1) + io_executor.shutdown() + self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')]) + + def test_get_file_io_write_task(self): + fileobj = WriteCollector() + io_write_task = self.download_output_manager.get_io_write_task( + fileobj=fileobj, data='foo', offset=1 + ) + self.assertIsInstance(io_write_task, IOStreamingWriteTask) + + io_write_task() + self.assertEqual(fileobj.writes, [(0, 'foo')]) + + +class TestDownloadSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # Create a stream to read from + self.content = b'my content' + self.stream = BytesIO(self.content) + + # A list to keep track of all of the bodies sent over the wire + # and their order. + + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + self.io_executor = BoundedExecutor(1000, 1) + self.submission_main_kwargs = { + 'client': self.client, + 'config': self.config, + 'osutil': self.osutil, + 'request_executor': self.executor, + 'io_executor': self.io_executor, + 'transfer_future': self.transfer_future, + } + self.submission_task = self.get_download_submission_task() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def get_call_args(self, **kwargs): + default_call_args = { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'subscribers': self.subscribers, + } + default_call_args.update(kwargs) + return CallArgs(**default_call_args) + + def wrap_executor_in_recorder(self): + self.executor = RecordingExecutor(self.executor) + self.submission_main_kwargs['request_executor'] = self.executor + + def use_fileobj_in_call_args(self, fileobj): + self.call_args = self.get_call_args(fileobj=fileobj) + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs['transfer_future'] = self.transfer_future + + def assert_tag_for_get_object(self, tag_value): + submissions_to_compare = self.executor.submissions + if len(submissions_to_compare) > 1: + # If it was ranged get, make sure we do not include the join task. + submissions_to_compare = submissions_to_compare[:-1] + for submission in submissions_to_compare: + self.assertEqual(submission['tag'], tag_value) + + def add_head_object_response(self): + self.stubber.add_response( + 'head_object', {'ContentLength': len(self.content)} + ) + + def add_get_responses(self): + chunksize = self.config.multipart_chunksize + for i in range(0, len(self.content), chunksize): + if i + chunksize > len(self.content): + stream = BytesIO(self.content[i:]) + self.stubber.add_response('get_object', {'Body': stream}) + else: + stream = BytesIO(self.content[i : i + chunksize]) + self.stubber.add_response('get_object', {'Body': stream}) + + def configure_for_ranged_get(self): + self.config.multipart_threshold = 1 + self.config.multipart_chunksize = 4 + + def get_download_submission_task(self): + return self.get_task( + DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + + def wait_and_assert_completed_successfully(self, submission_task): + submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + def test_submits_no_tag_for_get_object_filename(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_ranged_get_filename(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_get_object_fileobj(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def test_submits_no_tag_for_ranged_get_object_fileobj(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(None) + + def tests_submits_tag_for_get_object_nonseekable_fileobj(self): + self.wrap_executor_in_recorder() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + + def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self): + self.wrap_executor_in_recorder() + self.configure_for_ranged_get() + self.add_head_object_response() + self.add_get_responses() + + with open(self.filename, 'wb') as f: + self.use_fileobj_in_call_args(NonSeekableWriter(f)) + self.submission_task = self.get_download_submission_task() + self.wait_and_assert_completed_successfully(self.submission_task) + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG) + + +class TestGetObjectTask(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.callbacks = [] + self.max_attempts = 5 + self.io_executor = BoundedExecutor(1000, 1) + self.content = b'my content' + self.stream = BytesIO(self.content) + self.fileobj = WriteCollector() + self.osutil = OSUtils() + self.io_chunksize = 64 * (1024 ** 2) + self.task_cls = GetObjectTask + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, self.transfer_coordinator, self.io_executor + ) + + def get_download_task(self, **kwargs): + default_kwargs = { + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'fileobj': self.fileobj, + 'extra_args': self.extra_args, + 'callbacks': self.callbacks, + 'max_attempts': self.max_attempts, + 'download_output_manager': self.download_output_manager, + 'io_chunksize': self.io_chunksize, + } + default_kwargs.update(kwargs) + self.transfer_coordinator.set_status_to_queued() + return self.get_task(self.task_cls, main_kwargs=default_kwargs) + + def assert_io_writes(self, expected_writes): + # Let the io executor process all of the writes before checking + # what writes were sent to it. + self.io_executor.shutdown() + self.assertEqual(self.fileobj.writes, expected_writes) + + def test_main(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_extra_args(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Range': 'bytes=0-', + }, + ) + self.extra_args['Range'] = 'bytes=0-' + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_control_chunk_size(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(io_chunksize=1) + task() + + self.stubber.assert_no_pending_responses() + expected_contents = [] + for i in range(len(self.content)): + expected_contents.append((i, bytes(self.content[i : i + 1]))) + + self.assert_io_writes(expected_contents) + + def test_start_index(self): + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(start_index=5) + task() + + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(5, self.content)]) + + def test_uses_bandwidth_limiter(self): + bandwidth_limiter = mock.Mock(BandwidthLimiter) + + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(bandwidth_limiter=bandwidth_limiter) + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual( + bandwidth_limiter.get_bandwith_limited_stream.call_args_list, + [mock.call(mock.ANY, self.transfer_coordinator)], + ) + + def test_retries_succeeds(self): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError(self.stream, SOCKET_ERROR) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + # Retryable error should have not affected the bytes placed into + # the io queue. + self.stubber.assert_no_pending_responses() + self.assert_io_writes([(0, self.content)]) + + def test_retries_failure(self): + for _ in range(self.max_attempts): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError(self.stream, SOCKET_ERROR) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + + task = self.get_download_task() + task() + self.transfer_coordinator.announce_done() + + # Should have failed out on a RetriesExceededError + with self.assertRaises(RetriesExceededError): + self.transfer_coordinator.result() + self.stubber.assert_no_pending_responses() + + def test_retries_in_middle_of_streaming(self): + # After the first read a retryable error will be thrown + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': StreamWithError( + copy.deepcopy(self.stream), SOCKET_ERROR, 1 + ) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.stubber.add_response( + 'get_object', + service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task(io_chunksize=1) + task() + + self.stubber.assert_no_pending_responses() + expected_contents = [] + # This is the content initially read in before the retry hit on the + # second read() + expected_contents.append((0, bytes(self.content[0:1]))) + + # The rest of the content should be the entire set of data partitioned + # out based on the one byte stream chunk size. Note the second + # element in the list should be a copy of the first element since + # a retryable exception happened in between. + for i in range(len(self.content)): + expected_contents.append((i, bytes(self.content[i : i + 1]))) + self.assert_io_writes(expected_contents) + + def test_cancels_out_of_queueing(self): + self.stubber.add_response( + 'get_object', + service_response={ + 'Body': CancelledStreamWrapper( + self.stream, self.transfer_coordinator + ) + }, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + task = self.get_download_task() + task() + + self.stubber.assert_no_pending_responses() + # Make sure that no contents were added to the queue because the task + # should have been canceled before trying to add the contents to the + # io queue. + self.assert_io_writes([]) + + def test_handles_callback_on_initial_error(self): + # We can't use the stubber for this because we need to raise + # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows + # you to raise a ClientError. + self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR()) + task = self.get_download_task() + task() + self.transfer_coordinator.announce_done() + # Should have failed out on a RetriesExceededError because + # get_object keeps raising a socket error. + with self.assertRaises(RetriesExceededError): + self.transfer_coordinator.result() + + +class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask): + def setUp(self): + super().setUp() + self.task_cls = ImmediatelyWriteIOGetObjectTask + # When data is written out, it should not use the io executor at all + # if it does use the io executor that is a deviation from expected + # behavior as the data should be written immediately to the file + # object once downloaded. + self.io_executor = None + self.download_output_manager = DownloadSeekableOutputManager( + self.osutil, self.transfer_coordinator, self.io_executor + ) + + def assert_io_writes(self, expected_writes): + self.assertEqual(self.fileobj.writes, expected_writes) + + +class BaseIOTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.files = FileCreator() + self.osutil = OSUtils() + self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile') + self.final_filename = os.path.join(self.files.rootdir, 'myfile') + + def tearDown(self): + super().tearDown() + self.files.remove_all() + + +class TestIOStreamingWriteTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + task = self.get_task( + IOStreamingWriteTask, + main_kwargs={'fileobj': f, 'data': b'foobar'}, + ) + task() + task2 = self.get_task( + IOStreamingWriteTask, + main_kwargs={'fileobj': f, 'data': b'baz'}, + ) + task2() + with open(self.temp_filename, 'rb') as f: + # We should just have written to the file in the order + # the tasks were executed. + self.assertEqual(f.read(), b'foobarbaz') + + +class TestIOWriteTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + # Write once to the file + task = self.get_task( + IOWriteTask, + main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0}, + ) + task() + + # Write again to the file + task = self.get_task( + IOWriteTask, + main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3}, + ) + task() + + with open(self.temp_filename, 'rb') as f: + self.assertEqual(f.read(), b'foobar') + + +class TestIORenameFileTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'wb') as f: + task = self.get_task( + IORenameFileTask, + main_kwargs={ + 'fileobj': f, + 'final_filename': self.final_filename, + 'osutil': self.osutil, + }, + ) + task() + self.assertTrue(os.path.exists(self.final_filename)) + self.assertFalse(os.path.exists(self.temp_filename)) + + +class TestIOCloseTask(BaseIOTaskTest): + def test_main(self): + with open(self.temp_filename, 'w') as f: + task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f}) + task() + self.assertTrue(f.closed) + + +class TestDownloadChunkIterator(unittest.TestCase): + def test_iter(self): + content = b'my content' + body = BytesIO(content) + ref_chunks = [] + for chunk in DownloadChunkIterator(body, len(content)): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'my content']) + + def test_iter_chunksize(self): + content = b'1234' + body = BytesIO(content) + ref_chunks = [] + for chunk in DownloadChunkIterator(body, 3): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'123', b'4']) + + def test_empty_content(self): + body = BytesIO(b'') + ref_chunks = [] + for chunk in DownloadChunkIterator(body, 3): + ref_chunks.append(chunk) + self.assertEqual(ref_chunks, [b'']) + + +class TestDeferQueue(unittest.TestCase): + def setUp(self): + self.q = DeferQueue() + + def test_no_writes_when_not_lowest_block(self): + writes = self.q.request_writes(offset=1, data='bar') + self.assertEqual(writes, []) + + def test_writes_returned_in_order(self): + self.assertEqual(self.q.request_writes(offset=3, data='d'), []) + self.assertEqual(self.q.request_writes(offset=2, data='c'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + # Everything at this point has been deferred, but as soon as we + # send offset=0, that will unlock offsets 0-3. + writes = self.q.request_writes(offset=0, data='a') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'a'}, + {'offset': 1, 'data': 'b'}, + {'offset': 2, 'data': 'c'}, + {'offset': 3, 'data': 'd'}, + ], + ) + + def test_unlocks_partial_range(self): + self.assertEqual(self.q.request_writes(offset=5, data='f'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first. + writes = self.q.request_writes(offset=0, data='a') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'a'}, + {'offset': 1, 'data': 'b'}, + ], + ) + + def test_data_can_be_any_size(self): + self.q.request_writes(offset=5, data='hello world') + writes = self.q.request_writes(offset=0, data='abcde') + self.assertEqual( + writes, + [ + {'offset': 0, 'data': 'abcde'}, + {'offset': 5, 'data': 'hello world'}, + ], + ) + + def test_data_queued_in_order(self): + # This immediately gets returned because offset=0 is the + # next range we're waiting on. + writes = self.q.request_writes(offset=0, data='hello world') + self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}]) + # Same thing here but with offset + writes = self.q.request_writes(offset=11, data='hello again') + self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}]) + + def test_writes_below_min_offset_are_ignored(self): + self.q.request_writes(offset=0, data='a') + self.q.request_writes(offset=1, data='b') + self.q.request_writes(offset=2, data='c') + + # At this point we're expecting offset=3, so if a write + # comes in below 3, we ignore it. + self.assertEqual(self.q.request_writes(offset=0, data='a'), []) + self.assertEqual(self.q.request_writes(offset=1, data='b'), []) + + self.assertEqual( + self.q.request_writes(offset=3, data='d'), + [{'offset': 3, 'data': 'd'}], + ) + + def test_duplicate_writes_are_ignored(self): + self.q.request_writes(offset=2, data='c') + self.q.request_writes(offset=1, data='b') + + # We're still waiting for offset=0, but if + # a duplicate write comes in for offset=2/offset=1 + # it's ignored. This gives "first one wins" behavior. + self.assertEqual(self.q.request_writes(offset=2, data='X'), []) + self.assertEqual(self.q.request_writes(offset=1, data='Y'), []) + + self.assertEqual( + self.q.request_writes(offset=0, data='a'), + [ + {'offset': 0, 'data': 'a'}, + # Note we're seeing 'b' 'c', and not 'X', 'Y'. + {'offset': 1, 'data': 'b'}, + {'offset': 2, 'data': 'c'}, + ], + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_futures.py b/contrib/python/s3transfer/py3/tests/unit/test_futures.py index 2f4028effe..ca2888a654 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_futures.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_futures.py @@ -1,696 +1,696 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import sys -import time -import traceback -from concurrent.futures import ThreadPoolExecutor - -from s3transfer.exceptions import ( - CancelledError, - FatalError, - TransferNotDoneError, -) -from s3transfer.futures import ( - BaseExecutor, - BoundedExecutor, - ExecutorFuture, - NonThreadedExecutor, - NonThreadedExecutorFuture, - TransferCoordinator, - TransferFuture, - TransferMeta, -) -from s3transfer.tasks import Task -from s3transfer.utils import ( - FunctionContainer, - NoResourcesAvailable, - TaskSemaphore, -) -from __tests__ import ( - RecordingExecutor, - TransferCoordinatorWithInterrupt, - mock, - unittest, -) - - -def return_call_args(*args, **kwargs): - return args, kwargs - - -def raise_exception(exception): - raise exception - - -def get_exc_info(exception): - try: - raise_exception(exception) - except Exception: - return sys.exc_info() - - -class RecordingTransferCoordinator(TransferCoordinator): - def __init__(self): - self.all_transfer_futures_ever_associated = set() - super().__init__() - - def add_associated_future(self, future): - self.all_transfer_futures_ever_associated.add(future) - super().add_associated_future(future) - - -class ReturnFooTask(Task): - def _main(self, **kwargs): - return 'foo' - - -class SleepTask(Task): - def _main(self, sleep_time, **kwargs): - time.sleep(sleep_time) - - -class TestTransferFuture(unittest.TestCase): - def setUp(self): - self.meta = TransferMeta() - self.coordinator = TransferCoordinator() - self.future = self._get_transfer_future() - - def _get_transfer_future(self, **kwargs): - components = { - 'meta': self.meta, - 'coordinator': self.coordinator, - } - for component_name, component in kwargs.items(): - components[component_name] = component - return TransferFuture(**components) - - def test_meta(self): - self.assertIs(self.future.meta, self.meta) - - def test_done(self): - self.assertFalse(self.future.done()) - self.coordinator.set_result(None) - self.assertTrue(self.future.done()) - - def test_result(self): - result = 'foo' - self.coordinator.set_result(result) - self.coordinator.announce_done() - self.assertEqual(self.future.result(), result) - - def test_keyboard_interrupt_on_result_does_not_block(self): - # This should raise a KeyboardInterrupt when result is called on it. - self.coordinator = TransferCoordinatorWithInterrupt() - self.future = self._get_transfer_future() - - # result() should not block and immediately raise the keyboard - # interrupt exception. - with self.assertRaises(KeyboardInterrupt): - self.future.result() - - def test_cancel(self): - self.future.cancel() - self.assertTrue(self.future.done()) - self.assertEqual(self.coordinator.status, 'cancelled') - - def test_set_exception(self): - # Set the result such that there is no exception - self.coordinator.set_result('result') - self.coordinator.announce_done() - self.assertEqual(self.future.result(), 'result') - - self.future.set_exception(ValueError()) - with self.assertRaises(ValueError): - self.future.result() - - def test_set_exception_only_after_done(self): - with self.assertRaises(TransferNotDoneError): - self.future.set_exception(ValueError()) - - self.coordinator.set_result('result') - self.coordinator.announce_done() - self.future.set_exception(ValueError()) - with self.assertRaises(ValueError): - self.future.result() - - -class TestTransferMeta(unittest.TestCase): - def setUp(self): - self.transfer_meta = TransferMeta() - - def test_size(self): - self.assertEqual(self.transfer_meta.size, None) - self.transfer_meta.provide_transfer_size(5) - self.assertEqual(self.transfer_meta.size, 5) - - def test_call_args(self): - call_args = object() - transfer_meta = TransferMeta(call_args) - # Assert the that call args provided is the same as is returned - self.assertIs(transfer_meta.call_args, call_args) - - def test_transfer_id(self): - transfer_meta = TransferMeta(transfer_id=1) - self.assertEqual(transfer_meta.transfer_id, 1) - - def test_user_context(self): - self.transfer_meta.user_context['foo'] = 'bar' - self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'}) - - -class TestTransferCoordinator(unittest.TestCase): - def setUp(self): - self.transfer_coordinator = TransferCoordinator() - - def test_transfer_id(self): - transfer_coordinator = TransferCoordinator(transfer_id=1) - self.assertEqual(transfer_coordinator.transfer_id, 1) - - def test_repr(self): - transfer_coordinator = TransferCoordinator(transfer_id=1) - self.assertEqual( - repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)' - ) - - def test_initial_status(self): - # A TransferCoordinator with no progress should have the status - # of not-started - self.assertEqual(self.transfer_coordinator.status, 'not-started') - - def test_set_status_to_queued(self): - self.transfer_coordinator.set_status_to_queued() - self.assertEqual(self.transfer_coordinator.status, 'queued') - - def test_cannot_set_status_to_queued_from_done_state(self): - self.transfer_coordinator.set_exception(RuntimeError) - with self.assertRaises(RuntimeError): - self.transfer_coordinator.set_status_to_queued() - - def test_status_running(self): - self.transfer_coordinator.set_status_to_running() - self.assertEqual(self.transfer_coordinator.status, 'running') - - def test_cannot_set_status_to_running_from_done_state(self): - self.transfer_coordinator.set_exception(RuntimeError) - with self.assertRaises(RuntimeError): - self.transfer_coordinator.set_status_to_running() - - def test_set_result(self): - success_result = 'foo' - self.transfer_coordinator.set_result(success_result) - self.transfer_coordinator.announce_done() - # Setting result should result in a success state and the return value - # that was set. - self.assertEqual(self.transfer_coordinator.status, 'success') - self.assertEqual(self.transfer_coordinator.result(), success_result) - - def test_set_exception(self): - exception_result = RuntimeError - self.transfer_coordinator.set_exception(exception_result) - self.transfer_coordinator.announce_done() - # Setting an exception should result in a failed state and the return - # value should be the raised exception - self.assertEqual(self.transfer_coordinator.status, 'failed') - self.assertEqual(self.transfer_coordinator.exception, exception_result) - with self.assertRaises(exception_result): - self.transfer_coordinator.result() - - def test_exception_cannot_override_done_state(self): - self.transfer_coordinator.set_result('foo') - self.transfer_coordinator.set_exception(RuntimeError) - # It status should be success even after the exception is set because - # success is a done state. - self.assertEqual(self.transfer_coordinator.status, 'success') - - def test_exception_can_override_done_state_with_override_flag(self): - self.transfer_coordinator.set_result('foo') - self.transfer_coordinator.set_exception(RuntimeError, override=True) - self.assertEqual(self.transfer_coordinator.status, 'failed') - - def test_cancel(self): - self.assertEqual(self.transfer_coordinator.status, 'not-started') - self.transfer_coordinator.cancel() - # This should set the state to cancelled and raise the CancelledError - # exception and should have also set the done event so that result() - # is no longer set. - self.assertEqual(self.transfer_coordinator.status, 'cancelled') - with self.assertRaises(CancelledError): - self.transfer_coordinator.result() - - def test_cancel_can_run_done_callbacks_that_uses_result(self): - exceptions = [] - - def capture_exception(transfer_coordinator, captured_exceptions): - try: - transfer_coordinator.result() - except Exception as e: - captured_exceptions.append(e) - - self.assertEqual(self.transfer_coordinator.status, 'not-started') - self.transfer_coordinator.add_done_callback( - capture_exception, self.transfer_coordinator, exceptions - ) - self.transfer_coordinator.cancel() - - self.assertEqual(len(exceptions), 1) - self.assertIsInstance(exceptions[0], CancelledError) - - def test_cancel_with_message(self): - message = 'my message' - self.transfer_coordinator.cancel(message) - self.transfer_coordinator.announce_done() - with self.assertRaisesRegex(CancelledError, message): - self.transfer_coordinator.result() - - def test_cancel_with_provided_exception(self): - message = 'my message' - self.transfer_coordinator.cancel(message, exc_type=FatalError) - self.transfer_coordinator.announce_done() - with self.assertRaisesRegex(FatalError, message): - self.transfer_coordinator.result() - - def test_cancel_cannot_override_done_state(self): - self.transfer_coordinator.set_result('foo') - self.transfer_coordinator.cancel() - # It status should be success even after cancel is called because - # success is a done state. - self.assertEqual(self.transfer_coordinator.status, 'success') - - def test_set_result_can_override_cancel(self): - self.transfer_coordinator.cancel() - # Result setting should override any cancel or set exception as this - # is always invoked by the final task. - self.transfer_coordinator.set_result('foo') - self.transfer_coordinator.announce_done() - self.assertEqual(self.transfer_coordinator.status, 'success') - - def test_submit(self): - # Submit a callable to the transfer coordinator. It should submit it - # to the executor. - executor = RecordingExecutor( - BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)}) - ) - task = ReturnFooTask(self.transfer_coordinator) - future = self.transfer_coordinator.submit(executor, task, tag='my-tag') - executor.shutdown() - # Make sure the future got submit and executed as well by checking its - # result value which should include the provided future tag. - self.assertEqual( - executor.submissions, - [{'block': True, 'tag': 'my-tag', 'task': task}], - ) - self.assertEqual(future.result(), 'foo') - - def test_association_and_disassociation_on_submit(self): - self.transfer_coordinator = RecordingTransferCoordinator() - - # Submit a callable to the transfer coordinator. - executor = BoundedExecutor(1, 1) - task = ReturnFooTask(self.transfer_coordinator) - future = self.transfer_coordinator.submit(executor, task) - executor.shutdown() - - # Make sure the future that got submitted was associated to the - # transfer future at some point. - self.assertEqual( - self.transfer_coordinator.all_transfer_futures_ever_associated, - {future}, - ) - - # Make sure the future got disassociated once the future is now done - # by looking at the currently associated futures. - self.assertEqual(self.transfer_coordinator.associated_futures, set()) - - def test_done(self): - # These should result in not done state: - # queued - self.assertFalse(self.transfer_coordinator.done()) - # running - self.transfer_coordinator.set_status_to_running() - self.assertFalse(self.transfer_coordinator.done()) - - # These should result in done state: - # failed - self.transfer_coordinator.set_exception(Exception) - self.assertTrue(self.transfer_coordinator.done()) - - # success - self.transfer_coordinator.set_result('foo') - self.assertTrue(self.transfer_coordinator.done()) - - # cancelled - self.transfer_coordinator.cancel() - self.assertTrue(self.transfer_coordinator.done()) - - def test_result_waits_until_done(self): - execution_order = [] - - def sleep_then_set_result(transfer_coordinator, execution_order): - time.sleep(0.05) - execution_order.append('setting_result') - transfer_coordinator.set_result(None) - self.transfer_coordinator.announce_done() - - with ThreadPoolExecutor(max_workers=1) as executor: - executor.submit( - sleep_then_set_result, - self.transfer_coordinator, - execution_order, - ) - self.transfer_coordinator.result() - execution_order.append('after_result') - - # The result() call should have waited until the other thread set - # the result after sleeping for 0.05 seconds. - self.assertTrue(execution_order, ['setting_result', 'after_result']) - - def test_failure_cleanups(self): - args = (1, 2) - kwargs = {'foo': 'bar'} - - second_args = (2, 4) - second_kwargs = {'biz': 'baz'} - - self.transfer_coordinator.add_failure_cleanup( - return_call_args, *args, **kwargs - ) - self.transfer_coordinator.add_failure_cleanup( - return_call_args, *second_args, **second_kwargs - ) - - # Ensure the callbacks got added. - self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2) - - result_list = [] - # Ensure they will get called in the correct order. - for cleanup in self.transfer_coordinator.failure_cleanups: - result_list.append(cleanup()) - self.assertEqual( - result_list, [(args, kwargs), (second_args, second_kwargs)] - ) - - def test_associated_futures(self): - first_future = object() - # Associate one future to the transfer - self.transfer_coordinator.add_associated_future(first_future) - associated_futures = self.transfer_coordinator.associated_futures - # The first future should be in the returned list of futures. - self.assertEqual(associated_futures, {first_future}) - - second_future = object() - # Associate another future to the transfer. - self.transfer_coordinator.add_associated_future(second_future) - # The association should not have mutated the returned list from - # before. - self.assertEqual(associated_futures, {first_future}) - - # Both futures should be in the returned list. - self.assertEqual( - self.transfer_coordinator.associated_futures, - {first_future, second_future}, - ) - - def test_done_callbacks_on_done(self): - done_callback_invocations = [] - callback = FunctionContainer( - done_callback_invocations.append, 'done callback called' - ) - - # Add the done callback to the transfer. - self.transfer_coordinator.add_done_callback(callback) - - # Announce that the transfer is done. This should invoke the done - # callback. - self.transfer_coordinator.announce_done() - self.assertEqual(done_callback_invocations, ['done callback called']) - - # If done is announced again, we should not invoke the callback again - # because done has already been announced and thus the callback has - # been ran as well. - self.transfer_coordinator.announce_done() - self.assertEqual(done_callback_invocations, ['done callback called']) - - def test_failure_cleanups_on_done(self): - cleanup_invocations = [] - callback = FunctionContainer( - cleanup_invocations.append, 'cleanup called' - ) - - # Add the failure cleanup to the transfer. - self.transfer_coordinator.add_failure_cleanup(callback) - - # Announce that the transfer is done. This should invoke the failure - # cleanup. - self.transfer_coordinator.announce_done() - self.assertEqual(cleanup_invocations, ['cleanup called']) - - # If done is announced again, we should not invoke the cleanup again - # because done has already been announced and thus the cleanup has - # been ran as well. - self.transfer_coordinator.announce_done() - self.assertEqual(cleanup_invocations, ['cleanup called']) - - -class TestBoundedExecutor(unittest.TestCase): - def setUp(self): - self.coordinator = TransferCoordinator() - self.tag_semaphores = {} - self.executor = self.get_executor() - - def get_executor(self, max_size=1, max_num_threads=1): - return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores) - - def get_task(self, task_cls, main_kwargs=None): - return task_cls(self.coordinator, main_kwargs=main_kwargs) - - def get_sleep_task(self, sleep_time=0.01): - return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time}) - - def add_semaphore(self, task_tag, count): - self.tag_semaphores[task_tag] = TaskSemaphore(count) - - def assert_submit_would_block(self, task, tag=None): - with self.assertRaises(NoResourcesAvailable): - self.executor.submit(task, tag=tag, block=False) - - def assert_submit_would_not_block(self, task, tag=None, **kwargs): - try: - self.executor.submit(task, tag=tag, block=False) - except NoResourcesAvailable: - self.fail( - 'Task {} should not have been blocked. Caused by:\n{}'.format( - task, traceback.format_exc() - ) - ) - - def add_done_callback_to_future(self, future, fn, *args, **kwargs): - callback_for_future = FunctionContainer(fn, *args, **kwargs) - future.add_done_callback(callback_for_future) - - def test_submit_single_task(self): - # Ensure we can submit a task to the executor - task = self.get_task(ReturnFooTask) - future = self.executor.submit(task) - - # Ensure what we get back is a Future - self.assertIsInstance(future, ExecutorFuture) - # Ensure the callable got executed. - self.assertEqual(future.result(), 'foo') - - @unittest.skipIf( - os.environ.get('USE_SERIAL_EXECUTOR'), - "Not supported with serial executor tests", - ) - def test_executor_blocks_on_full_capacity(self): - first_task = self.get_sleep_task() - second_task = self.get_sleep_task() - self.executor.submit(first_task) - # The first task should be sleeping for a substantial period of - # time such that on the submission of the second task, it will - # raise an error saying that it cannot be submitted as the max - # capacity of the semaphore is one. - self.assert_submit_would_block(second_task) - - def test_executor_clears_capacity_on_done_tasks(self): - first_task = self.get_sleep_task() - second_task = self.get_task(ReturnFooTask) - - # Submit a task. - future = self.executor.submit(first_task) - - # Submit a new task when the first task finishes. This should not get - # blocked because the first task should have finished clearing up - # capacity. - self.add_done_callback_to_future( - future, self.assert_submit_would_not_block, second_task - ) - - # Wait for it to complete. - self.executor.shutdown() - - @unittest.skipIf( - os.environ.get('USE_SERIAL_EXECUTOR'), - "Not supported with serial executor tests", - ) - def test_would_not_block_when_full_capacity_in_other_semaphore(self): - first_task = self.get_sleep_task() - - # Now let's create a new task with a tag and so it uses different - # semaphore. - task_tag = 'other' - other_task = self.get_sleep_task() - self.add_semaphore(task_tag, 1) - - # Submit the normal first task - self.executor.submit(first_task) - - # Even though The first task should be sleeping for a substantial - # period of time, the submission of the second task should not - # raise an error because it should use a different semaphore - self.assert_submit_would_not_block(other_task, task_tag) - - # Another submission of the other task though should raise - # an exception as the capacity is equal to one for that tag. - self.assert_submit_would_block(other_task, task_tag) - - def test_shutdown(self): - slow_task = self.get_sleep_task() - future = self.executor.submit(slow_task) - self.executor.shutdown() - # Ensure that the shutdown waits until the task is done - self.assertTrue(future.done()) - - @unittest.skipIf( - os.environ.get('USE_SERIAL_EXECUTOR'), - "Not supported with serial executor tests", - ) - def test_shutdown_no_wait(self): - slow_task = self.get_sleep_task() - future = self.executor.submit(slow_task) - self.executor.shutdown(False) - # Ensure that the shutdown returns immediately even if the task is - # not done, which it should not be because it it slow. - self.assertFalse(future.done()) - - def test_replace_underlying_executor(self): - mocked_executor_cls = mock.Mock(BaseExecutor) - executor = BoundedExecutor(10, 1, {}, mocked_executor_cls) - executor.submit(self.get_task(ReturnFooTask)) - self.assertTrue(mocked_executor_cls.return_value.submit.called) - - -class TestExecutorFuture(unittest.TestCase): - def test_result(self): - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(return_call_args, 'foo', biz='baz') - wrapped_future = ExecutorFuture(future) - self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'})) - - def test_done(self): - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(return_call_args, 'foo', biz='baz') - wrapped_future = ExecutorFuture(future) - self.assertTrue(wrapped_future.done()) - - def test_add_done_callback(self): - done_callbacks = [] - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(return_call_args, 'foo', biz='baz') - wrapped_future = ExecutorFuture(future) - wrapped_future.add_done_callback( - FunctionContainer(done_callbacks.append, 'called') - ) - self.assertEqual(done_callbacks, ['called']) - - -class TestNonThreadedExecutor(unittest.TestCase): - def test_submit(self): - executor = NonThreadedExecutor() - future = executor.submit(return_call_args, 1, 2, foo='bar') - self.assertIsInstance(future, NonThreadedExecutorFuture) - self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'})) - - def test_submit_with_exception(self): - executor = NonThreadedExecutor() - future = executor.submit(raise_exception, RuntimeError()) - self.assertIsInstance(future, NonThreadedExecutorFuture) - with self.assertRaises(RuntimeError): - future.result() - - def test_submit_with_exception_and_captures_info(self): - exception = ValueError('message') - tb = get_exc_info(exception)[2] - future = NonThreadedExecutor().submit(raise_exception, exception) - try: - future.result() - # An exception should have been raised - self.fail('Future should have raised a ValueError') - except ValueError: - actual_tb = sys.exc_info()[2] - last_frame = traceback.extract_tb(actual_tb)[-1] - last_expected_frame = traceback.extract_tb(tb)[-1] - self.assertEqual(last_frame, last_expected_frame) - - -class TestNonThreadedExecutorFuture(unittest.TestCase): - def setUp(self): - self.future = NonThreadedExecutorFuture() - - def test_done_starts_false(self): - self.assertFalse(self.future.done()) - - def test_done_after_setting_result(self): - self.future.set_result('result') - self.assertTrue(self.future.done()) - - def test_done_after_setting_exception(self): - self.future.set_exception_info(Exception(), None) - self.assertTrue(self.future.done()) - - def test_result(self): - self.future.set_result('result') - self.assertEqual(self.future.result(), 'result') - - def test_exception_result(self): - exception = ValueError('message') - self.future.set_exception_info(exception, None) - with self.assertRaisesRegex(ValueError, 'message'): - self.future.result() - - def test_exception_result_doesnt_modify_last_frame(self): - exception = ValueError('message') - tb = get_exc_info(exception)[2] - self.future.set_exception_info(exception, tb) - try: - self.future.result() - # An exception should have been raised - self.fail() - except ValueError: - actual_tb = sys.exc_info()[2] - last_frame = traceback.extract_tb(actual_tb)[-1] - last_expected_frame = traceback.extract_tb(tb)[-1] - self.assertEqual(last_frame, last_expected_frame) - - def test_done_callback(self): - done_futures = [] - self.future.add_done_callback(done_futures.append) - self.assertEqual(done_futures, []) - self.future.set_result('result') - self.assertEqual(done_futures, [self.future]) - - def test_done_callback_after_done(self): - self.future.set_result('result') - done_futures = [] - self.future.add_done_callback(done_futures.append) - self.assertEqual(done_futures, [self.future]) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import sys +import time +import traceback +from concurrent.futures import ThreadPoolExecutor + +from s3transfer.exceptions import ( + CancelledError, + FatalError, + TransferNotDoneError, +) +from s3transfer.futures import ( + BaseExecutor, + BoundedExecutor, + ExecutorFuture, + NonThreadedExecutor, + NonThreadedExecutorFuture, + TransferCoordinator, + TransferFuture, + TransferMeta, +) +from s3transfer.tasks import Task +from s3transfer.utils import ( + FunctionContainer, + NoResourcesAvailable, + TaskSemaphore, +) +from __tests__ import ( + RecordingExecutor, + TransferCoordinatorWithInterrupt, + mock, + unittest, +) + + +def return_call_args(*args, **kwargs): + return args, kwargs + + +def raise_exception(exception): + raise exception + + +def get_exc_info(exception): + try: + raise_exception(exception) + except Exception: + return sys.exc_info() + + +class RecordingTransferCoordinator(TransferCoordinator): + def __init__(self): + self.all_transfer_futures_ever_associated = set() + super().__init__() + + def add_associated_future(self, future): + self.all_transfer_futures_ever_associated.add(future) + super().add_associated_future(future) + + +class ReturnFooTask(Task): + def _main(self, **kwargs): + return 'foo' + + +class SleepTask(Task): + def _main(self, sleep_time, **kwargs): + time.sleep(sleep_time) + + +class TestTransferFuture(unittest.TestCase): + def setUp(self): + self.meta = TransferMeta() + self.coordinator = TransferCoordinator() + self.future = self._get_transfer_future() + + def _get_transfer_future(self, **kwargs): + components = { + 'meta': self.meta, + 'coordinator': self.coordinator, + } + for component_name, component in kwargs.items(): + components[component_name] = component + return TransferFuture(**components) + + def test_meta(self): + self.assertIs(self.future.meta, self.meta) + + def test_done(self): + self.assertFalse(self.future.done()) + self.coordinator.set_result(None) + self.assertTrue(self.future.done()) + + def test_result(self): + result = 'foo' + self.coordinator.set_result(result) + self.coordinator.announce_done() + self.assertEqual(self.future.result(), result) + + def test_keyboard_interrupt_on_result_does_not_block(self): + # This should raise a KeyboardInterrupt when result is called on it. + self.coordinator = TransferCoordinatorWithInterrupt() + self.future = self._get_transfer_future() + + # result() should not block and immediately raise the keyboard + # interrupt exception. + with self.assertRaises(KeyboardInterrupt): + self.future.result() + + def test_cancel(self): + self.future.cancel() + self.assertTrue(self.future.done()) + self.assertEqual(self.coordinator.status, 'cancelled') + + def test_set_exception(self): + # Set the result such that there is no exception + self.coordinator.set_result('result') + self.coordinator.announce_done() + self.assertEqual(self.future.result(), 'result') + + self.future.set_exception(ValueError()) + with self.assertRaises(ValueError): + self.future.result() + + def test_set_exception_only_after_done(self): + with self.assertRaises(TransferNotDoneError): + self.future.set_exception(ValueError()) + + self.coordinator.set_result('result') + self.coordinator.announce_done() + self.future.set_exception(ValueError()) + with self.assertRaises(ValueError): + self.future.result() + + +class TestTransferMeta(unittest.TestCase): + def setUp(self): + self.transfer_meta = TransferMeta() + + def test_size(self): + self.assertEqual(self.transfer_meta.size, None) + self.transfer_meta.provide_transfer_size(5) + self.assertEqual(self.transfer_meta.size, 5) + + def test_call_args(self): + call_args = object() + transfer_meta = TransferMeta(call_args) + # Assert the that call args provided is the same as is returned + self.assertIs(transfer_meta.call_args, call_args) + + def test_transfer_id(self): + transfer_meta = TransferMeta(transfer_id=1) + self.assertEqual(transfer_meta.transfer_id, 1) + + def test_user_context(self): + self.transfer_meta.user_context['foo'] = 'bar' + self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'}) + + +class TestTransferCoordinator(unittest.TestCase): + def setUp(self): + self.transfer_coordinator = TransferCoordinator() + + def test_transfer_id(self): + transfer_coordinator = TransferCoordinator(transfer_id=1) + self.assertEqual(transfer_coordinator.transfer_id, 1) + + def test_repr(self): + transfer_coordinator = TransferCoordinator(transfer_id=1) + self.assertEqual( + repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)' + ) + + def test_initial_status(self): + # A TransferCoordinator with no progress should have the status + # of not-started + self.assertEqual(self.transfer_coordinator.status, 'not-started') + + def test_set_status_to_queued(self): + self.transfer_coordinator.set_status_to_queued() + self.assertEqual(self.transfer_coordinator.status, 'queued') + + def test_cannot_set_status_to_queued_from_done_state(self): + self.transfer_coordinator.set_exception(RuntimeError) + with self.assertRaises(RuntimeError): + self.transfer_coordinator.set_status_to_queued() + + def test_status_running(self): + self.transfer_coordinator.set_status_to_running() + self.assertEqual(self.transfer_coordinator.status, 'running') + + def test_cannot_set_status_to_running_from_done_state(self): + self.transfer_coordinator.set_exception(RuntimeError) + with self.assertRaises(RuntimeError): + self.transfer_coordinator.set_status_to_running() + + def test_set_result(self): + success_result = 'foo' + self.transfer_coordinator.set_result(success_result) + self.transfer_coordinator.announce_done() + # Setting result should result in a success state and the return value + # that was set. + self.assertEqual(self.transfer_coordinator.status, 'success') + self.assertEqual(self.transfer_coordinator.result(), success_result) + + def test_set_exception(self): + exception_result = RuntimeError + self.transfer_coordinator.set_exception(exception_result) + self.transfer_coordinator.announce_done() + # Setting an exception should result in a failed state and the return + # value should be the raised exception + self.assertEqual(self.transfer_coordinator.status, 'failed') + self.assertEqual(self.transfer_coordinator.exception, exception_result) + with self.assertRaises(exception_result): + self.transfer_coordinator.result() + + def test_exception_cannot_override_done_state(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.set_exception(RuntimeError) + # It status should be success even after the exception is set because + # success is a done state. + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_exception_can_override_done_state_with_override_flag(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.set_exception(RuntimeError, override=True) + self.assertEqual(self.transfer_coordinator.status, 'failed') + + def test_cancel(self): + self.assertEqual(self.transfer_coordinator.status, 'not-started') + self.transfer_coordinator.cancel() + # This should set the state to cancelled and raise the CancelledError + # exception and should have also set the done event so that result() + # is no longer set. + self.assertEqual(self.transfer_coordinator.status, 'cancelled') + with self.assertRaises(CancelledError): + self.transfer_coordinator.result() + + def test_cancel_can_run_done_callbacks_that_uses_result(self): + exceptions = [] + + def capture_exception(transfer_coordinator, captured_exceptions): + try: + transfer_coordinator.result() + except Exception as e: + captured_exceptions.append(e) + + self.assertEqual(self.transfer_coordinator.status, 'not-started') + self.transfer_coordinator.add_done_callback( + capture_exception, self.transfer_coordinator, exceptions + ) + self.transfer_coordinator.cancel() + + self.assertEqual(len(exceptions), 1) + self.assertIsInstance(exceptions[0], CancelledError) + + def test_cancel_with_message(self): + message = 'my message' + self.transfer_coordinator.cancel(message) + self.transfer_coordinator.announce_done() + with self.assertRaisesRegex(CancelledError, message): + self.transfer_coordinator.result() + + def test_cancel_with_provided_exception(self): + message = 'my message' + self.transfer_coordinator.cancel(message, exc_type=FatalError) + self.transfer_coordinator.announce_done() + with self.assertRaisesRegex(FatalError, message): + self.transfer_coordinator.result() + + def test_cancel_cannot_override_done_state(self): + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.cancel() + # It status should be success even after cancel is called because + # success is a done state. + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_set_result_can_override_cancel(self): + self.transfer_coordinator.cancel() + # Result setting should override any cancel or set exception as this + # is always invoked by the final task. + self.transfer_coordinator.set_result('foo') + self.transfer_coordinator.announce_done() + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_submit(self): + # Submit a callable to the transfer coordinator. It should submit it + # to the executor. + executor = RecordingExecutor( + BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)}) + ) + task = ReturnFooTask(self.transfer_coordinator) + future = self.transfer_coordinator.submit(executor, task, tag='my-tag') + executor.shutdown() + # Make sure the future got submit and executed as well by checking its + # result value which should include the provided future tag. + self.assertEqual( + executor.submissions, + [{'block': True, 'tag': 'my-tag', 'task': task}], + ) + self.assertEqual(future.result(), 'foo') + + def test_association_and_disassociation_on_submit(self): + self.transfer_coordinator = RecordingTransferCoordinator() + + # Submit a callable to the transfer coordinator. + executor = BoundedExecutor(1, 1) + task = ReturnFooTask(self.transfer_coordinator) + future = self.transfer_coordinator.submit(executor, task) + executor.shutdown() + + # Make sure the future that got submitted was associated to the + # transfer future at some point. + self.assertEqual( + self.transfer_coordinator.all_transfer_futures_ever_associated, + {future}, + ) + + # Make sure the future got disassociated once the future is now done + # by looking at the currently associated futures. + self.assertEqual(self.transfer_coordinator.associated_futures, set()) + + def test_done(self): + # These should result in not done state: + # queued + self.assertFalse(self.transfer_coordinator.done()) + # running + self.transfer_coordinator.set_status_to_running() + self.assertFalse(self.transfer_coordinator.done()) + + # These should result in done state: + # failed + self.transfer_coordinator.set_exception(Exception) + self.assertTrue(self.transfer_coordinator.done()) + + # success + self.transfer_coordinator.set_result('foo') + self.assertTrue(self.transfer_coordinator.done()) + + # cancelled + self.transfer_coordinator.cancel() + self.assertTrue(self.transfer_coordinator.done()) + + def test_result_waits_until_done(self): + execution_order = [] + + def sleep_then_set_result(transfer_coordinator, execution_order): + time.sleep(0.05) + execution_order.append('setting_result') + transfer_coordinator.set_result(None) + self.transfer_coordinator.announce_done() + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit( + sleep_then_set_result, + self.transfer_coordinator, + execution_order, + ) + self.transfer_coordinator.result() + execution_order.append('after_result') + + # The result() call should have waited until the other thread set + # the result after sleeping for 0.05 seconds. + self.assertTrue(execution_order, ['setting_result', 'after_result']) + + def test_failure_cleanups(self): + args = (1, 2) + kwargs = {'foo': 'bar'} + + second_args = (2, 4) + second_kwargs = {'biz': 'baz'} + + self.transfer_coordinator.add_failure_cleanup( + return_call_args, *args, **kwargs + ) + self.transfer_coordinator.add_failure_cleanup( + return_call_args, *second_args, **second_kwargs + ) + + # Ensure the callbacks got added. + self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2) + + result_list = [] + # Ensure they will get called in the correct order. + for cleanup in self.transfer_coordinator.failure_cleanups: + result_list.append(cleanup()) + self.assertEqual( + result_list, [(args, kwargs), (second_args, second_kwargs)] + ) + + def test_associated_futures(self): + first_future = object() + # Associate one future to the transfer + self.transfer_coordinator.add_associated_future(first_future) + associated_futures = self.transfer_coordinator.associated_futures + # The first future should be in the returned list of futures. + self.assertEqual(associated_futures, {first_future}) + + second_future = object() + # Associate another future to the transfer. + self.transfer_coordinator.add_associated_future(second_future) + # The association should not have mutated the returned list from + # before. + self.assertEqual(associated_futures, {first_future}) + + # Both futures should be in the returned list. + self.assertEqual( + self.transfer_coordinator.associated_futures, + {first_future, second_future}, + ) + + def test_done_callbacks_on_done(self): + done_callback_invocations = [] + callback = FunctionContainer( + done_callback_invocations.append, 'done callback called' + ) + + # Add the done callback to the transfer. + self.transfer_coordinator.add_done_callback(callback) + + # Announce that the transfer is done. This should invoke the done + # callback. + self.transfer_coordinator.announce_done() + self.assertEqual(done_callback_invocations, ['done callback called']) + + # If done is announced again, we should not invoke the callback again + # because done has already been announced and thus the callback has + # been ran as well. + self.transfer_coordinator.announce_done() + self.assertEqual(done_callback_invocations, ['done callback called']) + + def test_failure_cleanups_on_done(self): + cleanup_invocations = [] + callback = FunctionContainer( + cleanup_invocations.append, 'cleanup called' + ) + + # Add the failure cleanup to the transfer. + self.transfer_coordinator.add_failure_cleanup(callback) + + # Announce that the transfer is done. This should invoke the failure + # cleanup. + self.transfer_coordinator.announce_done() + self.assertEqual(cleanup_invocations, ['cleanup called']) + + # If done is announced again, we should not invoke the cleanup again + # because done has already been announced and thus the cleanup has + # been ran as well. + self.transfer_coordinator.announce_done() + self.assertEqual(cleanup_invocations, ['cleanup called']) + + +class TestBoundedExecutor(unittest.TestCase): + def setUp(self): + self.coordinator = TransferCoordinator() + self.tag_semaphores = {} + self.executor = self.get_executor() + + def get_executor(self, max_size=1, max_num_threads=1): + return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores) + + def get_task(self, task_cls, main_kwargs=None): + return task_cls(self.coordinator, main_kwargs=main_kwargs) + + def get_sleep_task(self, sleep_time=0.01): + return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time}) + + def add_semaphore(self, task_tag, count): + self.tag_semaphores[task_tag] = TaskSemaphore(count) + + def assert_submit_would_block(self, task, tag=None): + with self.assertRaises(NoResourcesAvailable): + self.executor.submit(task, tag=tag, block=False) + + def assert_submit_would_not_block(self, task, tag=None, **kwargs): + try: + self.executor.submit(task, tag=tag, block=False) + except NoResourcesAvailable: + self.fail( + 'Task {} should not have been blocked. Caused by:\n{}'.format( + task, traceback.format_exc() + ) + ) + + def add_done_callback_to_future(self, future, fn, *args, **kwargs): + callback_for_future = FunctionContainer(fn, *args, **kwargs) + future.add_done_callback(callback_for_future) + + def test_submit_single_task(self): + # Ensure we can submit a task to the executor + task = self.get_task(ReturnFooTask) + future = self.executor.submit(task) + + # Ensure what we get back is a Future + self.assertIsInstance(future, ExecutorFuture) + # Ensure the callable got executed. + self.assertEqual(future.result(), 'foo') + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_executor_blocks_on_full_capacity(self): + first_task = self.get_sleep_task() + second_task = self.get_sleep_task() + self.executor.submit(first_task) + # The first task should be sleeping for a substantial period of + # time such that on the submission of the second task, it will + # raise an error saying that it cannot be submitted as the max + # capacity of the semaphore is one. + self.assert_submit_would_block(second_task) + + def test_executor_clears_capacity_on_done_tasks(self): + first_task = self.get_sleep_task() + second_task = self.get_task(ReturnFooTask) + + # Submit a task. + future = self.executor.submit(first_task) + + # Submit a new task when the first task finishes. This should not get + # blocked because the first task should have finished clearing up + # capacity. + self.add_done_callback_to_future( + future, self.assert_submit_would_not_block, second_task + ) + + # Wait for it to complete. + self.executor.shutdown() + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_would_not_block_when_full_capacity_in_other_semaphore(self): + first_task = self.get_sleep_task() + + # Now let's create a new task with a tag and so it uses different + # semaphore. + task_tag = 'other' + other_task = self.get_sleep_task() + self.add_semaphore(task_tag, 1) + + # Submit the normal first task + self.executor.submit(first_task) + + # Even though The first task should be sleeping for a substantial + # period of time, the submission of the second task should not + # raise an error because it should use a different semaphore + self.assert_submit_would_not_block(other_task, task_tag) + + # Another submission of the other task though should raise + # an exception as the capacity is equal to one for that tag. + self.assert_submit_would_block(other_task, task_tag) + + def test_shutdown(self): + slow_task = self.get_sleep_task() + future = self.executor.submit(slow_task) + self.executor.shutdown() + # Ensure that the shutdown waits until the task is done + self.assertTrue(future.done()) + + @unittest.skipIf( + os.environ.get('USE_SERIAL_EXECUTOR'), + "Not supported with serial executor tests", + ) + def test_shutdown_no_wait(self): + slow_task = self.get_sleep_task() + future = self.executor.submit(slow_task) + self.executor.shutdown(False) + # Ensure that the shutdown returns immediately even if the task is + # not done, which it should not be because it it slow. + self.assertFalse(future.done()) + + def test_replace_underlying_executor(self): + mocked_executor_cls = mock.Mock(BaseExecutor) + executor = BoundedExecutor(10, 1, {}, mocked_executor_cls) + executor.submit(self.get_task(ReturnFooTask)) + self.assertTrue(mocked_executor_cls.return_value.submit.called) + + +class TestExecutorFuture(unittest.TestCase): + def test_result(self): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'})) + + def test_done(self): + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + self.assertTrue(wrapped_future.done()) + + def test_add_done_callback(self): + done_callbacks = [] + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(return_call_args, 'foo', biz='baz') + wrapped_future = ExecutorFuture(future) + wrapped_future.add_done_callback( + FunctionContainer(done_callbacks.append, 'called') + ) + self.assertEqual(done_callbacks, ['called']) + + +class TestNonThreadedExecutor(unittest.TestCase): + def test_submit(self): + executor = NonThreadedExecutor() + future = executor.submit(return_call_args, 1, 2, foo='bar') + self.assertIsInstance(future, NonThreadedExecutorFuture) + self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'})) + + def test_submit_with_exception(self): + executor = NonThreadedExecutor() + future = executor.submit(raise_exception, RuntimeError()) + self.assertIsInstance(future, NonThreadedExecutorFuture) + with self.assertRaises(RuntimeError): + future.result() + + def test_submit_with_exception_and_captures_info(self): + exception = ValueError('message') + tb = get_exc_info(exception)[2] + future = NonThreadedExecutor().submit(raise_exception, exception) + try: + future.result() + # An exception should have been raised + self.fail('Future should have raised a ValueError') + except ValueError: + actual_tb = sys.exc_info()[2] + last_frame = traceback.extract_tb(actual_tb)[-1] + last_expected_frame = traceback.extract_tb(tb)[-1] + self.assertEqual(last_frame, last_expected_frame) + + +class TestNonThreadedExecutorFuture(unittest.TestCase): + def setUp(self): + self.future = NonThreadedExecutorFuture() + + def test_done_starts_false(self): + self.assertFalse(self.future.done()) + + def test_done_after_setting_result(self): + self.future.set_result('result') + self.assertTrue(self.future.done()) + + def test_done_after_setting_exception(self): + self.future.set_exception_info(Exception(), None) + self.assertTrue(self.future.done()) + + def test_result(self): + self.future.set_result('result') + self.assertEqual(self.future.result(), 'result') + + def test_exception_result(self): + exception = ValueError('message') + self.future.set_exception_info(exception, None) + with self.assertRaisesRegex(ValueError, 'message'): + self.future.result() + + def test_exception_result_doesnt_modify_last_frame(self): + exception = ValueError('message') + tb = get_exc_info(exception)[2] + self.future.set_exception_info(exception, tb) + try: + self.future.result() + # An exception should have been raised + self.fail() + except ValueError: + actual_tb = sys.exc_info()[2] + last_frame = traceback.extract_tb(actual_tb)[-1] + last_expected_frame = traceback.extract_tb(tb)[-1] + self.assertEqual(last_frame, last_expected_frame) + + def test_done_callback(self): + done_futures = [] + self.future.add_done_callback(done_futures.append) + self.assertEqual(done_futures, []) + self.future.set_result('result') + self.assertEqual(done_futures, [self.future]) + + def test_done_callback_after_done(self): + self.future.set_result('result') + done_futures = [] + self.future.add_done_callback(done_futures.append) + self.assertEqual(done_futures, [self.future]) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_manager.py b/contrib/python/s3transfer/py3/tests/unit/test_manager.py index 2676357927..fc3caa843f 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_manager.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_manager.py @@ -1,143 +1,143 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import time -from concurrent.futures import ThreadPoolExecutor - -from s3transfer.exceptions import CancelledError, FatalError -from s3transfer.futures import TransferCoordinator -from s3transfer.manager import TransferConfig, TransferCoordinatorController -from __tests__ import TransferCoordinatorWithInterrupt, unittest - - -class FutureResultException(Exception): - pass - - -class TestTransferConfig(unittest.TestCase): - def test_exception_on_zero_attr_value(self): - with self.assertRaises(ValueError): - TransferConfig(max_request_queue_size=0) - - -class TestTransferCoordinatorController(unittest.TestCase): - def setUp(self): - self.coordinator_controller = TransferCoordinatorController() - - def sleep_then_announce_done(self, transfer_coordinator, sleep_time): - time.sleep(sleep_time) - transfer_coordinator.set_result('done') - transfer_coordinator.announce_done() - - def assert_coordinator_is_cancelled(self, transfer_coordinator): - self.assertEqual(transfer_coordinator.status, 'cancelled') - - def test_add_transfer_coordinator(self): - transfer_coordinator = TransferCoordinator() - # Add the transfer coordinator - self.coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) - # Ensure that is tracked. - self.assertEqual( - self.coordinator_controller.tracked_transfer_coordinators, - {transfer_coordinator}, - ) - - def test_remove_transfer_coordinator(self): - transfer_coordinator = TransferCoordinator() - # Add the coordinator - self.coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) - # Now remove the coordinator - self.coordinator_controller.remove_transfer_coordinator( - transfer_coordinator - ) - # Make sure that it is no longer getting tracked. - self.assertEqual( - self.coordinator_controller.tracked_transfer_coordinators, set() - ) - - def test_cancel(self): - transfer_coordinator = TransferCoordinator() - # Add the transfer coordinator - self.coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) - # Cancel with the canceler - self.coordinator_controller.cancel() - # Check that coordinator got canceled - self.assert_coordinator_is_cancelled(transfer_coordinator) - - def test_cancel_with_message(self): - message = 'my cancel message' - transfer_coordinator = TransferCoordinator() - self.coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) - self.coordinator_controller.cancel(message) - transfer_coordinator.announce_done() - with self.assertRaisesRegex(CancelledError, message): - transfer_coordinator.result() - - def test_cancel_with_provided_exception(self): - message = 'my cancel message' - transfer_coordinator = TransferCoordinator() - self.coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) - self.coordinator_controller.cancel(message, exc_type=FatalError) - transfer_coordinator.announce_done() - with self.assertRaisesRegex(FatalError, message): - transfer_coordinator.result() - - def test_wait_for_done_transfer_coordinators(self): - # Create a coordinator and add it to the canceler - transfer_coordinator = TransferCoordinator() - self.coordinator_controller.add_transfer_coordinator( - transfer_coordinator - ) - - sleep_time = 0.02 - with ThreadPoolExecutor(max_workers=1) as executor: - # In a separate thread sleep and then set the transfer coordinator - # to done after sleeping. - start_time = time.time() - executor.submit( - self.sleep_then_announce_done, transfer_coordinator, sleep_time - ) - # Now call wait to wait for the transfer coordinator to be done. - self.coordinator_controller.wait() - end_time = time.time() - wait_time = end_time - start_time - # The time waited should not be less than the time it took to sleep in - # the separate thread because the wait ending should be dependent on - # the sleeping thread announcing that the transfer coordinator is done. - self.assertTrue(sleep_time <= wait_time) - - def test_wait_does_not_propogate_exceptions_from_result(self): - transfer_coordinator = TransferCoordinator() - transfer_coordinator.set_exception(FutureResultException()) - transfer_coordinator.announce_done() - try: - self.coordinator_controller.wait() - except FutureResultException as e: - self.fail('%s should not have been raised.' % e) - - def test_wait_can_be_interrupted(self): - inject_interrupt_coordinator = TransferCoordinatorWithInterrupt() - self.coordinator_controller.add_transfer_coordinator( - inject_interrupt_coordinator - ) - with self.assertRaises(KeyboardInterrupt): - self.coordinator_controller.wait() +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import time +from concurrent.futures import ThreadPoolExecutor + +from s3transfer.exceptions import CancelledError, FatalError +from s3transfer.futures import TransferCoordinator +from s3transfer.manager import TransferConfig, TransferCoordinatorController +from __tests__ import TransferCoordinatorWithInterrupt, unittest + + +class FutureResultException(Exception): + pass + + +class TestTransferConfig(unittest.TestCase): + def test_exception_on_zero_attr_value(self): + with self.assertRaises(ValueError): + TransferConfig(max_request_queue_size=0) + + +class TestTransferCoordinatorController(unittest.TestCase): + def setUp(self): + self.coordinator_controller = TransferCoordinatorController() + + def sleep_then_announce_done(self, transfer_coordinator, sleep_time): + time.sleep(sleep_time) + transfer_coordinator.set_result('done') + transfer_coordinator.announce_done() + + def assert_coordinator_is_cancelled(self, transfer_coordinator): + self.assertEqual(transfer_coordinator.status, 'cancelled') + + def test_add_transfer_coordinator(self): + transfer_coordinator = TransferCoordinator() + # Add the transfer coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Ensure that is tracked. + self.assertEqual( + self.coordinator_controller.tracked_transfer_coordinators, + {transfer_coordinator}, + ) + + def test_remove_transfer_coordinator(self): + transfer_coordinator = TransferCoordinator() + # Add the coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Now remove the coordinator + self.coordinator_controller.remove_transfer_coordinator( + transfer_coordinator + ) + # Make sure that it is no longer getting tracked. + self.assertEqual( + self.coordinator_controller.tracked_transfer_coordinators, set() + ) + + def test_cancel(self): + transfer_coordinator = TransferCoordinator() + # Add the transfer coordinator + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + # Cancel with the canceler + self.coordinator_controller.cancel() + # Check that coordinator got canceled + self.assert_coordinator_is_cancelled(transfer_coordinator) + + def test_cancel_with_message(self): + message = 'my cancel message' + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + self.coordinator_controller.cancel(message) + transfer_coordinator.announce_done() + with self.assertRaisesRegex(CancelledError, message): + transfer_coordinator.result() + + def test_cancel_with_provided_exception(self): + message = 'my cancel message' + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + self.coordinator_controller.cancel(message, exc_type=FatalError) + transfer_coordinator.announce_done() + with self.assertRaisesRegex(FatalError, message): + transfer_coordinator.result() + + def test_wait_for_done_transfer_coordinators(self): + # Create a coordinator and add it to the canceler + transfer_coordinator = TransferCoordinator() + self.coordinator_controller.add_transfer_coordinator( + transfer_coordinator + ) + + sleep_time = 0.02 + with ThreadPoolExecutor(max_workers=1) as executor: + # In a separate thread sleep and then set the transfer coordinator + # to done after sleeping. + start_time = time.time() + executor.submit( + self.sleep_then_announce_done, transfer_coordinator, sleep_time + ) + # Now call wait to wait for the transfer coordinator to be done. + self.coordinator_controller.wait() + end_time = time.time() + wait_time = end_time - start_time + # The time waited should not be less than the time it took to sleep in + # the separate thread because the wait ending should be dependent on + # the sleeping thread announcing that the transfer coordinator is done. + self.assertTrue(sleep_time <= wait_time) + + def test_wait_does_not_propogate_exceptions_from_result(self): + transfer_coordinator = TransferCoordinator() + transfer_coordinator.set_exception(FutureResultException()) + transfer_coordinator.announce_done() + try: + self.coordinator_controller.wait() + except FutureResultException as e: + self.fail('%s should not have been raised.' % e) + + def test_wait_can_be_interrupted(self): + inject_interrupt_coordinator = TransferCoordinatorWithInterrupt() + self.coordinator_controller.add_transfer_coordinator( + inject_interrupt_coordinator + ) + with self.assertRaises(KeyboardInterrupt): + self.coordinator_controller.wait() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py index 9a6f63b638..d77b5e0240 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py @@ -1,728 +1,728 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import queue -import signal -import threading -import time -from io import BytesIO - -from botocore.client import BaseClient -from botocore.config import Config -from botocore.exceptions import ClientError, ReadTimeoutError - -from s3transfer.constants import PROCESS_USER_AGENT -from s3transfer.exceptions import CancelledError, RetriesExceededError -from s3transfer.processpool import ( - SHUTDOWN_SIGNAL, - ClientFactory, - DownloadFileRequest, - GetObjectJob, - GetObjectSubmitter, - GetObjectWorker, - ProcessPoolDownloader, - ProcessPoolTransferFuture, - ProcessPoolTransferMeta, - ProcessTransferConfig, - TransferMonitor, - TransferState, - ignore_ctrl_c, -) -from s3transfer.utils import CallArgs, OSUtils -from __tests__ import ( - FileCreator, - StreamWithError, - StubbedClientTest, - mock, - skip_if_windows, - unittest, -) - - -class RenameFailingOSUtils(OSUtils): - def __init__(self, exception): - self.exception = exception - - def rename_file(self, current_filename, new_filename): - raise self.exception - - -class TestIgnoreCtrlC(unittest.TestCase): - @skip_if_windows('os.kill() with SIGINT not supported on Windows') - def test_ignore_ctrl_c(self): - with ignore_ctrl_c(): - try: - os.kill(os.getpid(), signal.SIGINT) - except KeyboardInterrupt: - self.fail( - 'The ignore_ctrl_c context manager should have ' - 'ignored the KeyboardInterrupt exception' - ) - - -class TestProcessPoolDownloader(unittest.TestCase): - def test_uses_client_kwargs(self): - with mock.patch('s3transfer.processpool.ClientFactory') as factory: - ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'}) - self.assertEqual( - factory.call_args[0][0], {'region_name': 'myregion'} - ) - - -class TestProcessPoolTransferFuture(unittest.TestCase): - def setUp(self): - self.monitor = TransferMonitor() - self.transfer_id = self.monitor.notify_new_transfer() - self.meta = ProcessPoolTransferMeta( - transfer_id=self.transfer_id, call_args=CallArgs() - ) - self.future = ProcessPoolTransferFuture( - monitor=self.monitor, meta=self.meta - ) - - def test_meta(self): - self.assertEqual(self.future.meta, self.meta) - - def test_done(self): - self.assertFalse(self.future.done()) - self.monitor.notify_done(self.transfer_id) - self.assertTrue(self.future.done()) - - def test_result(self): - self.monitor.notify_done(self.transfer_id) - self.assertIsNone(self.future.result()) - - def test_result_with_exception(self): - self.monitor.notify_exception(self.transfer_id, RuntimeError()) - self.monitor.notify_done(self.transfer_id) - with self.assertRaises(RuntimeError): - self.future.result() - - def test_result_with_keyboard_interrupt(self): - mock_monitor = mock.Mock(TransferMonitor) - mock_monitor._connect = mock.Mock() - mock_monitor.poll_for_result.side_effect = KeyboardInterrupt() - future = ProcessPoolTransferFuture( - monitor=mock_monitor, meta=self.meta - ) - with self.assertRaises(KeyboardInterrupt): - future.result() - self.assertTrue(mock_monitor._connect.called) - self.assertTrue(mock_monitor.notify_exception.called) - call_args = mock_monitor.notify_exception.call_args[0] - self.assertEqual(call_args[0], self.transfer_id) - self.assertIsInstance(call_args[1], CancelledError) - - def test_cancel(self): - self.future.cancel() - self.monitor.notify_done(self.transfer_id) - with self.assertRaises(CancelledError): - self.future.result() - - -class TestProcessPoolTransferMeta(unittest.TestCase): - def test_transfer_id(self): - meta = ProcessPoolTransferMeta(1, CallArgs()) - self.assertEqual(meta.transfer_id, 1) - - def test_call_args(self): - call_args = CallArgs() - meta = ProcessPoolTransferMeta(1, call_args) - self.assertEqual(meta.call_args, call_args) - - def test_user_context(self): - meta = ProcessPoolTransferMeta(1, CallArgs()) - self.assertEqual(meta.user_context, {}) - meta.user_context['mykey'] = 'myvalue' - self.assertEqual(meta.user_context, {'mykey': 'myvalue'}) - - -class TestClientFactory(unittest.TestCase): - def test_create_client(self): - client = ClientFactory().create_client() - self.assertIsInstance(client, BaseClient) - self.assertEqual(client.meta.service_model.service_name, 's3') - self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) - - def test_create_client_with_client_kwargs(self): - client = ClientFactory({'region_name': 'myregion'}).create_client() - self.assertEqual(client.meta.region_name, 'myregion') - - def test_user_agent_with_config(self): - client = ClientFactory({'config': Config()}).create_client() - self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) - - def test_user_agent_with_existing_user_agent_extra(self): - config = Config(user_agent_extra='foo/1.0') - client = ClientFactory({'config': config}).create_client() - self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) - - def test_user_agent_with_existing_user_agent(self): - config = Config(user_agent='foo/1.0') - client = ClientFactory({'config': config}).create_client() - self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) - - -class TestTransferMonitor(unittest.TestCase): - def setUp(self): - self.monitor = TransferMonitor() - self.transfer_id = self.monitor.notify_new_transfer() - - def test_notify_new_transfer_creates_new_state(self): - monitor = TransferMonitor() - transfer_id = monitor.notify_new_transfer() - self.assertFalse(monitor.is_done(transfer_id)) - self.assertIsNone(monitor.get_exception(transfer_id)) - - def test_notify_new_transfer_increments_transfer_id(self): - monitor = TransferMonitor() - self.assertEqual(monitor.notify_new_transfer(), 0) - self.assertEqual(monitor.notify_new_transfer(), 1) - - def test_notify_get_exception(self): - exception = Exception() - self.monitor.notify_exception(self.transfer_id, exception) - self.assertEqual( - self.monitor.get_exception(self.transfer_id), exception - ) - - def test_get_no_exception(self): - self.assertIsNone(self.monitor.get_exception(self.transfer_id)) - - def test_notify_jobs(self): - self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) - self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) - self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0) - - def test_notify_jobs_for_multiple_transfers(self): - self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) - other_transfer_id = self.monitor.notify_new_transfer() - self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2) - self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) - self.assertEqual( - self.monitor.notify_job_complete(other_transfer_id), 1 - ) - - def test_done(self): - self.assertFalse(self.monitor.is_done(self.transfer_id)) - self.monitor.notify_done(self.transfer_id) - self.assertTrue(self.monitor.is_done(self.transfer_id)) - - def test_poll_for_result(self): - self.monitor.notify_done(self.transfer_id) - self.assertIsNone(self.monitor.poll_for_result(self.transfer_id)) - - def test_poll_for_result_raises_error(self): - self.monitor.notify_exception(self.transfer_id, RuntimeError()) - self.monitor.notify_done(self.transfer_id) - with self.assertRaises(RuntimeError): - self.monitor.poll_for_result(self.transfer_id) - - def test_poll_for_result_waits_till_done(self): - event_order = [] - - def sleep_then_notify_done(): - time.sleep(0.05) - event_order.append('notify_done') - self.monitor.notify_done(self.transfer_id) - - t = threading.Thread(target=sleep_then_notify_done) - t.start() - - self.monitor.poll_for_result(self.transfer_id) - event_order.append('done_polling') - self.assertEqual(event_order, ['notify_done', 'done_polling']) - - def test_notify_cancel_all_in_progress(self): - monitor = TransferMonitor() - transfer_ids = [] - for _ in range(10): - transfer_ids.append(monitor.notify_new_transfer()) - monitor.notify_cancel_all_in_progress() - for transfer_id in transfer_ids: - self.assertIsInstance( - monitor.get_exception(transfer_id), CancelledError - ) - # Cancelling a transfer does not mean it is done as there may - # be cleanup work left to do. - self.assertFalse(monitor.is_done(transfer_id)) - - def test_notify_cancel_does_not_affect_done_transfers(self): - self.monitor.notify_done(self.transfer_id) - self.monitor.notify_cancel_all_in_progress() - self.assertTrue(self.monitor.is_done(self.transfer_id)) - self.assertIsNone(self.monitor.get_exception(self.transfer_id)) - - -class TestTransferState(unittest.TestCase): - def setUp(self): - self.state = TransferState() - - def test_done(self): - self.assertFalse(self.state.done) - self.state.set_done() - self.assertTrue(self.state.done) - - def test_waits_till_done_is_set(self): - event_order = [] - - def sleep_then_set_done(): - time.sleep(0.05) - event_order.append('set_done') - self.state.set_done() - - t = threading.Thread(target=sleep_then_set_done) - t.start() - - self.state.wait_till_done() - event_order.append('done_waiting') - self.assertEqual(event_order, ['set_done', 'done_waiting']) - - def test_exception(self): - exception = RuntimeError() - self.state.exception = exception - self.assertEqual(self.state.exception, exception) - - def test_jobs_to_complete(self): - self.state.jobs_to_complete = 5 - self.assertEqual(self.state.jobs_to_complete, 5) - - def test_decrement_jobs_to_complete(self): - self.state.jobs_to_complete = 5 - self.assertEqual(self.state.decrement_jobs_to_complete(), 4) - - -class TestGetObjectSubmitter(StubbedClientTest): - def setUp(self): - super().setUp() - self.transfer_config = ProcessTransferConfig() - self.client_factory = mock.Mock(ClientFactory) - self.client_factory.create_client.return_value = self.client - self.transfer_monitor = TransferMonitor() - self.osutil = mock.Mock(OSUtils) - self.download_request_queue = queue.Queue() - self.worker_queue = queue.Queue() - self.submitter = GetObjectSubmitter( - transfer_config=self.transfer_config, - client_factory=self.client_factory, - transfer_monitor=self.transfer_monitor, - osutil=self.osutil, - download_request_queue=self.download_request_queue, - worker_queue=self.worker_queue, - ) - self.transfer_id = self.transfer_monitor.notify_new_transfer() - self.bucket = 'bucket' - self.key = 'key' - self.filename = 'myfile' - self.temp_filename = 'myfile.temp' - self.osutil.get_temp_filename.return_value = self.temp_filename - self.extra_args = {} - self.expected_size = None - - def add_download_file_request(self, **override_kwargs): - kwargs = { - 'transfer_id': self.transfer_id, - 'bucket': self.bucket, - 'key': self.key, - 'filename': self.filename, - 'extra_args': self.extra_args, - 'expected_size': self.expected_size, - } - kwargs.update(override_kwargs) - self.download_request_queue.put(DownloadFileRequest(**kwargs)) - - def add_shutdown(self): - self.download_request_queue.put(SHUTDOWN_SIGNAL) - - def assert_submitted_get_object_jobs(self, expected_jobs): - actual_jobs = [] - while not self.worker_queue.empty(): - actual_jobs.append(self.worker_queue.get()) - self.assertEqual(actual_jobs, expected_jobs) - - def test_run_for_non_ranged_download(self): - self.add_download_file_request(expected_size=1) - self.add_shutdown() - self.submitter.run() - self.osutil.allocate.assert_called_with(self.temp_filename, 1) - self.assert_submitted_get_object_jobs( - [ - GetObjectJob( - transfer_id=self.transfer_id, - bucket=self.bucket, - key=self.key, - temp_filename=self.temp_filename, - offset=0, - extra_args={}, - filename=self.filename, - ) - ] - ) - - def test_run_for_ranged_download(self): - self.transfer_config.multipart_chunksize = 2 - self.transfer_config.multipart_threshold = 4 - self.add_download_file_request(expected_size=4) - self.add_shutdown() - self.submitter.run() - self.osutil.allocate.assert_called_with(self.temp_filename, 4) - self.assert_submitted_get_object_jobs( - [ - GetObjectJob( - transfer_id=self.transfer_id, - bucket=self.bucket, - key=self.key, - temp_filename=self.temp_filename, - offset=0, - extra_args={'Range': 'bytes=0-1'}, - filename=self.filename, - ), - GetObjectJob( - transfer_id=self.transfer_id, - bucket=self.bucket, - key=self.key, - temp_filename=self.temp_filename, - offset=2, - extra_args={'Range': 'bytes=2-'}, - filename=self.filename, - ), - ] - ) - - def test_run_when_expected_size_not_provided(self): - self.stubber.add_response( - 'head_object', - {'ContentLength': 1}, - expected_params={'Bucket': self.bucket, 'Key': self.key}, - ) - self.add_download_file_request(expected_size=None) - self.add_shutdown() - self.submitter.run() - self.stubber.assert_no_pending_responses() - self.osutil.allocate.assert_called_with(self.temp_filename, 1) - self.assert_submitted_get_object_jobs( - [ - GetObjectJob( - transfer_id=self.transfer_id, - bucket=self.bucket, - key=self.key, - temp_filename=self.temp_filename, - offset=0, - extra_args={}, - filename=self.filename, - ) - ] - ) - - def test_run_with_extra_args(self): - self.stubber.add_response( - 'head_object', - {'ContentLength': 1}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'VersionId': 'versionid', - }, - ) - self.add_download_file_request( - extra_args={'VersionId': 'versionid'}, expected_size=None - ) - self.add_shutdown() - self.submitter.run() - self.stubber.assert_no_pending_responses() - self.osutil.allocate.assert_called_with(self.temp_filename, 1) - self.assert_submitted_get_object_jobs( - [ - GetObjectJob( - transfer_id=self.transfer_id, - bucket=self.bucket, - key=self.key, - temp_filename=self.temp_filename, - offset=0, - extra_args={'VersionId': 'versionid'}, - filename=self.filename, - ) - ] - ) - - def test_run_with_exception(self): - self.stubber.add_client_error('head_object', 'NoSuchKey', 404) - self.add_download_file_request(expected_size=None) - self.add_shutdown() - self.submitter.run() - self.stubber.assert_no_pending_responses() - self.assert_submitted_get_object_jobs([]) - self.assertIsInstance( - self.transfer_monitor.get_exception(self.transfer_id), ClientError - ) - - def test_run_with_error_in_allocating_temp_file(self): - self.osutil.allocate.side_effect = OSError() - self.add_download_file_request(expected_size=1) - self.add_shutdown() - self.submitter.run() - self.assert_submitted_get_object_jobs([]) - self.assertIsInstance( - self.transfer_monitor.get_exception(self.transfer_id), OSError - ) - - @skip_if_windows('os.kill() with SIGINT not supported on Windows') - def test_submitter_cannot_be_killed(self): - self.add_download_file_request(expected_size=None) - self.add_shutdown() - - def raise_ctrl_c(**kwargs): - os.kill(os.getpid(), signal.SIGINT) - - mock_client = mock.Mock() - mock_client.head_object = raise_ctrl_c - self.client_factory.create_client.return_value = mock_client - - try: - self.submitter.run() - except KeyboardInterrupt: - self.fail( - 'The submitter should have not been killed by the ' - 'KeyboardInterrupt' - ) - - -class TestGetObjectWorker(StubbedClientTest): - def setUp(self): - super().setUp() - self.files = FileCreator() - self.queue = queue.Queue() - self.client_factory = mock.Mock(ClientFactory) - self.client_factory.create_client.return_value = self.client - self.transfer_monitor = TransferMonitor() - self.osutil = OSUtils() - self.worker = GetObjectWorker( - queue=self.queue, - client_factory=self.client_factory, - transfer_monitor=self.transfer_monitor, - osutil=self.osutil, - ) - self.transfer_id = self.transfer_monitor.notify_new_transfer() - self.bucket = 'bucket' - self.key = 'key' - self.remote_contents = b'my content' - self.temp_filename = self.files.create_file('tempfile', '') - self.extra_args = {} - self.offset = 0 - self.final_filename = self.files.full_path('final_filename') - self.stream = BytesIO(self.remote_contents) - self.transfer_monitor.notify_expected_jobs_to_complete( - self.transfer_id, 1000 - ) - - def tearDown(self): - super().tearDown() - self.files.remove_all() - - def add_get_object_job(self, **override_kwargs): - kwargs = { - 'transfer_id': self.transfer_id, - 'bucket': self.bucket, - 'key': self.key, - 'temp_filename': self.temp_filename, - 'extra_args': self.extra_args, - 'offset': self.offset, - 'filename': self.final_filename, - } - kwargs.update(override_kwargs) - self.queue.put(GetObjectJob(**kwargs)) - - def add_shutdown(self): - self.queue.put(SHUTDOWN_SIGNAL) - - def add_stubbed_get_object_response(self, body=None, expected_params=None): - if body is None: - body = self.stream - get_object_response = {'Body': body} - - if expected_params is None: - expected_params = {'Bucket': self.bucket, 'Key': self.key} - - self.stubber.add_response( - 'get_object', get_object_response, expected_params - ) - - def assert_contents(self, filename, contents): - self.assertTrue(os.path.exists(filename)) - with open(filename, 'rb') as f: - self.assertEqual(f.read(), contents) - - def assert_does_not_exist(self, filename): - self.assertFalse(os.path.exists(filename)) - - def test_run_is_final_job(self): - self.add_get_object_job() - self.add_shutdown() - self.add_stubbed_get_object_response() - self.transfer_monitor.notify_expected_jobs_to_complete( - self.transfer_id, 1 - ) - - self.worker.run() - self.stubber.assert_no_pending_responses() - self.assert_does_not_exist(self.temp_filename) - self.assert_contents(self.final_filename, self.remote_contents) - - def test_run_jobs_is_not_final_job(self): - self.add_get_object_job() - self.add_shutdown() - self.add_stubbed_get_object_response() - self.transfer_monitor.notify_expected_jobs_to_complete( - self.transfer_id, 1000 - ) - - self.worker.run() - self.stubber.assert_no_pending_responses() - self.assert_contents(self.temp_filename, self.remote_contents) - self.assert_does_not_exist(self.final_filename) - - def test_run_with_extra_args(self): - self.add_get_object_job(extra_args={'VersionId': 'versionid'}) - self.add_shutdown() - self.add_stubbed_get_object_response( - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'VersionId': 'versionid', - } - ) - - self.worker.run() - self.stubber.assert_no_pending_responses() - - def test_run_with_offset(self): - offset = 1 - self.add_get_object_job(offset=offset) - self.add_shutdown() - self.add_stubbed_get_object_response() - - self.worker.run() - with open(self.temp_filename, 'rb') as f: - f.seek(offset) - self.assertEqual(f.read(), self.remote_contents) - - def test_run_error_in_get_object(self): - self.add_get_object_job() - self.add_shutdown() - self.stubber.add_client_error('get_object', 'NoSuchKey', 404) - self.add_stubbed_get_object_response() - - self.worker.run() - self.assertIsInstance( - self.transfer_monitor.get_exception(self.transfer_id), ClientError - ) - - def test_run_does_retries_for_get_object(self): - self.add_get_object_job() - self.add_shutdown() - self.add_stubbed_get_object_response( - body=StreamWithError( - self.stream, ReadTimeoutError(endpoint_url='') - ) - ) - self.add_stubbed_get_object_response() - - self.worker.run() - self.stubber.assert_no_pending_responses() - self.assert_contents(self.temp_filename, self.remote_contents) - - def test_run_can_exhaust_retries_for_get_object(self): - self.add_get_object_job() - self.add_shutdown() - # 5 is the current setting for max number of GetObject attempts - for _ in range(5): - self.add_stubbed_get_object_response( - body=StreamWithError( - self.stream, ReadTimeoutError(endpoint_url='') - ) - ) - - self.worker.run() - self.stubber.assert_no_pending_responses() - self.assertIsInstance( - self.transfer_monitor.get_exception(self.transfer_id), - RetriesExceededError, - ) - - def test_run_skips_get_object_on_previous_exception(self): - self.add_get_object_job() - self.add_shutdown() - self.transfer_monitor.notify_exception(self.transfer_id, Exception()) - - self.worker.run() - # Note we did not add a stubbed response for get_object - self.stubber.assert_no_pending_responses() - - def test_run_final_job_removes_file_on_previous_exception(self): - self.add_get_object_job() - self.add_shutdown() - self.transfer_monitor.notify_exception(self.transfer_id, Exception()) - self.transfer_monitor.notify_expected_jobs_to_complete( - self.transfer_id, 1 - ) - - self.worker.run() - self.stubber.assert_no_pending_responses() - self.assert_does_not_exist(self.temp_filename) - self.assert_does_not_exist(self.final_filename) - - def test_run_fails_to_rename_file(self): - exception = OSError() - osutil = RenameFailingOSUtils(exception) - self.worker = GetObjectWorker( - queue=self.queue, - client_factory=self.client_factory, - transfer_monitor=self.transfer_monitor, - osutil=osutil, - ) - self.add_get_object_job() - self.add_shutdown() - self.add_stubbed_get_object_response() - self.transfer_monitor.notify_expected_jobs_to_complete( - self.transfer_id, 1 - ) - - self.worker.run() - self.assertEqual( - self.transfer_monitor.get_exception(self.transfer_id), exception - ) - self.assert_does_not_exist(self.temp_filename) - self.assert_does_not_exist(self.final_filename) - - @skip_if_windows('os.kill() with SIGINT not supported on Windows') - def test_worker_cannot_be_killed(self): - self.add_get_object_job() - self.add_shutdown() - self.transfer_monitor.notify_expected_jobs_to_complete( - self.transfer_id, 1 - ) - - def raise_ctrl_c(**kwargs): - os.kill(os.getpid(), signal.SIGINT) - - mock_client = mock.Mock() - mock_client.get_object = raise_ctrl_c - self.client_factory.create_client.return_value = mock_client - - try: - self.worker.run() - except KeyboardInterrupt: - self.fail( - 'The worker should have not been killed by the ' - 'KeyboardInterrupt' - ) +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import queue +import signal +import threading +import time +from io import BytesIO + +from botocore.client import BaseClient +from botocore.config import Config +from botocore.exceptions import ClientError, ReadTimeoutError + +from s3transfer.constants import PROCESS_USER_AGENT +from s3transfer.exceptions import CancelledError, RetriesExceededError +from s3transfer.processpool import ( + SHUTDOWN_SIGNAL, + ClientFactory, + DownloadFileRequest, + GetObjectJob, + GetObjectSubmitter, + GetObjectWorker, + ProcessPoolDownloader, + ProcessPoolTransferFuture, + ProcessPoolTransferMeta, + ProcessTransferConfig, + TransferMonitor, + TransferState, + ignore_ctrl_c, +) +from s3transfer.utils import CallArgs, OSUtils +from __tests__ import ( + FileCreator, + StreamWithError, + StubbedClientTest, + mock, + skip_if_windows, + unittest, +) + + +class RenameFailingOSUtils(OSUtils): + def __init__(self, exception): + self.exception = exception + + def rename_file(self, current_filename, new_filename): + raise self.exception + + +class TestIgnoreCtrlC(unittest.TestCase): + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_ignore_ctrl_c(self): + with ignore_ctrl_c(): + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + self.fail( + 'The ignore_ctrl_c context manager should have ' + 'ignored the KeyboardInterrupt exception' + ) + + +class TestProcessPoolDownloader(unittest.TestCase): + def test_uses_client_kwargs(self): + with mock.patch('s3transfer.processpool.ClientFactory') as factory: + ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'}) + self.assertEqual( + factory.call_args[0][0], {'region_name': 'myregion'} + ) + + +class TestProcessPoolTransferFuture(unittest.TestCase): + def setUp(self): + self.monitor = TransferMonitor() + self.transfer_id = self.monitor.notify_new_transfer() + self.meta = ProcessPoolTransferMeta( + transfer_id=self.transfer_id, call_args=CallArgs() + ) + self.future = ProcessPoolTransferFuture( + monitor=self.monitor, meta=self.meta + ) + + def test_meta(self): + self.assertEqual(self.future.meta, self.meta) + + def test_done(self): + self.assertFalse(self.future.done()) + self.monitor.notify_done(self.transfer_id) + self.assertTrue(self.future.done()) + + def test_result(self): + self.monitor.notify_done(self.transfer_id) + self.assertIsNone(self.future.result()) + + def test_result_with_exception(self): + self.monitor.notify_exception(self.transfer_id, RuntimeError()) + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(RuntimeError): + self.future.result() + + def test_result_with_keyboard_interrupt(self): + mock_monitor = mock.Mock(TransferMonitor) + mock_monitor._connect = mock.Mock() + mock_monitor.poll_for_result.side_effect = KeyboardInterrupt() + future = ProcessPoolTransferFuture( + monitor=mock_monitor, meta=self.meta + ) + with self.assertRaises(KeyboardInterrupt): + future.result() + self.assertTrue(mock_monitor._connect.called) + self.assertTrue(mock_monitor.notify_exception.called) + call_args = mock_monitor.notify_exception.call_args[0] + self.assertEqual(call_args[0], self.transfer_id) + self.assertIsInstance(call_args[1], CancelledError) + + def test_cancel(self): + self.future.cancel() + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(CancelledError): + self.future.result() + + +class TestProcessPoolTransferMeta(unittest.TestCase): + def test_transfer_id(self): + meta = ProcessPoolTransferMeta(1, CallArgs()) + self.assertEqual(meta.transfer_id, 1) + + def test_call_args(self): + call_args = CallArgs() + meta = ProcessPoolTransferMeta(1, call_args) + self.assertEqual(meta.call_args, call_args) + + def test_user_context(self): + meta = ProcessPoolTransferMeta(1, CallArgs()) + self.assertEqual(meta.user_context, {}) + meta.user_context['mykey'] = 'myvalue' + self.assertEqual(meta.user_context, {'mykey': 'myvalue'}) + + +class TestClientFactory(unittest.TestCase): + def test_create_client(self): + client = ClientFactory().create_client() + self.assertIsInstance(client, BaseClient) + self.assertEqual(client.meta.service_model.service_name, 's3') + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_create_client_with_client_kwargs(self): + client = ClientFactory({'region_name': 'myregion'}).create_client() + self.assertEqual(client.meta.region_name, 'myregion') + + def test_user_agent_with_config(self): + client = ClientFactory({'config': Config()}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_user_agent_with_existing_user_agent_extra(self): + config = Config(user_agent_extra='foo/1.0') + client = ClientFactory({'config': config}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + def test_user_agent_with_existing_user_agent(self): + config = Config(user_agent='foo/1.0') + client = ClientFactory({'config': config}).create_client() + self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent) + + +class TestTransferMonitor(unittest.TestCase): + def setUp(self): + self.monitor = TransferMonitor() + self.transfer_id = self.monitor.notify_new_transfer() + + def test_notify_new_transfer_creates_new_state(self): + monitor = TransferMonitor() + transfer_id = monitor.notify_new_transfer() + self.assertFalse(monitor.is_done(transfer_id)) + self.assertIsNone(monitor.get_exception(transfer_id)) + + def test_notify_new_transfer_increments_transfer_id(self): + monitor = TransferMonitor() + self.assertEqual(monitor.notify_new_transfer(), 0) + self.assertEqual(monitor.notify_new_transfer(), 1) + + def test_notify_get_exception(self): + exception = Exception() + self.monitor.notify_exception(self.transfer_id, exception) + self.assertEqual( + self.monitor.get_exception(self.transfer_id), exception + ) + + def test_get_no_exception(self): + self.assertIsNone(self.monitor.get_exception(self.transfer_id)) + + def test_notify_jobs(self): + self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0) + + def test_notify_jobs_for_multiple_transfers(self): + self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2) + other_transfer_id = self.monitor.notify_new_transfer() + self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2) + self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1) + self.assertEqual( + self.monitor.notify_job_complete(other_transfer_id), 1 + ) + + def test_done(self): + self.assertFalse(self.monitor.is_done(self.transfer_id)) + self.monitor.notify_done(self.transfer_id) + self.assertTrue(self.monitor.is_done(self.transfer_id)) + + def test_poll_for_result(self): + self.monitor.notify_done(self.transfer_id) + self.assertIsNone(self.monitor.poll_for_result(self.transfer_id)) + + def test_poll_for_result_raises_error(self): + self.monitor.notify_exception(self.transfer_id, RuntimeError()) + self.monitor.notify_done(self.transfer_id) + with self.assertRaises(RuntimeError): + self.monitor.poll_for_result(self.transfer_id) + + def test_poll_for_result_waits_till_done(self): + event_order = [] + + def sleep_then_notify_done(): + time.sleep(0.05) + event_order.append('notify_done') + self.monitor.notify_done(self.transfer_id) + + t = threading.Thread(target=sleep_then_notify_done) + t.start() + + self.monitor.poll_for_result(self.transfer_id) + event_order.append('done_polling') + self.assertEqual(event_order, ['notify_done', 'done_polling']) + + def test_notify_cancel_all_in_progress(self): + monitor = TransferMonitor() + transfer_ids = [] + for _ in range(10): + transfer_ids.append(monitor.notify_new_transfer()) + monitor.notify_cancel_all_in_progress() + for transfer_id in transfer_ids: + self.assertIsInstance( + monitor.get_exception(transfer_id), CancelledError + ) + # Cancelling a transfer does not mean it is done as there may + # be cleanup work left to do. + self.assertFalse(monitor.is_done(transfer_id)) + + def test_notify_cancel_does_not_affect_done_transfers(self): + self.monitor.notify_done(self.transfer_id) + self.monitor.notify_cancel_all_in_progress() + self.assertTrue(self.monitor.is_done(self.transfer_id)) + self.assertIsNone(self.monitor.get_exception(self.transfer_id)) + + +class TestTransferState(unittest.TestCase): + def setUp(self): + self.state = TransferState() + + def test_done(self): + self.assertFalse(self.state.done) + self.state.set_done() + self.assertTrue(self.state.done) + + def test_waits_till_done_is_set(self): + event_order = [] + + def sleep_then_set_done(): + time.sleep(0.05) + event_order.append('set_done') + self.state.set_done() + + t = threading.Thread(target=sleep_then_set_done) + t.start() + + self.state.wait_till_done() + event_order.append('done_waiting') + self.assertEqual(event_order, ['set_done', 'done_waiting']) + + def test_exception(self): + exception = RuntimeError() + self.state.exception = exception + self.assertEqual(self.state.exception, exception) + + def test_jobs_to_complete(self): + self.state.jobs_to_complete = 5 + self.assertEqual(self.state.jobs_to_complete, 5) + + def test_decrement_jobs_to_complete(self): + self.state.jobs_to_complete = 5 + self.assertEqual(self.state.decrement_jobs_to_complete(), 4) + + +class TestGetObjectSubmitter(StubbedClientTest): + def setUp(self): + super().setUp() + self.transfer_config = ProcessTransferConfig() + self.client_factory = mock.Mock(ClientFactory) + self.client_factory.create_client.return_value = self.client + self.transfer_monitor = TransferMonitor() + self.osutil = mock.Mock(OSUtils) + self.download_request_queue = queue.Queue() + self.worker_queue = queue.Queue() + self.submitter = GetObjectSubmitter( + transfer_config=self.transfer_config, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=self.osutil, + download_request_queue=self.download_request_queue, + worker_queue=self.worker_queue, + ) + self.transfer_id = self.transfer_monitor.notify_new_transfer() + self.bucket = 'bucket' + self.key = 'key' + self.filename = 'myfile' + self.temp_filename = 'myfile.temp' + self.osutil.get_temp_filename.return_value = self.temp_filename + self.extra_args = {} + self.expected_size = None + + def add_download_file_request(self, **override_kwargs): + kwargs = { + 'transfer_id': self.transfer_id, + 'bucket': self.bucket, + 'key': self.key, + 'filename': self.filename, + 'extra_args': self.extra_args, + 'expected_size': self.expected_size, + } + kwargs.update(override_kwargs) + self.download_request_queue.put(DownloadFileRequest(**kwargs)) + + def add_shutdown(self): + self.download_request_queue.put(SHUTDOWN_SIGNAL) + + def assert_submitted_get_object_jobs(self, expected_jobs): + actual_jobs = [] + while not self.worker_queue.empty(): + actual_jobs.append(self.worker_queue.get()) + self.assertEqual(actual_jobs, expected_jobs) + + def test_run_for_non_ranged_download(self): + self.add_download_file_request(expected_size=1) + self.add_shutdown() + self.submitter.run() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={}, + filename=self.filename, + ) + ] + ) + + def test_run_for_ranged_download(self): + self.transfer_config.multipart_chunksize = 2 + self.transfer_config.multipart_threshold = 4 + self.add_download_file_request(expected_size=4) + self.add_shutdown() + self.submitter.run() + self.osutil.allocate.assert_called_with(self.temp_filename, 4) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={'Range': 'bytes=0-1'}, + filename=self.filename, + ), + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=2, + extra_args={'Range': 'bytes=2-'}, + filename=self.filename, + ), + ] + ) + + def test_run_when_expected_size_not_provided(self): + self.stubber.add_response( + 'head_object', + {'ContentLength': 1}, + expected_params={'Bucket': self.bucket, 'Key': self.key}, + ) + self.add_download_file_request(expected_size=None) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={}, + filename=self.filename, + ) + ] + ) + + def test_run_with_extra_args(self): + self.stubber.add_response( + 'head_object', + {'ContentLength': 1}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + }, + ) + self.add_download_file_request( + extra_args={'VersionId': 'versionid'}, expected_size=None + ) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.osutil.allocate.assert_called_with(self.temp_filename, 1) + self.assert_submitted_get_object_jobs( + [ + GetObjectJob( + transfer_id=self.transfer_id, + bucket=self.bucket, + key=self.key, + temp_filename=self.temp_filename, + offset=0, + extra_args={'VersionId': 'versionid'}, + filename=self.filename, + ) + ] + ) + + def test_run_with_exception(self): + self.stubber.add_client_error('head_object', 'NoSuchKey', 404) + self.add_download_file_request(expected_size=None) + self.add_shutdown() + self.submitter.run() + self.stubber.assert_no_pending_responses() + self.assert_submitted_get_object_jobs([]) + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), ClientError + ) + + def test_run_with_error_in_allocating_temp_file(self): + self.osutil.allocate.side_effect = OSError() + self.add_download_file_request(expected_size=1) + self.add_shutdown() + self.submitter.run() + self.assert_submitted_get_object_jobs([]) + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), OSError + ) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_submitter_cannot_be_killed(self): + self.add_download_file_request(expected_size=None) + self.add_shutdown() + + def raise_ctrl_c(**kwargs): + os.kill(os.getpid(), signal.SIGINT) + + mock_client = mock.Mock() + mock_client.head_object = raise_ctrl_c + self.client_factory.create_client.return_value = mock_client + + try: + self.submitter.run() + except KeyboardInterrupt: + self.fail( + 'The submitter should have not been killed by the ' + 'KeyboardInterrupt' + ) + + +class TestGetObjectWorker(StubbedClientTest): + def setUp(self): + super().setUp() + self.files = FileCreator() + self.queue = queue.Queue() + self.client_factory = mock.Mock(ClientFactory) + self.client_factory.create_client.return_value = self.client + self.transfer_monitor = TransferMonitor() + self.osutil = OSUtils() + self.worker = GetObjectWorker( + queue=self.queue, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=self.osutil, + ) + self.transfer_id = self.transfer_monitor.notify_new_transfer() + self.bucket = 'bucket' + self.key = 'key' + self.remote_contents = b'my content' + self.temp_filename = self.files.create_file('tempfile', '') + self.extra_args = {} + self.offset = 0 + self.final_filename = self.files.full_path('final_filename') + self.stream = BytesIO(self.remote_contents) + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1000 + ) + + def tearDown(self): + super().tearDown() + self.files.remove_all() + + def add_get_object_job(self, **override_kwargs): + kwargs = { + 'transfer_id': self.transfer_id, + 'bucket': self.bucket, + 'key': self.key, + 'temp_filename': self.temp_filename, + 'extra_args': self.extra_args, + 'offset': self.offset, + 'filename': self.final_filename, + } + kwargs.update(override_kwargs) + self.queue.put(GetObjectJob(**kwargs)) + + def add_shutdown(self): + self.queue.put(SHUTDOWN_SIGNAL) + + def add_stubbed_get_object_response(self, body=None, expected_params=None): + if body is None: + body = self.stream + get_object_response = {'Body': body} + + if expected_params is None: + expected_params = {'Bucket': self.bucket, 'Key': self.key} + + self.stubber.add_response( + 'get_object', get_object_response, expected_params + ) + + def assert_contents(self, filename, contents): + self.assertTrue(os.path.exists(filename)) + with open(filename, 'rb') as f: + self.assertEqual(f.read(), contents) + + def assert_does_not_exist(self, filename): + self.assertFalse(os.path.exists(filename)) + + def test_run_is_final_job(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_does_not_exist(self.temp_filename) + self.assert_contents(self.final_filename, self.remote_contents) + + def test_run_jobs_is_not_final_job(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1000 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_contents(self.temp_filename, self.remote_contents) + self.assert_does_not_exist(self.final_filename) + + def test_run_with_extra_args(self): + self.add_get_object_job(extra_args={'VersionId': 'versionid'}) + self.add_shutdown() + self.add_stubbed_get_object_response( + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'VersionId': 'versionid', + } + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + + def test_run_with_offset(self): + offset = 1 + self.add_get_object_job(offset=offset) + self.add_shutdown() + self.add_stubbed_get_object_response() + + self.worker.run() + with open(self.temp_filename, 'rb') as f: + f.seek(offset) + self.assertEqual(f.read(), self.remote_contents) + + def test_run_error_in_get_object(self): + self.add_get_object_job() + self.add_shutdown() + self.stubber.add_client_error('get_object', 'NoSuchKey', 404) + self.add_stubbed_get_object_response() + + self.worker.run() + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), ClientError + ) + + def test_run_does_retries_for_get_object(self): + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response( + body=StreamWithError( + self.stream, ReadTimeoutError(endpoint_url='') + ) + ) + self.add_stubbed_get_object_response() + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_contents(self.temp_filename, self.remote_contents) + + def test_run_can_exhaust_retries_for_get_object(self): + self.add_get_object_job() + self.add_shutdown() + # 5 is the current setting for max number of GetObject attempts + for _ in range(5): + self.add_stubbed_get_object_response( + body=StreamWithError( + self.stream, ReadTimeoutError(endpoint_url='') + ) + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assertIsInstance( + self.transfer_monitor.get_exception(self.transfer_id), + RetriesExceededError, + ) + + def test_run_skips_get_object_on_previous_exception(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_exception(self.transfer_id, Exception()) + + self.worker.run() + # Note we did not add a stubbed response for get_object + self.stubber.assert_no_pending_responses() + + def test_run_final_job_removes_file_on_previous_exception(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_exception(self.transfer_id, Exception()) + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.stubber.assert_no_pending_responses() + self.assert_does_not_exist(self.temp_filename) + self.assert_does_not_exist(self.final_filename) + + def test_run_fails_to_rename_file(self): + exception = OSError() + osutil = RenameFailingOSUtils(exception) + self.worker = GetObjectWorker( + queue=self.queue, + client_factory=self.client_factory, + transfer_monitor=self.transfer_monitor, + osutil=osutil, + ) + self.add_get_object_job() + self.add_shutdown() + self.add_stubbed_get_object_response() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + self.worker.run() + self.assertEqual( + self.transfer_monitor.get_exception(self.transfer_id), exception + ) + self.assert_does_not_exist(self.temp_filename) + self.assert_does_not_exist(self.final_filename) + + @skip_if_windows('os.kill() with SIGINT not supported on Windows') + def test_worker_cannot_be_killed(self): + self.add_get_object_job() + self.add_shutdown() + self.transfer_monitor.notify_expected_jobs_to_complete( + self.transfer_id, 1 + ) + + def raise_ctrl_c(**kwargs): + os.kill(os.getpid(), signal.SIGINT) + + mock_client = mock.Mock() + mock_client.get_object = raise_ctrl_c + self.client_factory.create_client.return_value = mock_client + + try: + self.worker.run() + except KeyboardInterrupt: + self.fail( + 'The worker should have not been killed by the ' + 'KeyboardInterrupt' + ) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py index 1aababd760..35cf4a22dd 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py @@ -1,780 +1,780 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import os -import shutil -import socket -import tempfile -from concurrent import futures -from contextlib import closing -from io import BytesIO, StringIO - -from s3transfer import ( - MultipartDownloader, - MultipartUploader, - OSUtils, - QueueShutdownError, - ReadFileChunk, - S3Transfer, - ShutdownQueue, - StreamReaderProgress, - TransferConfig, - disable_upload_callbacks, - enable_upload_callbacks, - random_file_extension, -) -from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError -from __tests__ import mock, unittest - - -class InMemoryOSLayer(OSUtils): - def __init__(self, filemap): - self.filemap = filemap - - def get_file_size(self, filename): - return len(self.filemap[filename]) - - def open_file_chunk_reader(self, filename, start_byte, size, callback): - return closing(BytesIO(self.filemap[filename])) - - def open(self, filename, mode): - if 'wb' in mode: - fileobj = BytesIO() - self.filemap[filename] = fileobj - return closing(fileobj) - else: - return closing(self.filemap[filename]) - - def remove_file(self, filename): - if filename in self.filemap: - del self.filemap[filename] - - def rename_file(self, current_filename, new_filename): - if current_filename in self.filemap: - self.filemap[new_filename] = self.filemap.pop(current_filename) - - -class SequentialExecutor: - def __init__(self, max_workers): - pass - - def __enter__(self): - return self - - def __exit__(self, *args, **kwargs): - pass - - # The real map() interface actually takes *args, but we specifically do - # _not_ use this interface. - def map(self, function, args): - results = [] - for arg in args: - results.append(function(arg)) - return results - - def submit(self, function): - future = futures.Future() - future.set_result(function()) - return future - - -class TestOSUtils(unittest.TestCase): - def setUp(self): - self.tempdir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def test_get_file_size(self): - with mock.patch('os.path.getsize') as m: - OSUtils().get_file_size('myfile') - m.assert_called_with('myfile') - - def test_open_file_chunk_reader(self): - with mock.patch('s3transfer.ReadFileChunk') as m: - OSUtils().open_file_chunk_reader('myfile', 0, 100, None) - m.from_filename.assert_called_with( - 'myfile', 0, 100, None, enable_callback=False - ) - - def test_open_file(self): - fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') - self.assertTrue(hasattr(fileobj, 'write')) - - def test_remove_file_ignores_errors(self): - with mock.patch('os.remove') as remove: - remove.side_effect = OSError('fake error') - OSUtils().remove_file('foo') - remove.assert_called_with('foo') - - def test_remove_file_proxies_remove_file(self): - with mock.patch('os.remove') as remove: - OSUtils().remove_file('foo') - remove.assert_called_with('foo') - - def test_rename_file(self): - with mock.patch('s3transfer.compat.rename_file') as rename_file: - OSUtils().rename_file('foo', 'newfoo') - rename_file.assert_called_with('foo', 'newfoo') - - -class TestReadFileChunk(unittest.TestCase): - def setUp(self): - self.tempdir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def test_read_entire_chunk(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=0, chunk_size=3 - ) - self.assertEqual(chunk.read(), b'one') - self.assertEqual(chunk.read(), b'') - - def test_read_with_amount_size(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=11, chunk_size=4 - ) - self.assertEqual(chunk.read(1), b'f') - self.assertEqual(chunk.read(1), b'o') - self.assertEqual(chunk.read(1), b'u') - self.assertEqual(chunk.read(1), b'r') - self.assertEqual(chunk.read(1), b'') - - def test_reset_stream_emulation(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=11, chunk_size=4 - ) - self.assertEqual(chunk.read(), b'four') - chunk.seek(0) - self.assertEqual(chunk.read(), b'four') - - def test_read_past_end_of_file(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=36, chunk_size=100000 - ) - self.assertEqual(chunk.read(), b'ten') - self.assertEqual(chunk.read(), b'') - self.assertEqual(len(chunk), 3) - - def test_tell_and_seek(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=36, chunk_size=100000 - ) - self.assertEqual(chunk.tell(), 0) - self.assertEqual(chunk.read(), b'ten') - self.assertEqual(chunk.tell(), 3) - chunk.seek(0) - self.assertEqual(chunk.tell(), 0) - - def test_file_chunk_supports_context_manager(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'abc') - with ReadFileChunk.from_filename( - filename, start_byte=0, chunk_size=2 - ) as chunk: - val = chunk.read() - self.assertEqual(val, b'ab') - - def test_iter_is_always_empty(self): - # This tests the workaround for the httplib bug (see - # the source for more info). - filename = os.path.join(self.tempdir, 'foo') - open(filename, 'wb').close() - chunk = ReadFileChunk.from_filename( - filename, start_byte=0, chunk_size=10 - ) - self.assertEqual(list(chunk), []) - - -class TestReadFileChunkWithCallback(TestReadFileChunk): - def setUp(self): - super().setUp() - self.filename = os.path.join(self.tempdir, 'foo') - with open(self.filename, 'wb') as f: - f.write(b'abc') - self.amounts_seen = [] - - def callback(self, amount): - self.amounts_seen.append(amount) - - def test_callback_is_invoked_on_read(self): - chunk = ReadFileChunk.from_filename( - self.filename, start_byte=0, chunk_size=3, callback=self.callback - ) - chunk.read(1) - chunk.read(1) - chunk.read(1) - self.assertEqual(self.amounts_seen, [1, 1, 1]) - - def test_callback_can_be_disabled(self): - chunk = ReadFileChunk.from_filename( - self.filename, start_byte=0, chunk_size=3, callback=self.callback - ) - chunk.disable_callback() - # Now reading from the ReadFileChunk should not invoke - # the callback. - chunk.read() - self.assertEqual(self.amounts_seen, []) - - def test_callback_will_also_be_triggered_by_seek(self): - chunk = ReadFileChunk.from_filename( - self.filename, start_byte=0, chunk_size=3, callback=self.callback - ) - chunk.read(2) - chunk.seek(0) - chunk.read(2) - chunk.seek(1) - chunk.read(2) - self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) - - -class TestStreamReaderProgress(unittest.TestCase): - def test_proxies_to_wrapped_stream(self): - original_stream = StringIO('foobarbaz') - wrapped = StreamReaderProgress(original_stream) - self.assertEqual(wrapped.read(), 'foobarbaz') - - def test_callback_invoked(self): - amounts_seen = [] - - def callback(amount): - amounts_seen.append(amount) - - original_stream = StringIO('foobarbaz') - wrapped = StreamReaderProgress(original_stream, callback) - self.assertEqual(wrapped.read(), 'foobarbaz') - self.assertEqual(amounts_seen, [9]) - - -class TestMultipartUploader(unittest.TestCase): - def test_multipart_upload_uses_correct_client_calls(self): - client = mock.Mock() - uploader = MultipartUploader( - client, - TransferConfig(), - InMemoryOSLayer({'filename': b'foobar'}), - SequentialExecutor, - ) - client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} - client.upload_part.return_value = {'ETag': 'first'} - - uploader.upload_file('filename', 'bucket', 'key', None, {}) - - # We need to check both the sequence of calls (create/upload/complete) - # as well as the params passed between the calls, including - # 1. The upload_id was plumbed through - # 2. The collected etags were added to the complete call. - client.create_multipart_upload.assert_called_with( - Bucket='bucket', Key='key' - ) - # Should be two parts. - client.upload_part.assert_called_with( - Body=mock.ANY, - Bucket='bucket', - UploadId='upload_id', - Key='key', - PartNumber=1, - ) - client.complete_multipart_upload.assert_called_with( - MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, - Bucket='bucket', - UploadId='upload_id', - Key='key', - ) - - def test_multipart_upload_injects_proper_kwargs(self): - client = mock.Mock() - uploader = MultipartUploader( - client, - TransferConfig(), - InMemoryOSLayer({'filename': b'foobar'}), - SequentialExecutor, - ) - client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} - client.upload_part.return_value = {'ETag': 'first'} - - extra_args = { - 'SSECustomerKey': 'fakekey', - 'SSECustomerAlgorithm': 'AES256', - 'StorageClass': 'REDUCED_REDUNDANCY', - } - uploader.upload_file('filename', 'bucket', 'key', None, extra_args) - - client.create_multipart_upload.assert_called_with( - Bucket='bucket', - Key='key', - # The initial call should inject all the storage class params. - SSECustomerKey='fakekey', - SSECustomerAlgorithm='AES256', - StorageClass='REDUCED_REDUNDANCY', - ) - client.upload_part.assert_called_with( - Body=mock.ANY, - Bucket='bucket', - UploadId='upload_id', - Key='key', - PartNumber=1, - # We only have to forward certain **extra_args in subsequent - # UploadPart calls. - SSECustomerKey='fakekey', - SSECustomerAlgorithm='AES256', - ) - client.complete_multipart_upload.assert_called_with( - MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, - Bucket='bucket', - UploadId='upload_id', - Key='key', - ) - - def test_multipart_upload_is_aborted_on_error(self): - # If the create_multipart_upload succeeds and any upload_part - # fails, then abort_multipart_upload will be called. - client = mock.Mock() - uploader = MultipartUploader( - client, - TransferConfig(), - InMemoryOSLayer({'filename': b'foobar'}), - SequentialExecutor, - ) - client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} - client.upload_part.side_effect = Exception( - "Some kind of error occurred." - ) - - with self.assertRaises(S3UploadFailedError): - uploader.upload_file('filename', 'bucket', 'key', None, {}) - - client.abort_multipart_upload.assert_called_with( - Bucket='bucket', Key='key', UploadId='upload_id' - ) - - -class TestMultipartDownloader(unittest.TestCase): - - maxDiff = None - - def test_multipart_download_uses_correct_client_calls(self): - client = mock.Mock() - response_body = b'foobarbaz' - client.get_object.return_value = {'Body': BytesIO(response_body)} - - downloader = MultipartDownloader( - client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor - ) - downloader.download_file( - 'bucket', 'key', 'filename', len(response_body), {} - ) - - client.get_object.assert_called_with( - Range='bytes=0-', Bucket='bucket', Key='key' - ) - - def test_multipart_download_with_multiple_parts(self): - client = mock.Mock() - response_body = b'foobarbaz' - client.get_object.return_value = {'Body': BytesIO(response_body)} - # For testing purposes, we're testing with a multipart threshold - # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz', - # this should result in 3 calls. In python slices this would be: - # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly - # different because they use inclusive ranges. - config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) - - downloader = MultipartDownloader( - client, config, InMemoryOSLayer({}), SequentialExecutor - ) - downloader.download_file( - 'bucket', 'key', 'filename', len(response_body), {} - ) - - # We're storing these in **extra because the assertEqual - # below is really about verifying we have the correct value - # for the Range param. - extra = {'Bucket': 'bucket', 'Key': 'key'} - self.assertEqual( - client.get_object.call_args_list, - # Note these are inclusive ranges. - [ - mock.call(Range='bytes=0-3', **extra), - mock.call(Range='bytes=4-7', **extra), - mock.call(Range='bytes=8-', **extra), - ], - ) - - def test_retry_on_failures_from_stream_reads(self): - # If we get an exception during a call to the response body's .read() - # method, we should retry the request. - client = mock.Mock() - response_body = b'foobarbaz' - stream_with_errors = mock.Mock() - stream_with_errors.read.side_effect = [ - socket.error("fake error"), - response_body, - ] - client.get_object.return_value = {'Body': stream_with_errors} - config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) - - downloader = MultipartDownloader( - client, config, InMemoryOSLayer({}), SequentialExecutor - ) - downloader.download_file( - 'bucket', 'key', 'filename', len(response_body), {} - ) - - # We're storing these in **extra because the assertEqual - # below is really about verifying we have the correct value - # for the Range param. - extra = {'Bucket': 'bucket', 'Key': 'key'} - self.assertEqual( - client.get_object.call_args_list, - # The first call to range=0-3 fails because of the - # side_effect above where we make the .read() raise a - # socket.error. - # The second call to range=0-3 then succeeds. - [ - mock.call(Range='bytes=0-3', **extra), - mock.call(Range='bytes=0-3', **extra), - mock.call(Range='bytes=4-7', **extra), - mock.call(Range='bytes=8-', **extra), - ], - ) - - def test_exception_raised_on_exceeded_retries(self): - client = mock.Mock() - response_body = b'foobarbaz' - stream_with_errors = mock.Mock() - stream_with_errors.read.side_effect = socket.error("fake error") - client.get_object.return_value = {'Body': stream_with_errors} - config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) - - downloader = MultipartDownloader( - client, config, InMemoryOSLayer({}), SequentialExecutor - ) - with self.assertRaises(RetriesExceededError): - downloader.download_file( - 'bucket', 'key', 'filename', len(response_body), {} - ) - - def test_io_thread_failure_triggers_shutdown(self): - client = mock.Mock() - response_body = b'foobarbaz' - client.get_object.return_value = {'Body': BytesIO(response_body)} - os_layer = mock.Mock() - mock_fileobj = mock.MagicMock() - mock_fileobj.__enter__.return_value = mock_fileobj - mock_fileobj.write.side_effect = Exception("fake IO error") - os_layer.open.return_value = mock_fileobj - - downloader = MultipartDownloader( - client, TransferConfig(), os_layer, SequentialExecutor - ) - # We're verifying that the exception raised from the IO future - # propagates back up via download_file(). - with self.assertRaisesRegex(Exception, "fake IO error"): - downloader.download_file( - 'bucket', 'key', 'filename', len(response_body), {} - ) - - def test_download_futures_fail_triggers_shutdown(self): - class FailedDownloadParts(SequentialExecutor): - def __init__(self, max_workers): - self.is_first = True - - def submit(self, function): - future = futures.Future() - if self.is_first: - # This is the download_parts_thread. - future.set_exception( - Exception("fake download parts error") - ) - self.is_first = False - return future - - client = mock.Mock() - response_body = b'foobarbaz' - client.get_object.return_value = {'Body': BytesIO(response_body)} - - downloader = MultipartDownloader( - client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts - ) - with self.assertRaisesRegex(Exception, "fake download parts error"): - downloader.download_file( - 'bucket', 'key', 'filename', len(response_body), {} - ) - - -class TestS3Transfer(unittest.TestCase): - def setUp(self): - self.client = mock.Mock() - self.random_file_patch = mock.patch('s3transfer.random_file_extension') - self.random_file = self.random_file_patch.start() - self.random_file.return_value = 'RANDOM' - - def tearDown(self): - self.random_file_patch.stop() - - def test_callback_handlers_register_on_put_item(self): - osutil = InMemoryOSLayer({'smallfile': b'foobar'}) - transfer = S3Transfer(self.client, osutil=osutil) - transfer.upload_file('smallfile', 'bucket', 'key') - events = self.client.meta.events - events.register_first.assert_called_with( - 'request-created.s3', - disable_upload_callbacks, - unique_id='s3upload-callback-disable', - ) - events.register_last.assert_called_with( - 'request-created.s3', - enable_upload_callbacks, - unique_id='s3upload-callback-enable', - ) - - def test_upload_below_multipart_threshold_uses_put_object(self): - fake_files = { - 'smallfile': b'foobar', - } - osutil = InMemoryOSLayer(fake_files) - transfer = S3Transfer(self.client, osutil=osutil) - transfer.upload_file('smallfile', 'bucket', 'key') - self.client.put_object.assert_called_with( - Bucket='bucket', Key='key', Body=mock.ANY - ) - - def test_extra_args_on_uploaded_passed_to_api_call(self): - extra_args = {'ACL': 'public-read'} - fake_files = {'smallfile': b'hello world'} - osutil = InMemoryOSLayer(fake_files) - transfer = S3Transfer(self.client, osutil=osutil) - transfer.upload_file( - 'smallfile', 'bucket', 'key', extra_args=extra_args - ) - self.client.put_object.assert_called_with( - Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read' - ) - - def test_uses_multipart_upload_when_over_threshold(self): - with mock.patch('s3transfer.MultipartUploader') as uploader: - fake_files = { - 'smallfile': b'foobar', - } - osutil = InMemoryOSLayer(fake_files) - config = TransferConfig( - multipart_threshold=2, multipart_chunksize=2 - ) - transfer = S3Transfer(self.client, osutil=osutil, config=config) - transfer.upload_file('smallfile', 'bucket', 'key') - - uploader.return_value.upload_file.assert_called_with( - 'smallfile', 'bucket', 'key', None, {} - ) - - def test_uses_multipart_download_when_over_threshold(self): - with mock.patch('s3transfer.MultipartDownloader') as downloader: - osutil = InMemoryOSLayer({}) - over_multipart_threshold = 100 * 1024 * 1024 - transfer = S3Transfer(self.client, osutil=osutil) - callback = mock.sentinel.CALLBACK - self.client.head_object.return_value = { - 'ContentLength': over_multipart_threshold, - } - transfer.download_file( - 'bucket', 'key', 'filename', callback=callback - ) - - downloader.return_value.download_file.assert_called_with( - # Note how we're downloading to a temporary random file. - 'bucket', - 'key', - 'filename.RANDOM', - over_multipart_threshold, - {}, - callback, - ) - - def test_download_file_with_invalid_extra_args(self): - below_threshold = 20 - osutil = InMemoryOSLayer({}) - transfer = S3Transfer(self.client, osutil=osutil) - self.client.head_object.return_value = { - 'ContentLength': below_threshold - } - with self.assertRaises(ValueError): - transfer.download_file( - 'bucket', - 'key', - '/tmp/smallfile', - extra_args={'BadValue': 'foo'}, - ) - - def test_upload_file_with_invalid_extra_args(self): - osutil = InMemoryOSLayer({}) - transfer = S3Transfer(self.client, osutil=osutil) - bad_args = {"WebsiteRedirectLocation": "/foo"} - with self.assertRaises(ValueError): - transfer.upload_file( - 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args - ) - - def test_download_file_fowards_extra_args(self): - extra_args = { - 'SSECustomerKey': 'foo', - 'SSECustomerAlgorithm': 'AES256', - } - below_threshold = 20 - osutil = InMemoryOSLayer({'smallfile': b'hello world'}) - transfer = S3Transfer(self.client, osutil=osutil) - self.client.head_object.return_value = { - 'ContentLength': below_threshold - } - self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} - transfer.download_file( - 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args - ) - - # Note that we need to invoke the HeadObject call - # and the PutObject call with the extra_args. - # This is necessary. Trying to HeadObject an SSE object - # will return a 400 if you don't provide the required - # params. - self.client.get_object.assert_called_with( - Bucket='bucket', - Key='key', - SSECustomerAlgorithm='AES256', - SSECustomerKey='foo', - ) - - def test_get_object_stream_is_retried_and_succeeds(self): - below_threshold = 20 - osutil = InMemoryOSLayer({'smallfile': b'hello world'}) - transfer = S3Transfer(self.client, osutil=osutil) - self.client.head_object.return_value = { - 'ContentLength': below_threshold - } - self.client.get_object.side_effect = [ - # First request fails. - socket.error("fake error"), - # Second succeeds. - {'Body': BytesIO(b'foobar')}, - ] - transfer.download_file('bucket', 'key', '/tmp/smallfile') - - self.assertEqual(self.client.get_object.call_count, 2) - - def test_get_object_stream_uses_all_retries_and_errors_out(self): - below_threshold = 20 - osutil = InMemoryOSLayer({}) - transfer = S3Transfer(self.client, osutil=osutil) - self.client.head_object.return_value = { - 'ContentLength': below_threshold - } - # Here we're raising an exception every single time, which - # will exhaust our retry count and propagate a - # RetriesExceededError. - self.client.get_object.side_effect = socket.error("fake error") - with self.assertRaises(RetriesExceededError): - transfer.download_file('bucket', 'key', 'smallfile') - - self.assertEqual(self.client.get_object.call_count, 5) - # We should have also cleaned up the in progress file - # we were downloading to. - self.assertEqual(osutil.filemap, {}) - - def test_download_below_multipart_threshold(self): - below_threshold = 20 - osutil = InMemoryOSLayer({'smallfile': b'hello world'}) - transfer = S3Transfer(self.client, osutil=osutil) - self.client.head_object.return_value = { - 'ContentLength': below_threshold - } - self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} - transfer.download_file('bucket', 'key', 'smallfile') - - self.client.get_object.assert_called_with(Bucket='bucket', Key='key') - - def test_can_create_with_just_client(self): - transfer = S3Transfer(client=mock.Mock()) - self.assertIsInstance(transfer, S3Transfer) - - -class TestShutdownQueue(unittest.TestCase): - def test_handles_normal_put_get_requests(self): - q = ShutdownQueue() - q.put('foo') - self.assertEqual(q.get(), 'foo') - - def test_put_raises_error_on_shutdown(self): - q = ShutdownQueue() - q.trigger_shutdown() - with self.assertRaises(QueueShutdownError): - q.put('foo') - - -class TestRandomFileExtension(unittest.TestCase): - def test_has_proper_length(self): - self.assertEqual(len(random_file_extension(num_digits=4)), 4) - - -class TestCallbackHandlers(unittest.TestCase): - def setUp(self): - self.request = mock.Mock() - - def test_disable_request_on_put_object(self): - disable_upload_callbacks(self.request, 'PutObject') - self.request.body.disable_callback.assert_called_with() - - def test_disable_request_on_upload_part(self): - disable_upload_callbacks(self.request, 'UploadPart') - self.request.body.disable_callback.assert_called_with() - - def test_enable_object_on_put_object(self): - enable_upload_callbacks(self.request, 'PutObject') - self.request.body.enable_callback.assert_called_with() - - def test_enable_object_on_upload_part(self): - enable_upload_callbacks(self.request, 'UploadPart') - self.request.body.enable_callback.assert_called_with() - - def test_dont_disable_if_missing_interface(self): - del self.request.body.disable_callback - disable_upload_callbacks(self.request, 'PutObject') - self.assertEqual(self.request.body.method_calls, []) - - def test_dont_enable_if_missing_interface(self): - del self.request.body.enable_callback - enable_upload_callbacks(self.request, 'PutObject') - self.assertEqual(self.request.body.method_calls, []) - - def test_dont_disable_if_wrong_operation(self): - disable_upload_callbacks(self.request, 'OtherOperation') - self.assertFalse(self.request.body.disable_callback.called) - - def test_dont_enable_if_wrong_operation(self): - enable_upload_callbacks(self.request, 'OtherOperation') - self.assertFalse(self.request.body.enable_callback.called) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import socket +import tempfile +from concurrent import futures +from contextlib import closing +from io import BytesIO, StringIO + +from s3transfer import ( + MultipartDownloader, + MultipartUploader, + OSUtils, + QueueShutdownError, + ReadFileChunk, + S3Transfer, + ShutdownQueue, + StreamReaderProgress, + TransferConfig, + disable_upload_callbacks, + enable_upload_callbacks, + random_file_extension, +) +from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError +from __tests__ import mock, unittest + + +class InMemoryOSLayer(OSUtils): + def __init__(self, filemap): + self.filemap = filemap + + def get_file_size(self, filename): + return len(self.filemap[filename]) + + def open_file_chunk_reader(self, filename, start_byte, size, callback): + return closing(BytesIO(self.filemap[filename])) + + def open(self, filename, mode): + if 'wb' in mode: + fileobj = BytesIO() + self.filemap[filename] = fileobj + return closing(fileobj) + else: + return closing(self.filemap[filename]) + + def remove_file(self, filename): + if filename in self.filemap: + del self.filemap[filename] + + def rename_file(self, current_filename, new_filename): + if current_filename in self.filemap: + self.filemap[new_filename] = self.filemap.pop(current_filename) + + +class SequentialExecutor: + def __init__(self, max_workers): + pass + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + pass + + # The real map() interface actually takes *args, but we specifically do + # _not_ use this interface. + def map(self, function, args): + results = [] + for arg in args: + results.append(function(arg)) + return results + + def submit(self, function): + future = futures.Future() + future.set_result(function()) + return future + + +class TestOSUtils(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_get_file_size(self): + with mock.patch('os.path.getsize') as m: + OSUtils().get_file_size('myfile') + m.assert_called_with('myfile') + + def test_open_file_chunk_reader(self): + with mock.patch('s3transfer.ReadFileChunk') as m: + OSUtils().open_file_chunk_reader('myfile', 0, 100, None) + m.from_filename.assert_called_with( + 'myfile', 0, 100, None, enable_callback=False + ) + + def test_open_file(self): + fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') + self.assertTrue(hasattr(fileobj, 'write')) + + def test_remove_file_ignores_errors(self): + with mock.patch('os.remove') as remove: + remove.side_effect = OSError('fake error') + OSUtils().remove_file('foo') + remove.assert_called_with('foo') + + def test_remove_file_proxies_remove_file(self): + with mock.patch('os.remove') as remove: + OSUtils().remove_file('foo') + remove.assert_called_with('foo') + + def test_rename_file(self): + with mock.patch('s3transfer.compat.rename_file') as rename_file: + OSUtils().rename_file('foo', 'newfoo') + rename_file.assert_called_with('foo', 'newfoo') + + +class TestReadFileChunk(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def test_read_entire_chunk(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=3 + ) + self.assertEqual(chunk.read(), b'one') + self.assertEqual(chunk.read(), b'') + + def test_read_with_amount_size(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(1), b'f') + self.assertEqual(chunk.read(1), b'o') + self.assertEqual(chunk.read(1), b'u') + self.assertEqual(chunk.read(1), b'r') + self.assertEqual(chunk.read(1), b'') + + def test_reset_stream_emulation(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(), b'four') + chunk.seek(0) + self.assertEqual(chunk.read(), b'four') + + def test_read_past_end_of_file(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.read(), b'') + self.assertEqual(len(chunk), 3) + + def test_tell_and_seek(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.tell(), 0) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.tell(), 3) + chunk.seek(0) + self.assertEqual(chunk.tell(), 0) + + def test_file_chunk_supports_context_manager(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'abc') + with ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=2 + ) as chunk: + val = chunk.read() + self.assertEqual(val, b'ab') + + def test_iter_is_always_empty(self): + # This tests the workaround for the httplib bug (see + # the source for more info). + filename = os.path.join(self.tempdir, 'foo') + open(filename, 'wb').close() + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=10 + ) + self.assertEqual(list(chunk), []) + + +class TestReadFileChunkWithCallback(TestReadFileChunk): + def setUp(self): + super().setUp() + self.filename = os.path.join(self.tempdir, 'foo') + with open(self.filename, 'wb') as f: + f.write(b'abc') + self.amounts_seen = [] + + def callback(self, amount): + self.amounts_seen.append(amount) + + def test_callback_is_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + self.assertEqual(self.amounts_seen, [1, 1, 1]) + + def test_callback_can_be_disabled(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.disable_callback() + # Now reading from the ReadFileChunk should not invoke + # the callback. + chunk.read() + self.assertEqual(self.amounts_seen, []) + + def test_callback_will_also_be_triggered_by_seek(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, callback=self.callback + ) + chunk.read(2) + chunk.seek(0) + chunk.read(2) + chunk.seek(1) + chunk.read(2) + self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) + + +class TestStreamReaderProgress(unittest.TestCase): + def test_proxies_to_wrapped_stream(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream) + self.assertEqual(wrapped.read(), 'foobarbaz') + + def test_callback_invoked(self): + amounts_seen = [] + + def callback(amount): + amounts_seen.append(amount) + + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream, callback) + self.assertEqual(wrapped.read(), 'foobarbaz') + self.assertEqual(amounts_seen, [9]) + + +class TestMultipartUploader(unittest.TestCase): + def test_multipart_upload_uses_correct_client_calls(self): + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.return_value = {'ETag': 'first'} + + uploader.upload_file('filename', 'bucket', 'key', None, {}) + + # We need to check both the sequence of calls (create/upload/complete) + # as well as the params passed between the calls, including + # 1. The upload_id was plumbed through + # 2. The collected etags were added to the complete call. + client.create_multipart_upload.assert_called_with( + Bucket='bucket', Key='key' + ) + # Should be two parts. + client.upload_part.assert_called_with( + Body=mock.ANY, + Bucket='bucket', + UploadId='upload_id', + Key='key', + PartNumber=1, + ) + client.complete_multipart_upload.assert_called_with( + MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, + Bucket='bucket', + UploadId='upload_id', + Key='key', + ) + + def test_multipart_upload_injects_proper_kwargs(self): + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.return_value = {'ETag': 'first'} + + extra_args = { + 'SSECustomerKey': 'fakekey', + 'SSECustomerAlgorithm': 'AES256', + 'StorageClass': 'REDUCED_REDUNDANCY', + } + uploader.upload_file('filename', 'bucket', 'key', None, extra_args) + + client.create_multipart_upload.assert_called_with( + Bucket='bucket', + Key='key', + # The initial call should inject all the storage class params. + SSECustomerKey='fakekey', + SSECustomerAlgorithm='AES256', + StorageClass='REDUCED_REDUNDANCY', + ) + client.upload_part.assert_called_with( + Body=mock.ANY, + Bucket='bucket', + UploadId='upload_id', + Key='key', + PartNumber=1, + # We only have to forward certain **extra_args in subsequent + # UploadPart calls. + SSECustomerKey='fakekey', + SSECustomerAlgorithm='AES256', + ) + client.complete_multipart_upload.assert_called_with( + MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]}, + Bucket='bucket', + UploadId='upload_id', + Key='key', + ) + + def test_multipart_upload_is_aborted_on_error(self): + # If the create_multipart_upload succeeds and any upload_part + # fails, then abort_multipart_upload will be called. + client = mock.Mock() + uploader = MultipartUploader( + client, + TransferConfig(), + InMemoryOSLayer({'filename': b'foobar'}), + SequentialExecutor, + ) + client.create_multipart_upload.return_value = {'UploadId': 'upload_id'} + client.upload_part.side_effect = Exception( + "Some kind of error occurred." + ) + + with self.assertRaises(S3UploadFailedError): + uploader.upload_file('filename', 'bucket', 'key', None, {}) + + client.abort_multipart_upload.assert_called_with( + Bucket='bucket', Key='key', UploadId='upload_id' + ) + + +class TestMultipartDownloader(unittest.TestCase): + + maxDiff = None + + def test_multipart_download_uses_correct_client_calls(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + + downloader = MultipartDownloader( + client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + client.get_object.assert_called_with( + Range='bytes=0-', Bucket='bucket', Key='key' + ) + + def test_multipart_download_with_multiple_parts(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + # For testing purposes, we're testing with a multipart threshold + # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz', + # this should result in 3 calls. In python slices this would be: + # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly + # different because they use inclusive ranges. + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + # We're storing these in **extra because the assertEqual + # below is really about verifying we have the correct value + # for the Range param. + extra = {'Bucket': 'bucket', 'Key': 'key'} + self.assertEqual( + client.get_object.call_args_list, + # Note these are inclusive ranges. + [ + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=4-7', **extra), + mock.call(Range='bytes=8-', **extra), + ], + ) + + def test_retry_on_failures_from_stream_reads(self): + # If we get an exception during a call to the response body's .read() + # method, we should retry the request. + client = mock.Mock() + response_body = b'foobarbaz' + stream_with_errors = mock.Mock() + stream_with_errors.read.side_effect = [ + socket.error("fake error"), + response_body, + ] + client.get_object.return_value = {'Body': stream_with_errors} + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + # We're storing these in **extra because the assertEqual + # below is really about verifying we have the correct value + # for the Range param. + extra = {'Bucket': 'bucket', 'Key': 'key'} + self.assertEqual( + client.get_object.call_args_list, + # The first call to range=0-3 fails because of the + # side_effect above where we make the .read() raise a + # socket.error. + # The second call to range=0-3 then succeeds. + [ + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=0-3', **extra), + mock.call(Range='bytes=4-7', **extra), + mock.call(Range='bytes=8-', **extra), + ], + ) + + def test_exception_raised_on_exceeded_retries(self): + client = mock.Mock() + response_body = b'foobarbaz' + stream_with_errors = mock.Mock() + stream_with_errors.read.side_effect = socket.error("fake error") + client.get_object.return_value = {'Body': stream_with_errors} + config = TransferConfig(multipart_threshold=4, multipart_chunksize=4) + + downloader = MultipartDownloader( + client, config, InMemoryOSLayer({}), SequentialExecutor + ) + with self.assertRaises(RetriesExceededError): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + def test_io_thread_failure_triggers_shutdown(self): + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + os_layer = mock.Mock() + mock_fileobj = mock.MagicMock() + mock_fileobj.__enter__.return_value = mock_fileobj + mock_fileobj.write.side_effect = Exception("fake IO error") + os_layer.open.return_value = mock_fileobj + + downloader = MultipartDownloader( + client, TransferConfig(), os_layer, SequentialExecutor + ) + # We're verifying that the exception raised from the IO future + # propagates back up via download_file(). + with self.assertRaisesRegex(Exception, "fake IO error"): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + def test_download_futures_fail_triggers_shutdown(self): + class FailedDownloadParts(SequentialExecutor): + def __init__(self, max_workers): + self.is_first = True + + def submit(self, function): + future = futures.Future() + if self.is_first: + # This is the download_parts_thread. + future.set_exception( + Exception("fake download parts error") + ) + self.is_first = False + return future + + client = mock.Mock() + response_body = b'foobarbaz' + client.get_object.return_value = {'Body': BytesIO(response_body)} + + downloader = MultipartDownloader( + client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts + ) + with self.assertRaisesRegex(Exception, "fake download parts error"): + downloader.download_file( + 'bucket', 'key', 'filename', len(response_body), {} + ) + + +class TestS3Transfer(unittest.TestCase): + def setUp(self): + self.client = mock.Mock() + self.random_file_patch = mock.patch('s3transfer.random_file_extension') + self.random_file = self.random_file_patch.start() + self.random_file.return_value = 'RANDOM' + + def tearDown(self): + self.random_file_patch.stop() + + def test_callback_handlers_register_on_put_item(self): + osutil = InMemoryOSLayer({'smallfile': b'foobar'}) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file('smallfile', 'bucket', 'key') + events = self.client.meta.events + events.register_first.assert_called_with( + 'request-created.s3', + disable_upload_callbacks, + unique_id='s3upload-callback-disable', + ) + events.register_last.assert_called_with( + 'request-created.s3', + enable_upload_callbacks, + unique_id='s3upload-callback-enable', + ) + + def test_upload_below_multipart_threshold_uses_put_object(self): + fake_files = { + 'smallfile': b'foobar', + } + osutil = InMemoryOSLayer(fake_files) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file('smallfile', 'bucket', 'key') + self.client.put_object.assert_called_with( + Bucket='bucket', Key='key', Body=mock.ANY + ) + + def test_extra_args_on_uploaded_passed_to_api_call(self): + extra_args = {'ACL': 'public-read'} + fake_files = {'smallfile': b'hello world'} + osutil = InMemoryOSLayer(fake_files) + transfer = S3Transfer(self.client, osutil=osutil) + transfer.upload_file( + 'smallfile', 'bucket', 'key', extra_args=extra_args + ) + self.client.put_object.assert_called_with( + Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read' + ) + + def test_uses_multipart_upload_when_over_threshold(self): + with mock.patch('s3transfer.MultipartUploader') as uploader: + fake_files = { + 'smallfile': b'foobar', + } + osutil = InMemoryOSLayer(fake_files) + config = TransferConfig( + multipart_threshold=2, multipart_chunksize=2 + ) + transfer = S3Transfer(self.client, osutil=osutil, config=config) + transfer.upload_file('smallfile', 'bucket', 'key') + + uploader.return_value.upload_file.assert_called_with( + 'smallfile', 'bucket', 'key', None, {} + ) + + def test_uses_multipart_download_when_over_threshold(self): + with mock.patch('s3transfer.MultipartDownloader') as downloader: + osutil = InMemoryOSLayer({}) + over_multipart_threshold = 100 * 1024 * 1024 + transfer = S3Transfer(self.client, osutil=osutil) + callback = mock.sentinel.CALLBACK + self.client.head_object.return_value = { + 'ContentLength': over_multipart_threshold, + } + transfer.download_file( + 'bucket', 'key', 'filename', callback=callback + ) + + downloader.return_value.download_file.assert_called_with( + # Note how we're downloading to a temporary random file. + 'bucket', + 'key', + 'filename.RANDOM', + over_multipart_threshold, + {}, + callback, + ) + + def test_download_file_with_invalid_extra_args(self): + below_threshold = 20 + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + with self.assertRaises(ValueError): + transfer.download_file( + 'bucket', + 'key', + '/tmp/smallfile', + extra_args={'BadValue': 'foo'}, + ) + + def test_upload_file_with_invalid_extra_args(self): + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + bad_args = {"WebsiteRedirectLocation": "/foo"} + with self.assertRaises(ValueError): + transfer.upload_file( + 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args + ) + + def test_download_file_fowards_extra_args(self): + extra_args = { + 'SSECustomerKey': 'foo', + 'SSECustomerAlgorithm': 'AES256', + } + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + transfer.download_file( + 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args + ) + + # Note that we need to invoke the HeadObject call + # and the PutObject call with the extra_args. + # This is necessary. Trying to HeadObject an SSE object + # will return a 400 if you don't provide the required + # params. + self.client.get_object.assert_called_with( + Bucket='bucket', + Key='key', + SSECustomerAlgorithm='AES256', + SSECustomerKey='foo', + ) + + def test_get_object_stream_is_retried_and_succeeds(self): + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.side_effect = [ + # First request fails. + socket.error("fake error"), + # Second succeeds. + {'Body': BytesIO(b'foobar')}, + ] + transfer.download_file('bucket', 'key', '/tmp/smallfile') + + self.assertEqual(self.client.get_object.call_count, 2) + + def test_get_object_stream_uses_all_retries_and_errors_out(self): + below_threshold = 20 + osutil = InMemoryOSLayer({}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + # Here we're raising an exception every single time, which + # will exhaust our retry count and propagate a + # RetriesExceededError. + self.client.get_object.side_effect = socket.error("fake error") + with self.assertRaises(RetriesExceededError): + transfer.download_file('bucket', 'key', 'smallfile') + + self.assertEqual(self.client.get_object.call_count, 5) + # We should have also cleaned up the in progress file + # we were downloading to. + self.assertEqual(osutil.filemap, {}) + + def test_download_below_multipart_threshold(self): + below_threshold = 20 + osutil = InMemoryOSLayer({'smallfile': b'hello world'}) + transfer = S3Transfer(self.client, osutil=osutil) + self.client.head_object.return_value = { + 'ContentLength': below_threshold + } + self.client.get_object.return_value = {'Body': BytesIO(b'foobar')} + transfer.download_file('bucket', 'key', 'smallfile') + + self.client.get_object.assert_called_with(Bucket='bucket', Key='key') + + def test_can_create_with_just_client(self): + transfer = S3Transfer(client=mock.Mock()) + self.assertIsInstance(transfer, S3Transfer) + + +class TestShutdownQueue(unittest.TestCase): + def test_handles_normal_put_get_requests(self): + q = ShutdownQueue() + q.put('foo') + self.assertEqual(q.get(), 'foo') + + def test_put_raises_error_on_shutdown(self): + q = ShutdownQueue() + q.trigger_shutdown() + with self.assertRaises(QueueShutdownError): + q.put('foo') + + +class TestRandomFileExtension(unittest.TestCase): + def test_has_proper_length(self): + self.assertEqual(len(random_file_extension(num_digits=4)), 4) + + +class TestCallbackHandlers(unittest.TestCase): + def setUp(self): + self.request = mock.Mock() + + def test_disable_request_on_put_object(self): + disable_upload_callbacks(self.request, 'PutObject') + self.request.body.disable_callback.assert_called_with() + + def test_disable_request_on_upload_part(self): + disable_upload_callbacks(self.request, 'UploadPart') + self.request.body.disable_callback.assert_called_with() + + def test_enable_object_on_put_object(self): + enable_upload_callbacks(self.request, 'PutObject') + self.request.body.enable_callback.assert_called_with() + + def test_enable_object_on_upload_part(self): + enable_upload_callbacks(self.request, 'UploadPart') + self.request.body.enable_callback.assert_called_with() + + def test_dont_disable_if_missing_interface(self): + del self.request.body.disable_callback + disable_upload_callbacks(self.request, 'PutObject') + self.assertEqual(self.request.body.method_calls, []) + + def test_dont_enable_if_missing_interface(self): + del self.request.body.enable_callback + enable_upload_callbacks(self.request, 'PutObject') + self.assertEqual(self.request.body.method_calls, []) + + def test_dont_disable_if_wrong_operation(self): + disable_upload_callbacks(self.request, 'OtherOperation') + self.assertFalse(self.request.body.disable_callback.called) + + def test_dont_enable_if_wrong_operation(self): + enable_upload_callbacks(self.request, 'OtherOperation') + self.assertFalse(self.request.body.enable_callback.called) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py index fd503e212e..a26d3a548c 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py @@ -1,91 +1,91 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from s3transfer.exceptions import InvalidSubscriberMethodError -from s3transfer.subscribers import BaseSubscriber -from __tests__ import unittest - - -class ExtraMethodsSubscriber(BaseSubscriber): - def extra_method(self): - return 'called extra method' - - -class NotCallableSubscriber(BaseSubscriber): - on_done = 'foo' - - -class NoKwargsSubscriber(BaseSubscriber): - def on_done(self): - pass - - -class OverrideMethodSubscriber(BaseSubscriber): - def on_queued(self, **kwargs): - return kwargs - - -class OverrideConstructorSubscriber(BaseSubscriber): - def __init__(self, arg1, arg2): - self.arg1 = arg1 - self.arg2 = arg2 - - -class TestSubscribers(unittest.TestCase): - def test_can_instantiate_base_subscriber(self): - try: - BaseSubscriber() - except InvalidSubscriberMethodError: - self.fail('BaseSubscriber should be instantiable') - - def test_can_call_base_subscriber_method(self): - subscriber = BaseSubscriber() - try: - subscriber.on_done(future=None) - except Exception as e: - self.fail( - 'Should be able to call base class subscriber method. ' - 'instead got: %s' % e - ) - - def test_subclass_can_have_and_call_additional_methods(self): - subscriber = ExtraMethodsSubscriber() - self.assertEqual(subscriber.extra_method(), 'called extra method') - - def test_can_subclass_and_override_method_from_base_subscriber(self): - subscriber = OverrideMethodSubscriber() - # Make sure that the overridden method is called - self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'}) - - def test_can_subclass_and_override_constructor_from_base_class(self): - subscriber = OverrideConstructorSubscriber('foo', arg2='bar') - # Make sure you can create a custom constructor. - self.assertEqual(subscriber.arg1, 'foo') - self.assertEqual(subscriber.arg2, 'bar') - - def test_invalid_arguments_in_constructor_of_subclass_subscriber(self): - # The override constructor should still have validation of - # constructor args. - with self.assertRaises(TypeError): - OverrideConstructorSubscriber() - - def test_not_callable_in_subclass_subscriber_method(self): - with self.assertRaisesRegex( - InvalidSubscriberMethodError, 'must be callable' - ): - NotCallableSubscriber() - - def test_no_kwargs_in_subclass_subscriber_method(self): - with self.assertRaisesRegex( - InvalidSubscriberMethodError, 'must accept keyword' - ): - NoKwargsSubscriber() +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from s3transfer.exceptions import InvalidSubscriberMethodError +from s3transfer.subscribers import BaseSubscriber +from __tests__ import unittest + + +class ExtraMethodsSubscriber(BaseSubscriber): + def extra_method(self): + return 'called extra method' + + +class NotCallableSubscriber(BaseSubscriber): + on_done = 'foo' + + +class NoKwargsSubscriber(BaseSubscriber): + def on_done(self): + pass + + +class OverrideMethodSubscriber(BaseSubscriber): + def on_queued(self, **kwargs): + return kwargs + + +class OverrideConstructorSubscriber(BaseSubscriber): + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + +class TestSubscribers(unittest.TestCase): + def test_can_instantiate_base_subscriber(self): + try: + BaseSubscriber() + except InvalidSubscriberMethodError: + self.fail('BaseSubscriber should be instantiable') + + def test_can_call_base_subscriber_method(self): + subscriber = BaseSubscriber() + try: + subscriber.on_done(future=None) + except Exception as e: + self.fail( + 'Should be able to call base class subscriber method. ' + 'instead got: %s' % e + ) + + def test_subclass_can_have_and_call_additional_methods(self): + subscriber = ExtraMethodsSubscriber() + self.assertEqual(subscriber.extra_method(), 'called extra method') + + def test_can_subclass_and_override_method_from_base_subscriber(self): + subscriber = OverrideMethodSubscriber() + # Make sure that the overridden method is called + self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'}) + + def test_can_subclass_and_override_constructor_from_base_class(self): + subscriber = OverrideConstructorSubscriber('foo', arg2='bar') + # Make sure you can create a custom constructor. + self.assertEqual(subscriber.arg1, 'foo') + self.assertEqual(subscriber.arg2, 'bar') + + def test_invalid_arguments_in_constructor_of_subclass_subscriber(self): + # The override constructor should still have validation of + # constructor args. + with self.assertRaises(TypeError): + OverrideConstructorSubscriber() + + def test_not_callable_in_subclass_subscriber_method(self): + with self.assertRaisesRegex( + InvalidSubscriberMethodError, 'must be callable' + ): + NotCallableSubscriber() + + def test_no_kwargs_in_subclass_subscriber_method(self): + with self.assertRaisesRegex( + InvalidSubscriberMethodError, 'must accept keyword' + ): + NoKwargsSubscriber() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py index 8806598fd7..4f0bc4d1cc 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py @@ -1,833 +1,833 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from concurrent import futures -from functools import partial -from threading import Event - -from s3transfer.futures import BoundedExecutor, TransferCoordinator -from s3transfer.subscribers import BaseSubscriber -from s3transfer.tasks import ( - CompleteMultipartUploadTask, - CreateMultipartUploadTask, - SubmissionTask, - Task, -) -from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks -from __tests__ import ( - BaseSubmissionTaskTest, - BaseTaskTest, - RecordingSubscriber, - unittest, -) - - -class TaskFailureException(Exception): - pass - - -class SuccessTask(Task): - def _main( - self, return_value='success', callbacks=None, failure_cleanups=None - ): - if callbacks: - for callback in callbacks: - callback() - if failure_cleanups: - for failure_cleanup in failure_cleanups: - self._transfer_coordinator.add_failure_cleanup(failure_cleanup) - return return_value - - -class FailureTask(Task): - def _main(self, exception=TaskFailureException): - raise exception() - - -class ReturnKwargsTask(Task): - def _main(self, **kwargs): - return kwargs - - -class SubmitMoreTasksTask(Task): - def _main(self, executor, tasks_to_submit): - for task_to_submit in tasks_to_submit: - self._transfer_coordinator.submit(executor, task_to_submit) - - -class NOOPSubmissionTask(SubmissionTask): - def _submit(self, transfer_future, **kwargs): - pass - - -class ExceptionSubmissionTask(SubmissionTask): - def _submit( - self, - transfer_future, - executor=None, - tasks_to_submit=None, - additional_callbacks=None, - exception=TaskFailureException, - ): - if executor and tasks_to_submit: - for task_to_submit in tasks_to_submit: - self._transfer_coordinator.submit(executor, task_to_submit) - if additional_callbacks: - for callback in additional_callbacks: - callback() - raise exception() - - -class StatusRecordingTransferCoordinator(TransferCoordinator): - def __init__(self, transfer_id=None): - super().__init__(transfer_id) - self.status_changes = [self._status] - - def set_status_to_queued(self): - super().set_status_to_queued() - self._record_status_change() - - def set_status_to_running(self): - super().set_status_to_running() - self._record_status_change() - - def _record_status_change(self): - self.status_changes.append(self._status) - - -class RecordingStateSubscriber(BaseSubscriber): - def __init__(self, transfer_coordinator): - self._transfer_coordinator = transfer_coordinator - self.status_during_on_queued = None - - def on_queued(self, **kwargs): - self.status_during_on_queued = self._transfer_coordinator.status - - -class TestSubmissionTask(BaseSubmissionTaskTest): - def setUp(self): - super().setUp() - self.executor = BoundedExecutor(1000, 5) - self.call_args = CallArgs(subscribers=[]) - self.transfer_future = self.get_transfer_future(self.call_args) - self.main_kwargs = {'transfer_future': self.transfer_future} - - def test_transitions_from_not_started_to_queued_to_running(self): - self.transfer_coordinator = StatusRecordingTransferCoordinator() - submission_task = self.get_task( - NOOPSubmissionTask, main_kwargs=self.main_kwargs - ) - # Status should be queued until submission task has been ran. - self.assertEqual(self.transfer_coordinator.status, 'not-started') - - submission_task() - # Once submission task has been ran, the status should now be running. - self.assertEqual(self.transfer_coordinator.status, 'running') - - # Ensure the transitions were as expected as well. - self.assertEqual( - self.transfer_coordinator.status_changes, - ['not-started', 'queued', 'running'], - ) - - def test_on_queued_callbacks(self): - submission_task = self.get_task( - NOOPSubmissionTask, main_kwargs=self.main_kwargs - ) - - subscriber = RecordingSubscriber() - self.call_args.subscribers.append(subscriber) - submission_task() - # Make sure the on_queued callback of the subscriber is called. - self.assertEqual( - subscriber.on_queued_calls, [{'future': self.transfer_future}] - ) - - def test_on_queued_status_in_callbacks(self): - submission_task = self.get_task( - NOOPSubmissionTask, main_kwargs=self.main_kwargs - ) - - subscriber = RecordingStateSubscriber(self.transfer_coordinator) - self.call_args.subscribers.append(subscriber) - submission_task() - # Make sure the status was queued during on_queued callback. - self.assertEqual(subscriber.status_during_on_queued, 'queued') - - def test_sets_exception_from_submit(self): - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - submission_task() - - # Make sure the status of the future is failed - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Make sure the future propagates the exception encountered in the - # submission task. - with self.assertRaises(TaskFailureException): - self.transfer_future.result() - - def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self): - self.main_kwargs['exception'] = KeyboardInterrupt - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - submission_task() - - self.assertEqual(self.transfer_coordinator.status, 'failed') - with self.assertRaises(KeyboardInterrupt): - self.transfer_future.result() - - def test_calls_done_callbacks_on_exception(self): - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - - subscriber = RecordingSubscriber() - self.call_args.subscribers.append(subscriber) - - # Add the done callback to the callbacks to be invoked when the - # transfer is done. - done_callbacks = get_callbacks(self.transfer_future, 'done') - for done_callback in done_callbacks: - self.transfer_coordinator.add_done_callback(done_callback) - submission_task() - - # Make sure the task failed to start - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Make sure the on_done callback of the subscriber is called. - self.assertEqual( - subscriber.on_done_calls, [{'future': self.transfer_future}] - ) - - def test_calls_failure_cleanups_on_exception(self): - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - - # Add the callback to the callbacks to be invoked when the - # transfer fails. - invocations_of_cleanup = [] - cleanup_callback = FunctionContainer( - invocations_of_cleanup.append, 'cleanup happened' - ) - self.transfer_coordinator.add_failure_cleanup(cleanup_callback) - submission_task() - - # Make sure the task failed to start - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Make sure the cleanup was called. - self.assertEqual(invocations_of_cleanup, ['cleanup happened']) - - def test_cleanups_only_ran_once_on_exception(self): - # We want to be able to handle the case where the final task completes - # and anounces done but there is an error in the submission task - # which will cause it to need to announce done as well. In this case, - # we do not want the done callbacks to be invoke more than once. - - final_task = self.get_task(FailureTask, is_final=True) - self.main_kwargs['executor'] = self.executor - self.main_kwargs['tasks_to_submit'] = [final_task] - - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - - subscriber = RecordingSubscriber() - self.call_args.subscribers.append(subscriber) - - # Add the done callback to the callbacks to be invoked when the - # transfer is done. - done_callbacks = get_callbacks(self.transfer_future, 'done') - for done_callback in done_callbacks: - self.transfer_coordinator.add_done_callback(done_callback) - - submission_task() - - # Make sure the task failed to start - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Make sure the on_done callback of the subscriber is called only once. - self.assertEqual( - subscriber.on_done_calls, [{'future': self.transfer_future}] - ) - - def test_done_callbacks_only_ran_once_on_exception(self): - # We want to be able to handle the case where the final task completes - # and anounces done but there is an error in the submission task - # which will cause it to need to announce done as well. In this case, - # we do not want the failure cleanups to be invoked more than once. - - final_task = self.get_task(FailureTask, is_final=True) - self.main_kwargs['executor'] = self.executor - self.main_kwargs['tasks_to_submit'] = [final_task] - - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - - # Add the callback to the callbacks to be invoked when the - # transfer fails. - invocations_of_cleanup = [] - cleanup_callback = FunctionContainer( - invocations_of_cleanup.append, 'cleanup happened' - ) - self.transfer_coordinator.add_failure_cleanup(cleanup_callback) - submission_task() - - # Make sure the task failed to start - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Make sure the cleanup was called only once. - self.assertEqual(invocations_of_cleanup, ['cleanup happened']) - - def test_handles_cleanups_submitted_in_other_tasks(self): - invocations_of_cleanup = [] - event = Event() - cleanup_callback = FunctionContainer( - invocations_of_cleanup.append, 'cleanup happened' - ) - # We want the cleanup to be added in the execution of the task and - # still be executed by the submission task when it fails. - task = self.get_task( - SuccessTask, - main_kwargs={ - 'callbacks': [event.set], - 'failure_cleanups': [cleanup_callback], - }, - ) - - self.main_kwargs['executor'] = self.executor - self.main_kwargs['tasks_to_submit'] = [task] - self.main_kwargs['additional_callbacks'] = [event.wait] - - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - - submission_task() - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Make sure the cleanup was called even though the callback got - # added in a completely different task. - self.assertEqual(invocations_of_cleanup, ['cleanup happened']) - - def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self): - # In this test, we want to make sure that any tasks that may be - # submitted in another task complete before we start performing - # cleanups. - # - # This is tested by doing the following: - # - # ExecutionSubmissionTask - # | - # +--submits-->SubmitMoreTasksTask - # | - # +--submits-->SuccessTask - # | - # +-->sleeps-->adds failure cleanup - # - # In the end, the failure cleanup of the SuccessTask should be ran - # when the ExecutionSubmissionTask fails. If the - # ExeceptionSubmissionTask did not run the failure cleanup it is most - # likely that it did not wait for the SuccessTask to complete, which - # it needs to because the ExeceptionSubmissionTask does not know - # what failure cleanups it needs to run until all spawned tasks have - # completed. - invocations_of_cleanup = [] - event = Event() - cleanup_callback = FunctionContainer( - invocations_of_cleanup.append, 'cleanup happened' - ) - - cleanup_task = self.get_task( - SuccessTask, - main_kwargs={ - 'callbacks': [event.set], - 'failure_cleanups': [cleanup_callback], - }, - ) - task_for_submitting_cleanup_task = self.get_task( - SubmitMoreTasksTask, - main_kwargs={ - 'executor': self.executor, - 'tasks_to_submit': [cleanup_task], - }, - ) - - self.main_kwargs['executor'] = self.executor - self.main_kwargs['tasks_to_submit'] = [ - task_for_submitting_cleanup_task - ] - self.main_kwargs['additional_callbacks'] = [event.wait] - - submission_task = self.get_task( - ExceptionSubmissionTask, main_kwargs=self.main_kwargs - ) - - submission_task() - self.assertEqual(self.transfer_coordinator.status, 'failed') - self.assertEqual(invocations_of_cleanup, ['cleanup happened']) - - def test_submission_task_announces_done_if_cancelled_before_main(self): - invocations_of_done = [] - done_callback = FunctionContainer( - invocations_of_done.append, 'done announced' - ) - self.transfer_coordinator.add_done_callback(done_callback) - - self.transfer_coordinator.cancel() - submission_task = self.get_task( - NOOPSubmissionTask, main_kwargs=self.main_kwargs - ) - submission_task() - - # Because the submission task was cancelled before being run - # it did not submit any extra tasks so a result it is responsible - # for making sure it announces done as nothing else will. - self.assertEqual(invocations_of_done, ['done announced']) - - -class TestTask(unittest.TestCase): - def setUp(self): - self.transfer_id = 1 - self.transfer_coordinator = TransferCoordinator( - transfer_id=self.transfer_id - ) - - def test_repr(self): - main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'} - task = ReturnKwargsTask( - self.transfer_coordinator, main_kwargs=main_kwargs - ) - # The repr should not include the other parameter because it is not - # a desired parameter to include. - self.assertEqual( - repr(task), - 'ReturnKwargsTask(transfer_id={}, {})'.format( - self.transfer_id, {'bucket': 'mybucket'} - ), - ) - - def test_transfer_id(self): - task = SuccessTask(self.transfer_coordinator) - # Make sure that the id is the one provided to the id associated - # to the transfer coordinator. - self.assertEqual(task.transfer_id, self.transfer_id) - - def test_context_status_transitioning_success(self): - # The status should be set to running. - self.transfer_coordinator.set_status_to_running() - self.assertEqual(self.transfer_coordinator.status, 'running') - - # If a task is called, the status still should be running. - SuccessTask(self.transfer_coordinator)() - self.assertEqual(self.transfer_coordinator.status, 'running') - - # Once the final task is called, the status should be set to success. - SuccessTask(self.transfer_coordinator, is_final=True)() - self.assertEqual(self.transfer_coordinator.status, 'success') - - def test_context_status_transitioning_failed(self): - self.transfer_coordinator.set_status_to_running() - - SuccessTask(self.transfer_coordinator)() - self.assertEqual(self.transfer_coordinator.status, 'running') - - # A failure task should result in the failed status - FailureTask(self.transfer_coordinator)() - self.assertEqual(self.transfer_coordinator.status, 'failed') - - # Even if the final task comes in and succeeds, it should stay failed. - SuccessTask(self.transfer_coordinator, is_final=True)() - self.assertEqual(self.transfer_coordinator.status, 'failed') - - def test_result_setting_for_success(self): - override_return = 'foo' - SuccessTask(self.transfer_coordinator)() - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': override_return}, - is_final=True, - )() - - # The return value for the transfer future should be of the final - # task. - self.assertEqual(self.transfer_coordinator.result(), override_return) - - def test_result_setting_for_error(self): - FailureTask(self.transfer_coordinator)() - - # If another failure comes in, the result should still throw the - # original exception when result() is eventually called. - FailureTask( - self.transfer_coordinator, main_kwargs={'exception': Exception} - )() - - # Even if a success task comes along, the result of the future - # should be the original exception - SuccessTask(self.transfer_coordinator, is_final=True)() - with self.assertRaises(TaskFailureException): - self.transfer_coordinator.result() - - def test_done_callbacks_success(self): - callback_results = [] - SuccessTask( - self.transfer_coordinator, - done_callbacks=[ - partial(callback_results.append, 'first'), - partial(callback_results.append, 'second'), - ], - )() - # For successful tasks, the done callbacks should get called. - self.assertEqual(callback_results, ['first', 'second']) - - def test_done_callbacks_failure(self): - callback_results = [] - FailureTask( - self.transfer_coordinator, - done_callbacks=[ - partial(callback_results.append, 'first'), - partial(callback_results.append, 'second'), - ], - )() - # For even failed tasks, the done callbacks should get called. - self.assertEqual(callback_results, ['first', 'second']) - - # Callbacks should continue to be called even after a related failure - SuccessTask( - self.transfer_coordinator, - done_callbacks=[ - partial(callback_results.append, 'third'), - partial(callback_results.append, 'fourth'), - ], - )() - self.assertEqual( - callback_results, ['first', 'second', 'third', 'fourth'] - ) - - def test_failure_cleanups_on_failure(self): - callback_results = [] - self.transfer_coordinator.add_failure_cleanup( - callback_results.append, 'first' - ) - self.transfer_coordinator.add_failure_cleanup( - callback_results.append, 'second' - ) - FailureTask(self.transfer_coordinator)() - # The failure callbacks should have not been called yet because it - # is not the last task - self.assertEqual(callback_results, []) - - # Now the failure callbacks should get called. - SuccessTask(self.transfer_coordinator, is_final=True)() - self.assertEqual(callback_results, ['first', 'second']) - - def test_no_failure_cleanups_on_success(self): - callback_results = [] - self.transfer_coordinator.add_failure_cleanup( - callback_results.append, 'first' - ) - self.transfer_coordinator.add_failure_cleanup( - callback_results.append, 'second' - ) - SuccessTask(self.transfer_coordinator, is_final=True)() - # The failure cleanups should not have been called because no task - # failed for the transfer context. - self.assertEqual(callback_results, []) - - def test_passing_main_kwargs(self): - main_kwargs = {'foo': 'bar', 'baz': 'biz'} - ReturnKwargsTask( - self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True - )() - # The kwargs should have been passed to the main() - self.assertEqual(self.transfer_coordinator.result(), main_kwargs) - - def test_passing_pending_kwargs_single_futures(self): - pending_kwargs = {} - ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'} - - # Pass some tasks to an executor - with futures.ThreadPoolExecutor(1) as executor: - pending_kwargs['foo'] = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': ref_main_kwargs['foo']}, - ) - ) - pending_kwargs['baz'] = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': ref_main_kwargs['baz']}, - ) - ) - - # Create a task that depends on the tasks passed to the executor - ReturnKwargsTask( - self.transfer_coordinator, - pending_main_kwargs=pending_kwargs, - is_final=True, - )() - # The result should have the pending keyword arg values flushed - # out. - self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) - - def test_passing_pending_kwargs_list_of_futures(self): - pending_kwargs = {} - ref_main_kwargs = {'foo': ['first', 'second']} - - # Pass some tasks to an executor - with futures.ThreadPoolExecutor(1) as executor: - first_future = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': ref_main_kwargs['foo'][0]}, - ) - ) - second_future = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': ref_main_kwargs['foo'][1]}, - ) - ) - # Make the pending keyword arg value a list - pending_kwargs['foo'] = [first_future, second_future] - - # Create a task that depends on the tasks passed to the executor - ReturnKwargsTask( - self.transfer_coordinator, - pending_main_kwargs=pending_kwargs, - is_final=True, - )() - # The result should have the pending keyword arg values flushed - # out in the expected order. - self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) - - def test_passing_pending_and_non_pending_kwargs(self): - main_kwargs = {'nonpending_value': 'foo'} - pending_kwargs = {} - ref_main_kwargs = { - 'nonpending_value': 'foo', - 'pending_value': 'bar', - 'pending_list': ['first', 'second'], - } - - # Create the pending tasks - with futures.ThreadPoolExecutor(1) as executor: - pending_kwargs['pending_value'] = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={ - 'return_value': ref_main_kwargs['pending_value'] - }, - ) - ) - - first_future = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={ - 'return_value': ref_main_kwargs['pending_list'][0] - }, - ) - ) - second_future = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={ - 'return_value': ref_main_kwargs['pending_list'][1] - }, - ) - ) - # Make the pending keyword arg value a list - pending_kwargs['pending_list'] = [first_future, second_future] - - # Create a task that depends on the tasks passed to the executor - # and just regular nonpending kwargs. - ReturnKwargsTask( - self.transfer_coordinator, - main_kwargs=main_kwargs, - pending_main_kwargs=pending_kwargs, - is_final=True, - )() - # The result should have all of the kwargs (both pending and - # nonpending) - self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) - - def test_single_failed_pending_future(self): - pending_kwargs = {} - - # Pass some tasks to an executor. Make one successful and the other - # a failure. - with futures.ThreadPoolExecutor(1) as executor: - pending_kwargs['foo'] = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': 'bar'}, - ) - ) - pending_kwargs['baz'] = executor.submit( - FailureTask(self.transfer_coordinator) - ) - - # Create a task that depends on the tasks passed to the executor - ReturnKwargsTask( - self.transfer_coordinator, - pending_main_kwargs=pending_kwargs, - is_final=True, - )() - # The end result should raise the exception from the initial - # pending future value - with self.assertRaises(TaskFailureException): - self.transfer_coordinator.result() - - def test_single_failed_pending_future_in_list(self): - pending_kwargs = {} - - # Pass some tasks to an executor. Make one successful and the other - # a failure. - with futures.ThreadPoolExecutor(1) as executor: - first_future = executor.submit( - SuccessTask( - self.transfer_coordinator, - main_kwargs={'return_value': 'bar'}, - ) - ) - second_future = executor.submit( - FailureTask(self.transfer_coordinator) - ) - - pending_kwargs['pending_list'] = [first_future, second_future] - - # Create a task that depends on the tasks passed to the executor - ReturnKwargsTask( - self.transfer_coordinator, - pending_main_kwargs=pending_kwargs, - is_final=True, - )() - # The end result should raise the exception from the initial - # pending future value in the list - with self.assertRaises(TaskFailureException): - self.transfer_coordinator.result() - - -class BaseMultipartTaskTest(BaseTaskTest): - def setUp(self): - super().setUp() - self.bucket = 'mybucket' - self.key = 'foo' - - -class TestCreateMultipartUploadTask(BaseMultipartTaskTest): - def test_main(self): - upload_id = 'foo' - extra_args = {'Metadata': {'foo': 'bar'}} - response = {'UploadId': upload_id} - task = self.get_task( - CreateMultipartUploadTask, - main_kwargs={ - 'client': self.client, - 'bucket': self.bucket, - 'key': self.key, - 'extra_args': extra_args, - }, - ) - self.stubber.add_response( - method='create_multipart_upload', - service_response=response, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'Metadata': {'foo': 'bar'}, - }, - ) - result_id = task() - self.stubber.assert_no_pending_responses() - # Ensure the upload id returned is correct - self.assertEqual(upload_id, result_id) - - # Make sure that the abort was added as a cleanup failure - self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1) - - # Make sure if it is called, it will abort correctly - self.stubber.add_response( - method='abort_multipart_upload', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': upload_id, - }, - ) - self.transfer_coordinator.failure_cleanups[0]() - self.stubber.assert_no_pending_responses() - - -class TestCompleteMultipartUploadTask(BaseMultipartTaskTest): - def test_main(self): - upload_id = 'my-id' - parts = [{'ETag': 'etag', 'PartNumber': 0}] - task = self.get_task( - CompleteMultipartUploadTask, - main_kwargs={ - 'client': self.client, - 'bucket': self.bucket, - 'key': self.key, - 'upload_id': upload_id, - 'parts': parts, - 'extra_args': {}, - }, - ) - self.stubber.add_response( - method='complete_multipart_upload', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': upload_id, - 'MultipartUpload': {'Parts': parts}, - }, - ) - task() - self.stubber.assert_no_pending_responses() - - def test_includes_extra_args(self): - upload_id = 'my-id' - parts = [{'ETag': 'etag', 'PartNumber': 0}] - task = self.get_task( - CompleteMultipartUploadTask, - main_kwargs={ - 'client': self.client, - 'bucket': self.bucket, - 'key': self.key, - 'upload_id': upload_id, - 'parts': parts, - 'extra_args': {'RequestPayer': 'requester'}, - }, - ) - self.stubber.add_response( - method='complete_multipart_upload', - service_response={}, - expected_params={ - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': upload_id, - 'MultipartUpload': {'Parts': parts}, - 'RequestPayer': 'requester', - }, - ) - task() - self.stubber.assert_no_pending_responses() +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from concurrent import futures +from functools import partial +from threading import Event + +from s3transfer.futures import BoundedExecutor, TransferCoordinator +from s3transfer.subscribers import BaseSubscriber +from s3transfer.tasks import ( + CompleteMultipartUploadTask, + CreateMultipartUploadTask, + SubmissionTask, + Task, +) +from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + RecordingSubscriber, + unittest, +) + + +class TaskFailureException(Exception): + pass + + +class SuccessTask(Task): + def _main( + self, return_value='success', callbacks=None, failure_cleanups=None + ): + if callbacks: + for callback in callbacks: + callback() + if failure_cleanups: + for failure_cleanup in failure_cleanups: + self._transfer_coordinator.add_failure_cleanup(failure_cleanup) + return return_value + + +class FailureTask(Task): + def _main(self, exception=TaskFailureException): + raise exception() + + +class ReturnKwargsTask(Task): + def _main(self, **kwargs): + return kwargs + + +class SubmitMoreTasksTask(Task): + def _main(self, executor, tasks_to_submit): + for task_to_submit in tasks_to_submit: + self._transfer_coordinator.submit(executor, task_to_submit) + + +class NOOPSubmissionTask(SubmissionTask): + def _submit(self, transfer_future, **kwargs): + pass + + +class ExceptionSubmissionTask(SubmissionTask): + def _submit( + self, + transfer_future, + executor=None, + tasks_to_submit=None, + additional_callbacks=None, + exception=TaskFailureException, + ): + if executor and tasks_to_submit: + for task_to_submit in tasks_to_submit: + self._transfer_coordinator.submit(executor, task_to_submit) + if additional_callbacks: + for callback in additional_callbacks: + callback() + raise exception() + + +class StatusRecordingTransferCoordinator(TransferCoordinator): + def __init__(self, transfer_id=None): + super().__init__(transfer_id) + self.status_changes = [self._status] + + def set_status_to_queued(self): + super().set_status_to_queued() + self._record_status_change() + + def set_status_to_running(self): + super().set_status_to_running() + self._record_status_change() + + def _record_status_change(self): + self.status_changes.append(self._status) + + +class RecordingStateSubscriber(BaseSubscriber): + def __init__(self, transfer_coordinator): + self._transfer_coordinator = transfer_coordinator + self.status_during_on_queued = None + + def on_queued(self, **kwargs): + self.status_during_on_queued = self._transfer_coordinator.status + + +class TestSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.executor = BoundedExecutor(1000, 5) + self.call_args = CallArgs(subscribers=[]) + self.transfer_future = self.get_transfer_future(self.call_args) + self.main_kwargs = {'transfer_future': self.transfer_future} + + def test_transitions_from_not_started_to_queued_to_running(self): + self.transfer_coordinator = StatusRecordingTransferCoordinator() + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + # Status should be queued until submission task has been ran. + self.assertEqual(self.transfer_coordinator.status, 'not-started') + + submission_task() + # Once submission task has been ran, the status should now be running. + self.assertEqual(self.transfer_coordinator.status, 'running') + + # Ensure the transitions were as expected as well. + self.assertEqual( + self.transfer_coordinator.status_changes, + ['not-started', 'queued', 'running'], + ) + + def test_on_queued_callbacks(self): + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + submission_task() + # Make sure the on_queued callback of the subscriber is called. + self.assertEqual( + subscriber.on_queued_calls, [{'future': self.transfer_future}] + ) + + def test_on_queued_status_in_callbacks(self): + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingStateSubscriber(self.transfer_coordinator) + self.call_args.subscribers.append(subscriber) + submission_task() + # Make sure the status was queued during on_queued callback. + self.assertEqual(subscriber.status_during_on_queued, 'queued') + + def test_sets_exception_from_submit(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + # Make sure the status of the future is failed + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the future propagates the exception encountered in the + # submission task. + with self.assertRaises(TaskFailureException): + self.transfer_future.result() + + def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self): + self.main_kwargs['exception'] = KeyboardInterrupt + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + self.assertEqual(self.transfer_coordinator.status, 'failed') + with self.assertRaises(KeyboardInterrupt): + self.transfer_future.result() + + def test_calls_done_callbacks_on_exception(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + + # Add the done callback to the callbacks to be invoked when the + # transfer is done. + done_callbacks = get_callbacks(self.transfer_future, 'done') + for done_callback in done_callbacks: + self.transfer_coordinator.add_done_callback(done_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the on_done callback of the subscriber is called. + self.assertEqual( + subscriber.on_done_calls, [{'future': self.transfer_future}] + ) + + def test_calls_failure_cleanups_on_exception(self): + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + # Add the callback to the callbacks to be invoked when the + # transfer fails. + invocations_of_cleanup = [] + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + self.transfer_coordinator.add_failure_cleanup(cleanup_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_cleanups_only_ran_once_on_exception(self): + # We want to be able to handle the case where the final task completes + # and anounces done but there is an error in the submission task + # which will cause it to need to announce done as well. In this case, + # we do not want the done callbacks to be invoke more than once. + + final_task = self.get_task(FailureTask, is_final=True) + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [final_task] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + subscriber = RecordingSubscriber() + self.call_args.subscribers.append(subscriber) + + # Add the done callback to the callbacks to be invoked when the + # transfer is done. + done_callbacks = get_callbacks(self.transfer_future, 'done') + for done_callback in done_callbacks: + self.transfer_coordinator.add_done_callback(done_callback) + + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the on_done callback of the subscriber is called only once. + self.assertEqual( + subscriber.on_done_calls, [{'future': self.transfer_future}] + ) + + def test_done_callbacks_only_ran_once_on_exception(self): + # We want to be able to handle the case where the final task completes + # and anounces done but there is an error in the submission task + # which will cause it to need to announce done as well. In this case, + # we do not want the failure cleanups to be invoked more than once. + + final_task = self.get_task(FailureTask, is_final=True) + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [final_task] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + # Add the callback to the callbacks to be invoked when the + # transfer fails. + invocations_of_cleanup = [] + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + self.transfer_coordinator.add_failure_cleanup(cleanup_callback) + submission_task() + + # Make sure the task failed to start + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called only once. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_handles_cleanups_submitted_in_other_tasks(self): + invocations_of_cleanup = [] + event = Event() + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + # We want the cleanup to be added in the execution of the task and + # still be executed by the submission task when it fails. + task = self.get_task( + SuccessTask, + main_kwargs={ + 'callbacks': [event.set], + 'failure_cleanups': [cleanup_callback], + }, + ) + + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [task] + self.main_kwargs['additional_callbacks'] = [event.wait] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + submission_task() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Make sure the cleanup was called even though the callback got + # added in a completely different task. + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self): + # In this test, we want to make sure that any tasks that may be + # submitted in another task complete before we start performing + # cleanups. + # + # This is tested by doing the following: + # + # ExecutionSubmissionTask + # | + # +--submits-->SubmitMoreTasksTask + # | + # +--submits-->SuccessTask + # | + # +-->sleeps-->adds failure cleanup + # + # In the end, the failure cleanup of the SuccessTask should be ran + # when the ExecutionSubmissionTask fails. If the + # ExeceptionSubmissionTask did not run the failure cleanup it is most + # likely that it did not wait for the SuccessTask to complete, which + # it needs to because the ExeceptionSubmissionTask does not know + # what failure cleanups it needs to run until all spawned tasks have + # completed. + invocations_of_cleanup = [] + event = Event() + cleanup_callback = FunctionContainer( + invocations_of_cleanup.append, 'cleanup happened' + ) + + cleanup_task = self.get_task( + SuccessTask, + main_kwargs={ + 'callbacks': [event.set], + 'failure_cleanups': [cleanup_callback], + }, + ) + task_for_submitting_cleanup_task = self.get_task( + SubmitMoreTasksTask, + main_kwargs={ + 'executor': self.executor, + 'tasks_to_submit': [cleanup_task], + }, + ) + + self.main_kwargs['executor'] = self.executor + self.main_kwargs['tasks_to_submit'] = [ + task_for_submitting_cleanup_task + ] + self.main_kwargs['additional_callbacks'] = [event.wait] + + submission_task = self.get_task( + ExceptionSubmissionTask, main_kwargs=self.main_kwargs + ) + + submission_task() + self.assertEqual(self.transfer_coordinator.status, 'failed') + self.assertEqual(invocations_of_cleanup, ['cleanup happened']) + + def test_submission_task_announces_done_if_cancelled_before_main(self): + invocations_of_done = [] + done_callback = FunctionContainer( + invocations_of_done.append, 'done announced' + ) + self.transfer_coordinator.add_done_callback(done_callback) + + self.transfer_coordinator.cancel() + submission_task = self.get_task( + NOOPSubmissionTask, main_kwargs=self.main_kwargs + ) + submission_task() + + # Because the submission task was cancelled before being run + # it did not submit any extra tasks so a result it is responsible + # for making sure it announces done as nothing else will. + self.assertEqual(invocations_of_done, ['done announced']) + + +class TestTask(unittest.TestCase): + def setUp(self): + self.transfer_id = 1 + self.transfer_coordinator = TransferCoordinator( + transfer_id=self.transfer_id + ) + + def test_repr(self): + main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'} + task = ReturnKwargsTask( + self.transfer_coordinator, main_kwargs=main_kwargs + ) + # The repr should not include the other parameter because it is not + # a desired parameter to include. + self.assertEqual( + repr(task), + 'ReturnKwargsTask(transfer_id={}, {})'.format( + self.transfer_id, {'bucket': 'mybucket'} + ), + ) + + def test_transfer_id(self): + task = SuccessTask(self.transfer_coordinator) + # Make sure that the id is the one provided to the id associated + # to the transfer coordinator. + self.assertEqual(task.transfer_id, self.transfer_id) + + def test_context_status_transitioning_success(self): + # The status should be set to running. + self.transfer_coordinator.set_status_to_running() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # If a task is called, the status still should be running. + SuccessTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # Once the final task is called, the status should be set to success. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(self.transfer_coordinator.status, 'success') + + def test_context_status_transitioning_failed(self): + self.transfer_coordinator.set_status_to_running() + + SuccessTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'running') + + # A failure task should result in the failed status + FailureTask(self.transfer_coordinator)() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + # Even if the final task comes in and succeeds, it should stay failed. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(self.transfer_coordinator.status, 'failed') + + def test_result_setting_for_success(self): + override_return = 'foo' + SuccessTask(self.transfer_coordinator)() + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': override_return}, + is_final=True, + )() + + # The return value for the transfer future should be of the final + # task. + self.assertEqual(self.transfer_coordinator.result(), override_return) + + def test_result_setting_for_error(self): + FailureTask(self.transfer_coordinator)() + + # If another failure comes in, the result should still throw the + # original exception when result() is eventually called. + FailureTask( + self.transfer_coordinator, main_kwargs={'exception': Exception} + )() + + # Even if a success task comes along, the result of the future + # should be the original exception + SuccessTask(self.transfer_coordinator, is_final=True)() + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + def test_done_callbacks_success(self): + callback_results = [] + SuccessTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'first'), + partial(callback_results.append, 'second'), + ], + )() + # For successful tasks, the done callbacks should get called. + self.assertEqual(callback_results, ['first', 'second']) + + def test_done_callbacks_failure(self): + callback_results = [] + FailureTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'first'), + partial(callback_results.append, 'second'), + ], + )() + # For even failed tasks, the done callbacks should get called. + self.assertEqual(callback_results, ['first', 'second']) + + # Callbacks should continue to be called even after a related failure + SuccessTask( + self.transfer_coordinator, + done_callbacks=[ + partial(callback_results.append, 'third'), + partial(callback_results.append, 'fourth'), + ], + )() + self.assertEqual( + callback_results, ['first', 'second', 'third', 'fourth'] + ) + + def test_failure_cleanups_on_failure(self): + callback_results = [] + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'first' + ) + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'second' + ) + FailureTask(self.transfer_coordinator)() + # The failure callbacks should have not been called yet because it + # is not the last task + self.assertEqual(callback_results, []) + + # Now the failure callbacks should get called. + SuccessTask(self.transfer_coordinator, is_final=True)() + self.assertEqual(callback_results, ['first', 'second']) + + def test_no_failure_cleanups_on_success(self): + callback_results = [] + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'first' + ) + self.transfer_coordinator.add_failure_cleanup( + callback_results.append, 'second' + ) + SuccessTask(self.transfer_coordinator, is_final=True)() + # The failure cleanups should not have been called because no task + # failed for the transfer context. + self.assertEqual(callback_results, []) + + def test_passing_main_kwargs(self): + main_kwargs = {'foo': 'bar', 'baz': 'biz'} + ReturnKwargsTask( + self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True + )() + # The kwargs should have been passed to the main() + self.assertEqual(self.transfer_coordinator.result(), main_kwargs) + + def test_passing_pending_kwargs_single_futures(self): + pending_kwargs = {} + ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'} + + # Pass some tasks to an executor + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['foo'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo']}, + ) + ) + pending_kwargs['baz'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['baz']}, + ) + ) + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have the pending keyword arg values flushed + # out. + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_passing_pending_kwargs_list_of_futures(self): + pending_kwargs = {} + ref_main_kwargs = {'foo': ['first', 'second']} + + # Pass some tasks to an executor + with futures.ThreadPoolExecutor(1) as executor: + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo'][0]}, + ) + ) + second_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': ref_main_kwargs['foo'][1]}, + ) + ) + # Make the pending keyword arg value a list + pending_kwargs['foo'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have the pending keyword arg values flushed + # out in the expected order. + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_passing_pending_and_non_pending_kwargs(self): + main_kwargs = {'nonpending_value': 'foo'} + pending_kwargs = {} + ref_main_kwargs = { + 'nonpending_value': 'foo', + 'pending_value': 'bar', + 'pending_list': ['first', 'second'], + } + + # Create the pending tasks + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['pending_value'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_value'] + }, + ) + ) + + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_list'][0] + }, + ) + ) + second_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={ + 'return_value': ref_main_kwargs['pending_list'][1] + }, + ) + ) + # Make the pending keyword arg value a list + pending_kwargs['pending_list'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + # and just regular nonpending kwargs. + ReturnKwargsTask( + self.transfer_coordinator, + main_kwargs=main_kwargs, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The result should have all of the kwargs (both pending and + # nonpending) + self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs) + + def test_single_failed_pending_future(self): + pending_kwargs = {} + + # Pass some tasks to an executor. Make one successful and the other + # a failure. + with futures.ThreadPoolExecutor(1) as executor: + pending_kwargs['foo'] = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': 'bar'}, + ) + ) + pending_kwargs['baz'] = executor.submit( + FailureTask(self.transfer_coordinator) + ) + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The end result should raise the exception from the initial + # pending future value + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + def test_single_failed_pending_future_in_list(self): + pending_kwargs = {} + + # Pass some tasks to an executor. Make one successful and the other + # a failure. + with futures.ThreadPoolExecutor(1) as executor: + first_future = executor.submit( + SuccessTask( + self.transfer_coordinator, + main_kwargs={'return_value': 'bar'}, + ) + ) + second_future = executor.submit( + FailureTask(self.transfer_coordinator) + ) + + pending_kwargs['pending_list'] = [first_future, second_future] + + # Create a task that depends on the tasks passed to the executor + ReturnKwargsTask( + self.transfer_coordinator, + pending_main_kwargs=pending_kwargs, + is_final=True, + )() + # The end result should raise the exception from the initial + # pending future value in the list + with self.assertRaises(TaskFailureException): + self.transfer_coordinator.result() + + +class BaseMultipartTaskTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'foo' + + +class TestCreateMultipartUploadTask(BaseMultipartTaskTest): + def test_main(self): + upload_id = 'foo' + extra_args = {'Metadata': {'foo': 'bar'}} + response = {'UploadId': upload_id} + task = self.get_task( + CreateMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='create_multipart_upload', + service_response=response, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'Metadata': {'foo': 'bar'}, + }, + ) + result_id = task() + self.stubber.assert_no_pending_responses() + # Ensure the upload id returned is correct + self.assertEqual(upload_id, result_id) + + # Make sure that the abort was added as a cleanup failure + self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1) + + # Make sure if it is called, it will abort correctly + self.stubber.add_response( + method='abort_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + }, + ) + self.transfer_coordinator.failure_cleanups[0]() + self.stubber.assert_no_pending_responses() + + +class TestCompleteMultipartUploadTask(BaseMultipartTaskTest): + def test_main(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {}, + }, + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': {'Parts': parts}, + }, + ) + task() + self.stubber.assert_no_pending_responses() + + def test_includes_extra_args(self): + upload_id = 'my-id' + parts = [{'ETag': 'etag', 'PartNumber': 0}] + task = self.get_task( + CompleteMultipartUploadTask, + main_kwargs={ + 'client': self.client, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'parts': parts, + 'extra_args': {'RequestPayer': 'requester'}, + }, + ) + self.stubber.add_response( + method='complete_multipart_upload', + service_response={}, + expected_params={ + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'MultipartUpload': {'Parts': parts}, + 'RequestPayer': 'requester', + }, + ) + task() + self.stubber.assert_no_pending_responses() diff --git a/contrib/python/s3transfer/py3/tests/unit/test_upload.py b/contrib/python/s3transfer/py3/tests/unit/test_upload.py index 622255417c..1ac38b3616 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_upload.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_upload.py @@ -1,694 +1,694 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the 'license' file accompanying this file. This file is -# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -import math -import os -import shutil -import tempfile -from io import BytesIO - -from botocore.stub import ANY - -from s3transfer.futures import IN_MEMORY_UPLOAD_TAG -from s3transfer.manager import TransferConfig -from s3transfer.upload import ( - AggregatedProgressCallback, - InterruptReader, - PutObjectTask, - UploadFilenameInputManager, - UploadNonSeekableInputManager, - UploadPartTask, - UploadSeekableInputManager, - UploadSubmissionTask, -) -from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils -from __tests__ import ( - BaseSubmissionTaskTest, - BaseTaskTest, - FileSizeProvider, - NonSeekableReader, - RecordingExecutor, - RecordingSubscriber, - unittest, -) - - -class InterruptionError(Exception): - pass - - -class OSUtilsExceptionOnFileSize(OSUtils): - def get_file_size(self, filename): - raise AssertionError( - "The file %s should not have been stated" % filename - ) - - -class BaseUploadTest(BaseTaskTest): - def setUp(self): - super().setUp() - self.bucket = 'mybucket' - self.key = 'foo' - self.osutil = OSUtils() - - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - self.content = b'my content' - self.subscribers = [] - - with open(self.filename, 'wb') as f: - f.write(self.content) - - # A list to keep track of all of the bodies sent over the wire - # and their order. - self.sent_bodies = [] - self.client.meta.events.register( - 'before-parameter-build.s3.*', self.collect_body - ) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.tempdir) - - def collect_body(self, params, **kwargs): - if 'Body' in params: - self.sent_bodies.append(params['Body'].read()) - - -class TestAggregatedProgressCallback(unittest.TestCase): - def setUp(self): - self.aggregated_amounts = [] - self.threshold = 3 - self.aggregated_progress_callback = AggregatedProgressCallback( - [self.callback], self.threshold - ) - - def callback(self, bytes_transferred): - self.aggregated_amounts.append(bytes_transferred) - - def test_under_threshold(self): - one_under_threshold_amount = self.threshold - 1 - self.aggregated_progress_callback(one_under_threshold_amount) - self.assertEqual(self.aggregated_amounts, []) - self.aggregated_progress_callback(1) - self.assertEqual(self.aggregated_amounts, [self.threshold]) - - def test_at_threshold(self): - self.aggregated_progress_callback(self.threshold) - self.assertEqual(self.aggregated_amounts, [self.threshold]) - - def test_over_threshold(self): - over_threshold_amount = self.threshold + 1 - self.aggregated_progress_callback(over_threshold_amount) - self.assertEqual(self.aggregated_amounts, [over_threshold_amount]) - - def test_flush(self): - under_threshold_amount = self.threshold - 1 - self.aggregated_progress_callback(under_threshold_amount) - self.assertEqual(self.aggregated_amounts, []) - self.aggregated_progress_callback.flush() - self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) - - def test_flush_with_nothing_to_flush(self): - under_threshold_amount = self.threshold - 1 - self.aggregated_progress_callback(under_threshold_amount) - self.assertEqual(self.aggregated_amounts, []) - self.aggregated_progress_callback.flush() - self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) - # Flushing again should do nothing as it was just flushed - self.aggregated_progress_callback.flush() - self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) - - -class TestInterruptReader(BaseUploadTest): - def test_read_raises_exception(self): - with open(self.filename, 'rb') as f: - reader = InterruptReader(f, self.transfer_coordinator) - # Read some bytes to show it can be read. - self.assertEqual(reader.read(1), self.content[0:1]) - # Then set an exception in the transfer coordinator - self.transfer_coordinator.set_exception(InterruptionError()) - # The next read should have the exception propograte - with self.assertRaises(InterruptionError): - reader.read() - - def test_seek(self): - with open(self.filename, 'rb') as f: - reader = InterruptReader(f, self.transfer_coordinator) - # Ensure it can seek correctly - reader.seek(1) - self.assertEqual(reader.read(1), self.content[1:2]) - - def test_tell(self): - with open(self.filename, 'rb') as f: - reader = InterruptReader(f, self.transfer_coordinator) - # Ensure it can tell correctly - reader.seek(1) - self.assertEqual(reader.tell(), 1) - - -class BaseUploadInputManagerTest(BaseUploadTest): - def setUp(self): - super().setUp() - self.osutil = OSUtils() - self.config = TransferConfig() - self.recording_subscriber = RecordingSubscriber() - self.subscribers.append(self.recording_subscriber) - - def _get_expected_body_for_part(self, part_number): - # A helper method for retrieving the expected body for a specific - # part number of the data - total_size = len(self.content) - chunk_size = self.config.multipart_chunksize - start_index = (part_number - 1) * chunk_size - end_index = part_number * chunk_size - if end_index >= total_size: - return self.content[start_index:] - return self.content[start_index:end_index] - - -class TestUploadFilenameInputManager(BaseUploadInputManagerTest): - def setUp(self): - super().setUp() - self.upload_input_manager = UploadFilenameInputManager( - self.osutil, self.transfer_coordinator - ) - self.call_args = CallArgs( - fileobj=self.filename, subscribers=self.subscribers - ) - self.future = self.get_transfer_future(self.call_args) - - def test_is_compatible(self): - self.assertTrue( - self.upload_input_manager.is_compatible( - self.future.meta.call_args.fileobj - ) - ) - - def test_stores_bodies_in_memory_put_object(self): - self.assertFalse( - self.upload_input_manager.stores_body_in_memory('put_object') - ) - - def test_stores_bodies_in_memory_upload_part(self): - self.assertFalse( - self.upload_input_manager.stores_body_in_memory('upload_part') - ) - - def test_provide_transfer_size(self): - self.upload_input_manager.provide_transfer_size(self.future) - # The provided file size should be equal to size of the contents of - # the file. - self.assertEqual(self.future.meta.size, len(self.content)) - - def test_requires_multipart_upload(self): - self.future.meta.provide_transfer_size(len(self.content)) - # With the default multipart threshold, the length of the content - # should be smaller than the threshold thus not requiring a multipart - # transfer. - self.assertFalse( - self.upload_input_manager.requires_multipart_upload( - self.future, self.config - ) - ) - # Decreasing the threshold to that of the length of the content of - # the file should trigger the need for a multipart upload. - self.config.multipart_threshold = len(self.content) - self.assertTrue( - self.upload_input_manager.requires_multipart_upload( - self.future, self.config - ) - ) - - def test_get_put_object_body(self): - self.future.meta.provide_transfer_size(len(self.content)) - read_file_chunk = self.upload_input_manager.get_put_object_body( - self.future - ) - read_file_chunk.enable_callback() - # The file-like object provided back should be the same as the content - # of the file. - with read_file_chunk: - self.assertEqual(read_file_chunk.read(), self.content) - # The file-like object should also have been wrapped with the - # on_queued callbacks to track the amount of bytes being transferred. - self.assertEqual( - self.recording_subscriber.calculate_bytes_seen(), len(self.content) - ) - - def test_get_put_object_body_is_interruptable(self): - self.future.meta.provide_transfer_size(len(self.content)) - read_file_chunk = self.upload_input_manager.get_put_object_body( - self.future - ) - - # Set an exception in the transfer coordinator - self.transfer_coordinator.set_exception(InterruptionError) - # Ensure the returned read file chunk can be interrupted with that - # error. - with self.assertRaises(InterruptionError): - read_file_chunk.read() - - def test_yield_upload_part_bodies(self): - # Adjust the chunk size to something more grainular for testing. - self.config.multipart_chunksize = 4 - self.future.meta.provide_transfer_size(len(self.content)) - - # Get an iterator that will yield all of the bodies and their - # respective part number. - part_iterator = self.upload_input_manager.yield_upload_part_bodies( - self.future, self.config.multipart_chunksize - ) - expected_part_number = 1 - for part_number, read_file_chunk in part_iterator: - # Ensure that the part number is as expected - self.assertEqual(part_number, expected_part_number) - read_file_chunk.enable_callback() - # Ensure that the body is correct for that part. - with read_file_chunk: - self.assertEqual( - read_file_chunk.read(), - self._get_expected_body_for_part(part_number), - ) - expected_part_number += 1 - - # All of the file-like object should also have been wrapped with the - # on_queued callbacks to track the amount of bytes being transferred. - self.assertEqual( - self.recording_subscriber.calculate_bytes_seen(), len(self.content) - ) - - def test_yield_upload_part_bodies_are_interruptable(self): - # Adjust the chunk size to something more grainular for testing. - self.config.multipart_chunksize = 4 - self.future.meta.provide_transfer_size(len(self.content)) - - # Get an iterator that will yield all of the bodies and their - # respective part number. - part_iterator = self.upload_input_manager.yield_upload_part_bodies( - self.future, self.config.multipart_chunksize - ) - - # Set an exception in the transfer coordinator - self.transfer_coordinator.set_exception(InterruptionError) - for _, read_file_chunk in part_iterator: - # Ensure that each read file chunk yielded can be interrupted - # with that error. - with self.assertRaises(InterruptionError): - read_file_chunk.read() - - -class TestUploadSeekableInputManager(TestUploadFilenameInputManager): - def setUp(self): - super().setUp() - self.upload_input_manager = UploadSeekableInputManager( - self.osutil, self.transfer_coordinator - ) - self.fileobj = open(self.filename, 'rb') - self.call_args = CallArgs( - fileobj=self.fileobj, subscribers=self.subscribers - ) - self.future = self.get_transfer_future(self.call_args) - - def tearDown(self): - self.fileobj.close() - super().tearDown() - - def test_is_compatible_bytes_io(self): - self.assertTrue(self.upload_input_manager.is_compatible(BytesIO())) - - def test_not_compatible_for_non_filelike_obj(self): - self.assertFalse(self.upload_input_manager.is_compatible(object())) - - def test_stores_bodies_in_memory_upload_part(self): - self.assertTrue( - self.upload_input_manager.stores_body_in_memory('upload_part') - ) - - def test_get_put_object_body(self): - start_pos = 3 - self.fileobj.seek(start_pos) - adjusted_size = len(self.content) - start_pos - self.future.meta.provide_transfer_size(adjusted_size) - read_file_chunk = self.upload_input_manager.get_put_object_body( - self.future - ) - - read_file_chunk.enable_callback() - # The fact that the file was seeked to start should be taken into - # account in length and content for the read file chunk. - with read_file_chunk: - self.assertEqual(len(read_file_chunk), adjusted_size) - self.assertEqual(read_file_chunk.read(), self.content[start_pos:]) - self.assertEqual( - self.recording_subscriber.calculate_bytes_seen(), adjusted_size - ) - - -class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager): - def setUp(self): - super().setUp() - self.upload_input_manager = UploadNonSeekableInputManager( - self.osutil, self.transfer_coordinator - ) - self.fileobj = NonSeekableReader(self.content) - self.call_args = CallArgs( - fileobj=self.fileobj, subscribers=self.subscribers - ) - self.future = self.get_transfer_future(self.call_args) - - def assert_multipart_parts(self): - """ - Asserts that the input manager will generate a multipart upload - and that each part is in order and the correct size. - """ - # Assert that a multipart upload is required. - self.assertTrue( - self.upload_input_manager.requires_multipart_upload( - self.future, self.config - ) - ) - - # Get a list of all the parts that would be sent. - parts = list( - self.upload_input_manager.yield_upload_part_bodies( - self.future, self.config.multipart_chunksize - ) - ) - - # Assert that the actual number of parts is what we would expect - # based on the configuration. - size = self.config.multipart_chunksize - num_parts = math.ceil(len(self.content) / size) - self.assertEqual(len(parts), num_parts) - - # Run for every part but the last part. - for i, part in enumerate(parts[:-1]): - # Assert the part number is correct. - self.assertEqual(part[0], i + 1) - # Assert the part contains the right amount of data. - data = part[1].read() - self.assertEqual(len(data), size) - - # Assert that the last part is the correct size. - expected_final_size = len(self.content) - ((num_parts - 1) * size) - final_part = parts[-1] - self.assertEqual(len(final_part[1].read()), expected_final_size) - - # Assert that the last part has the correct part number. - self.assertEqual(final_part[0], len(parts)) - - def test_provide_transfer_size(self): - self.upload_input_manager.provide_transfer_size(self.future) - # There is no way to get the size without reading the entire body. - self.assertEqual(self.future.meta.size, None) - - def test_stores_bodies_in_memory_upload_part(self): - self.assertTrue( - self.upload_input_manager.stores_body_in_memory('upload_part') - ) - - def test_stores_bodies_in_memory_put_object(self): - self.assertTrue( - self.upload_input_manager.stores_body_in_memory('put_object') - ) - - def test_initial_data_parts_threshold_lesser(self): - # threshold < size - self.config.multipart_chunksize = 4 - self.config.multipart_threshold = 2 - self.assert_multipart_parts() - - def test_initial_data_parts_threshold_equal(self): - # threshold == size - self.config.multipart_chunksize = 4 - self.config.multipart_threshold = 4 - self.assert_multipart_parts() - - def test_initial_data_parts_threshold_greater(self): - # threshold > size - self.config.multipart_chunksize = 4 - self.config.multipart_threshold = 8 - self.assert_multipart_parts() - - -class TestUploadSubmissionTask(BaseSubmissionTaskTest): - def setUp(self): - super().setUp() - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'myfile') - self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3) - self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE - self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5 - - with open(self.filename, 'wb') as f: - f.write(self.content) - - self.bucket = 'mybucket' - self.key = 'mykey' - self.extra_args = {} - self.subscribers = [] - - # A list to keep track of all of the bodies sent over the wire - # and their order. - self.sent_bodies = [] - self.client.meta.events.register( - 'before-parameter-build.s3.*', self.collect_body - ) - - self.call_args = self.get_call_args() - self.transfer_future = self.get_transfer_future(self.call_args) - self.submission_main_kwargs = { - 'client': self.client, - 'config': self.config, - 'osutil': self.osutil, - 'request_executor': self.executor, - 'transfer_future': self.transfer_future, - } - self.submission_task = self.get_task( - UploadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - - def tearDown(self): - super().tearDown() - shutil.rmtree(self.tempdir) - - def collect_body(self, params, **kwargs): - if 'Body' in params: - self.sent_bodies.append(params['Body'].read()) - - def get_call_args(self, **kwargs): - default_call_args = { - 'fileobj': self.filename, - 'bucket': self.bucket, - 'key': self.key, - 'extra_args': self.extra_args, - 'subscribers': self.subscribers, - } - default_call_args.update(kwargs) - return CallArgs(**default_call_args) - - def add_multipart_upload_stubbed_responses(self): - self.stubber.add_response( - method='create_multipart_upload', - service_response={'UploadId': 'my-id'}, - ) - self.stubber.add_response( - method='upload_part', service_response={'ETag': 'etag-1'} - ) - self.stubber.add_response( - method='upload_part', service_response={'ETag': 'etag-2'} - ) - self.stubber.add_response( - method='upload_part', service_response={'ETag': 'etag-3'} - ) - self.stubber.add_response( - method='complete_multipart_upload', service_response={} - ) - - def wrap_executor_in_recorder(self): - self.executor = RecordingExecutor(self.executor) - self.submission_main_kwargs['request_executor'] = self.executor - - def use_fileobj_in_call_args(self, fileobj): - self.call_args = self.get_call_args(fileobj=fileobj) - self.transfer_future = self.get_transfer_future(self.call_args) - self.submission_main_kwargs['transfer_future'] = self.transfer_future - - def assert_tag_value_for_put_object(self, tag_value): - self.assertEqual(self.executor.submissions[0]['tag'], tag_value) - - def assert_tag_value_for_upload_parts(self, tag_value): - for submission in self.executor.submissions[1:-1]: - self.assertEqual(submission['tag'], tag_value) - - def test_provide_file_size_on_put(self): - self.call_args.subscribers.append(FileSizeProvider(len(self.content))) - self.stubber.add_response( - method='put_object', - service_response={}, - expected_params={ - 'Body': ANY, - 'Bucket': self.bucket, - 'Key': self.key, - }, - ) - - # With this submitter, it will fail to stat the file if a transfer - # size is not provided. - self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize() - - self.submission_task = self.get_task( - UploadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - self.submission_task() - self.transfer_future.result() - self.stubber.assert_no_pending_responses() - self.assertEqual(self.sent_bodies, [self.content]) - - def test_submits_no_tag_for_put_object_filename(self): - self.wrap_executor_in_recorder() - self.stubber.add_response('put_object', {}) - - self.submission_task = self.get_task( - UploadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - self.submission_task() - self.transfer_future.result() - self.stubber.assert_no_pending_responses() - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_value_for_put_object(None) - - def test_submits_no_tag_for_multipart_filename(self): - self.wrap_executor_in_recorder() - - # Set up for a multipart upload. - self.add_multipart_upload_stubbed_responses() - self.config.multipart_threshold = 1 - - self.submission_task = self.get_task( - UploadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - self.submission_task() - self.transfer_future.result() - self.stubber.assert_no_pending_responses() - - # Make sure no tag to limit any of the upload part tasks were - # were associated when submitted to the executor - self.assert_tag_value_for_upload_parts(None) - - def test_submits_no_tag_for_put_object_fileobj(self): - self.wrap_executor_in_recorder() - self.stubber.add_response('put_object', {}) - - with open(self.filename, 'rb') as f: - self.use_fileobj_in_call_args(f) - self.submission_task = self.get_task( - UploadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - self.submission_task() - self.transfer_future.result() - self.stubber.assert_no_pending_responses() - - # Make sure no tag to limit that task specifically was not associated - # to that task submission. - self.assert_tag_value_for_put_object(None) - - def test_submits_tag_for_multipart_fileobj(self): - self.wrap_executor_in_recorder() - - # Set up for a multipart upload. - self.add_multipart_upload_stubbed_responses() - self.config.multipart_threshold = 1 - - with open(self.filename, 'rb') as f: - self.use_fileobj_in_call_args(f) - self.submission_task = self.get_task( - UploadSubmissionTask, main_kwargs=self.submission_main_kwargs - ) - self.submission_task() - self.transfer_future.result() - self.stubber.assert_no_pending_responses() - - # Make sure tags to limit all of the upload part tasks were - # were associated when submitted to the executor as these tasks will - # have chunks of data stored with them in memory. - self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG) - - -class TestPutObjectTask(BaseUploadTest): - def test_main(self): - extra_args = {'Metadata': {'foo': 'bar'}} - with open(self.filename, 'rb') as fileobj: - task = self.get_task( - PutObjectTask, - main_kwargs={ - 'client': self.client, - 'fileobj': fileobj, - 'bucket': self.bucket, - 'key': self.key, - 'extra_args': extra_args, - }, - ) - self.stubber.add_response( - method='put_object', - service_response={}, - expected_params={ - 'Body': ANY, - 'Bucket': self.bucket, - 'Key': self.key, - 'Metadata': {'foo': 'bar'}, - }, - ) - task() - self.stubber.assert_no_pending_responses() - self.assertEqual(self.sent_bodies, [self.content]) - - -class TestUploadPartTask(BaseUploadTest): - def test_main(self): - extra_args = {'RequestPayer': 'requester'} - upload_id = 'my-id' - part_number = 1 - etag = 'foo' - with open(self.filename, 'rb') as fileobj: - task = self.get_task( - UploadPartTask, - main_kwargs={ - 'client': self.client, - 'fileobj': fileobj, - 'bucket': self.bucket, - 'key': self.key, - 'upload_id': upload_id, - 'part_number': part_number, - 'extra_args': extra_args, - }, - ) - self.stubber.add_response( - method='upload_part', - service_response={'ETag': etag}, - expected_params={ - 'Body': ANY, - 'Bucket': self.bucket, - 'Key': self.key, - 'UploadId': upload_id, - 'PartNumber': part_number, - 'RequestPayer': 'requester', - }, - ) - rval = task() - self.stubber.assert_no_pending_responses() - self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number}) - self.assertEqual(self.sent_bodies, [self.content]) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import math +import os +import shutil +import tempfile +from io import BytesIO + +from botocore.stub import ANY + +from s3transfer.futures import IN_MEMORY_UPLOAD_TAG +from s3transfer.manager import TransferConfig +from s3transfer.upload import ( + AggregatedProgressCallback, + InterruptReader, + PutObjectTask, + UploadFilenameInputManager, + UploadNonSeekableInputManager, + UploadPartTask, + UploadSeekableInputManager, + UploadSubmissionTask, +) +from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils +from __tests__ import ( + BaseSubmissionTaskTest, + BaseTaskTest, + FileSizeProvider, + NonSeekableReader, + RecordingExecutor, + RecordingSubscriber, + unittest, +) + + +class InterruptionError(Exception): + pass + + +class OSUtilsExceptionOnFileSize(OSUtils): + def get_file_size(self, filename): + raise AssertionError( + "The file %s should not have been stated" % filename + ) + + +class BaseUploadTest(BaseTaskTest): + def setUp(self): + super().setUp() + self.bucket = 'mybucket' + self.key = 'foo' + self.osutil = OSUtils() + + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'my content' + self.subscribers = [] + + with open(self.filename, 'wb') as f: + f.write(self.content) + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def collect_body(self, params, **kwargs): + if 'Body' in params: + self.sent_bodies.append(params['Body'].read()) + + +class TestAggregatedProgressCallback(unittest.TestCase): + def setUp(self): + self.aggregated_amounts = [] + self.threshold = 3 + self.aggregated_progress_callback = AggregatedProgressCallback( + [self.callback], self.threshold + ) + + def callback(self, bytes_transferred): + self.aggregated_amounts.append(bytes_transferred) + + def test_under_threshold(self): + one_under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(one_under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback(1) + self.assertEqual(self.aggregated_amounts, [self.threshold]) + + def test_at_threshold(self): + self.aggregated_progress_callback(self.threshold) + self.assertEqual(self.aggregated_amounts, [self.threshold]) + + def test_over_threshold(self): + over_threshold_amount = self.threshold + 1 + self.aggregated_progress_callback(over_threshold_amount) + self.assertEqual(self.aggregated_amounts, [over_threshold_amount]) + + def test_flush(self): + under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + + def test_flush_with_nothing_to_flush(self): + under_threshold_amount = self.threshold - 1 + self.aggregated_progress_callback(under_threshold_amount) + self.assertEqual(self.aggregated_amounts, []) + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + # Flushing again should do nothing as it was just flushed + self.aggregated_progress_callback.flush() + self.assertEqual(self.aggregated_amounts, [under_threshold_amount]) + + +class TestInterruptReader(BaseUploadTest): + def test_read_raises_exception(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Read some bytes to show it can be read. + self.assertEqual(reader.read(1), self.content[0:1]) + # Then set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError()) + # The next read should have the exception propograte + with self.assertRaises(InterruptionError): + reader.read() + + def test_seek(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Ensure it can seek correctly + reader.seek(1) + self.assertEqual(reader.read(1), self.content[1:2]) + + def test_tell(self): + with open(self.filename, 'rb') as f: + reader = InterruptReader(f, self.transfer_coordinator) + # Ensure it can tell correctly + reader.seek(1) + self.assertEqual(reader.tell(), 1) + + +class BaseUploadInputManagerTest(BaseUploadTest): + def setUp(self): + super().setUp() + self.osutil = OSUtils() + self.config = TransferConfig() + self.recording_subscriber = RecordingSubscriber() + self.subscribers.append(self.recording_subscriber) + + def _get_expected_body_for_part(self, part_number): + # A helper method for retrieving the expected body for a specific + # part number of the data + total_size = len(self.content) + chunk_size = self.config.multipart_chunksize + start_index = (part_number - 1) * chunk_size + end_index = part_number * chunk_size + if end_index >= total_size: + return self.content[start_index:] + return self.content[start_index:end_index] + + +class TestUploadFilenameInputManager(BaseUploadInputManagerTest): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadFilenameInputManager( + self.osutil, self.transfer_coordinator + ) + self.call_args = CallArgs( + fileobj=self.filename, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def test_is_compatible(self): + self.assertTrue( + self.upload_input_manager.is_compatible( + self.future.meta.call_args.fileobj + ) + ) + + def test_stores_bodies_in_memory_put_object(self): + self.assertFalse( + self.upload_input_manager.stores_body_in_memory('put_object') + ) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertFalse( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_provide_transfer_size(self): + self.upload_input_manager.provide_transfer_size(self.future) + # The provided file size should be equal to size of the contents of + # the file. + self.assertEqual(self.future.meta.size, len(self.content)) + + def test_requires_multipart_upload(self): + self.future.meta.provide_transfer_size(len(self.content)) + # With the default multipart threshold, the length of the content + # should be smaller than the threshold thus not requiring a multipart + # transfer. + self.assertFalse( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + # Decreasing the threshold to that of the length of the content of + # the file should trigger the need for a multipart upload. + self.config.multipart_threshold = len(self.content) + self.assertTrue( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + + def test_get_put_object_body(self): + self.future.meta.provide_transfer_size(len(self.content)) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + read_file_chunk.enable_callback() + # The file-like object provided back should be the same as the content + # of the file. + with read_file_chunk: + self.assertEqual(read_file_chunk.read(), self.content) + # The file-like object should also have been wrapped with the + # on_queued callbacks to track the amount of bytes being transferred. + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), len(self.content) + ) + + def test_get_put_object_body_is_interruptable(self): + self.future.meta.provide_transfer_size(len(self.content)) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + + # Set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError) + # Ensure the returned read file chunk can be interrupted with that + # error. + with self.assertRaises(InterruptionError): + read_file_chunk.read() + + def test_yield_upload_part_bodies(self): + # Adjust the chunk size to something more grainular for testing. + self.config.multipart_chunksize = 4 + self.future.meta.provide_transfer_size(len(self.content)) + + # Get an iterator that will yield all of the bodies and their + # respective part number. + part_iterator = self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + expected_part_number = 1 + for part_number, read_file_chunk in part_iterator: + # Ensure that the part number is as expected + self.assertEqual(part_number, expected_part_number) + read_file_chunk.enable_callback() + # Ensure that the body is correct for that part. + with read_file_chunk: + self.assertEqual( + read_file_chunk.read(), + self._get_expected_body_for_part(part_number), + ) + expected_part_number += 1 + + # All of the file-like object should also have been wrapped with the + # on_queued callbacks to track the amount of bytes being transferred. + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), len(self.content) + ) + + def test_yield_upload_part_bodies_are_interruptable(self): + # Adjust the chunk size to something more grainular for testing. + self.config.multipart_chunksize = 4 + self.future.meta.provide_transfer_size(len(self.content)) + + # Get an iterator that will yield all of the bodies and their + # respective part number. + part_iterator = self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + + # Set an exception in the transfer coordinator + self.transfer_coordinator.set_exception(InterruptionError) + for _, read_file_chunk in part_iterator: + # Ensure that each read file chunk yielded can be interrupted + # with that error. + with self.assertRaises(InterruptionError): + read_file_chunk.read() + + +class TestUploadSeekableInputManager(TestUploadFilenameInputManager): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadSeekableInputManager( + self.osutil, self.transfer_coordinator + ) + self.fileobj = open(self.filename, 'rb') + self.call_args = CallArgs( + fileobj=self.fileobj, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def tearDown(self): + self.fileobj.close() + super().tearDown() + + def test_is_compatible_bytes_io(self): + self.assertTrue(self.upload_input_manager.is_compatible(BytesIO())) + + def test_not_compatible_for_non_filelike_obj(self): + self.assertFalse(self.upload_input_manager.is_compatible(object())) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_get_put_object_body(self): + start_pos = 3 + self.fileobj.seek(start_pos) + adjusted_size = len(self.content) - start_pos + self.future.meta.provide_transfer_size(adjusted_size) + read_file_chunk = self.upload_input_manager.get_put_object_body( + self.future + ) + + read_file_chunk.enable_callback() + # The fact that the file was seeked to start should be taken into + # account in length and content for the read file chunk. + with read_file_chunk: + self.assertEqual(len(read_file_chunk), adjusted_size) + self.assertEqual(read_file_chunk.read(), self.content[start_pos:]) + self.assertEqual( + self.recording_subscriber.calculate_bytes_seen(), adjusted_size + ) + + +class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager): + def setUp(self): + super().setUp() + self.upload_input_manager = UploadNonSeekableInputManager( + self.osutil, self.transfer_coordinator + ) + self.fileobj = NonSeekableReader(self.content) + self.call_args = CallArgs( + fileobj=self.fileobj, subscribers=self.subscribers + ) + self.future = self.get_transfer_future(self.call_args) + + def assert_multipart_parts(self): + """ + Asserts that the input manager will generate a multipart upload + and that each part is in order and the correct size. + """ + # Assert that a multipart upload is required. + self.assertTrue( + self.upload_input_manager.requires_multipart_upload( + self.future, self.config + ) + ) + + # Get a list of all the parts that would be sent. + parts = list( + self.upload_input_manager.yield_upload_part_bodies( + self.future, self.config.multipart_chunksize + ) + ) + + # Assert that the actual number of parts is what we would expect + # based on the configuration. + size = self.config.multipart_chunksize + num_parts = math.ceil(len(self.content) / size) + self.assertEqual(len(parts), num_parts) + + # Run for every part but the last part. + for i, part in enumerate(parts[:-1]): + # Assert the part number is correct. + self.assertEqual(part[0], i + 1) + # Assert the part contains the right amount of data. + data = part[1].read() + self.assertEqual(len(data), size) + + # Assert that the last part is the correct size. + expected_final_size = len(self.content) - ((num_parts - 1) * size) + final_part = parts[-1] + self.assertEqual(len(final_part[1].read()), expected_final_size) + + # Assert that the last part has the correct part number. + self.assertEqual(final_part[0], len(parts)) + + def test_provide_transfer_size(self): + self.upload_input_manager.provide_transfer_size(self.future) + # There is no way to get the size without reading the entire body. + self.assertEqual(self.future.meta.size, None) + + def test_stores_bodies_in_memory_upload_part(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('upload_part') + ) + + def test_stores_bodies_in_memory_put_object(self): + self.assertTrue( + self.upload_input_manager.stores_body_in_memory('put_object') + ) + + def test_initial_data_parts_threshold_lesser(self): + # threshold < size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 2 + self.assert_multipart_parts() + + def test_initial_data_parts_threshold_equal(self): + # threshold == size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 4 + self.assert_multipart_parts() + + def test_initial_data_parts_threshold_greater(self): + # threshold > size + self.config.multipart_chunksize = 4 + self.config.multipart_threshold = 8 + self.assert_multipart_parts() + + +class TestUploadSubmissionTask(BaseSubmissionTaskTest): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'myfile') + self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3) + self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE + self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5 + + with open(self.filename, 'wb') as f: + f.write(self.content) + + self.bucket = 'mybucket' + self.key = 'mykey' + self.extra_args = {} + self.subscribers = [] + + # A list to keep track of all of the bodies sent over the wire + # and their order. + self.sent_bodies = [] + self.client.meta.events.register( + 'before-parameter-build.s3.*', self.collect_body + ) + + self.call_args = self.get_call_args() + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs = { + 'client': self.client, + 'config': self.config, + 'osutil': self.osutil, + 'request_executor': self.executor, + 'transfer_future': self.transfer_future, + } + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def collect_body(self, params, **kwargs): + if 'Body' in params: + self.sent_bodies.append(params['Body'].read()) + + def get_call_args(self, **kwargs): + default_call_args = { + 'fileobj': self.filename, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': self.extra_args, + 'subscribers': self.subscribers, + } + default_call_args.update(kwargs) + return CallArgs(**default_call_args) + + def add_multipart_upload_stubbed_responses(self): + self.stubber.add_response( + method='create_multipart_upload', + service_response={'UploadId': 'my-id'}, + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-1'} + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-2'} + ) + self.stubber.add_response( + method='upload_part', service_response={'ETag': 'etag-3'} + ) + self.stubber.add_response( + method='complete_multipart_upload', service_response={} + ) + + def wrap_executor_in_recorder(self): + self.executor = RecordingExecutor(self.executor) + self.submission_main_kwargs['request_executor'] = self.executor + + def use_fileobj_in_call_args(self, fileobj): + self.call_args = self.get_call_args(fileobj=fileobj) + self.transfer_future = self.get_transfer_future(self.call_args) + self.submission_main_kwargs['transfer_future'] = self.transfer_future + + def assert_tag_value_for_put_object(self, tag_value): + self.assertEqual(self.executor.submissions[0]['tag'], tag_value) + + def assert_tag_value_for_upload_parts(self, tag_value): + for submission in self.executor.submissions[1:-1]: + self.assertEqual(submission['tag'], tag_value) + + def test_provide_file_size_on_put(self): + self.call_args.subscribers.append(FileSizeProvider(len(self.content))) + self.stubber.add_response( + method='put_object', + service_response={}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + }, + ) + + # With this submitter, it will fail to stat the file if a transfer + # size is not provided. + self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize() + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + self.assertEqual(self.sent_bodies, [self.content]) + + def test_submits_no_tag_for_put_object_filename(self): + self.wrap_executor_in_recorder() + self.stubber.add_response('put_object', {}) + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_value_for_put_object(None) + + def test_submits_no_tag_for_multipart_filename(self): + self.wrap_executor_in_recorder() + + # Set up for a multipart upload. + self.add_multipart_upload_stubbed_responses() + self.config.multipart_threshold = 1 + + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit any of the upload part tasks were + # were associated when submitted to the executor + self.assert_tag_value_for_upload_parts(None) + + def test_submits_no_tag_for_put_object_fileobj(self): + self.wrap_executor_in_recorder() + self.stubber.add_response('put_object', {}) + + with open(self.filename, 'rb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure no tag to limit that task specifically was not associated + # to that task submission. + self.assert_tag_value_for_put_object(None) + + def test_submits_tag_for_multipart_fileobj(self): + self.wrap_executor_in_recorder() + + # Set up for a multipart upload. + self.add_multipart_upload_stubbed_responses() + self.config.multipart_threshold = 1 + + with open(self.filename, 'rb') as f: + self.use_fileobj_in_call_args(f) + self.submission_task = self.get_task( + UploadSubmissionTask, main_kwargs=self.submission_main_kwargs + ) + self.submission_task() + self.transfer_future.result() + self.stubber.assert_no_pending_responses() + + # Make sure tags to limit all of the upload part tasks were + # were associated when submitted to the executor as these tasks will + # have chunks of data stored with them in memory. + self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG) + + +class TestPutObjectTask(BaseUploadTest): + def test_main(self): + extra_args = {'Metadata': {'foo': 'bar'}} + with open(self.filename, 'rb') as fileobj: + task = self.get_task( + PutObjectTask, + main_kwargs={ + 'client': self.client, + 'fileobj': fileobj, + 'bucket': self.bucket, + 'key': self.key, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='put_object', + service_response={}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + 'Metadata': {'foo': 'bar'}, + }, + ) + task() + self.stubber.assert_no_pending_responses() + self.assertEqual(self.sent_bodies, [self.content]) + + +class TestUploadPartTask(BaseUploadTest): + def test_main(self): + extra_args = {'RequestPayer': 'requester'} + upload_id = 'my-id' + part_number = 1 + etag = 'foo' + with open(self.filename, 'rb') as fileobj: + task = self.get_task( + UploadPartTask, + main_kwargs={ + 'client': self.client, + 'fileobj': fileobj, + 'bucket': self.bucket, + 'key': self.key, + 'upload_id': upload_id, + 'part_number': part_number, + 'extra_args': extra_args, + }, + ) + self.stubber.add_response( + method='upload_part', + service_response={'ETag': etag}, + expected_params={ + 'Body': ANY, + 'Bucket': self.bucket, + 'Key': self.key, + 'UploadId': upload_id, + 'PartNumber': part_number, + 'RequestPayer': 'requester', + }, + ) + rval = task() + self.stubber.assert_no_pending_responses() + self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number}) + self.assertEqual(self.sent_bodies, [self.content]) diff --git a/contrib/python/s3transfer/py3/tests/unit/test_utils.py b/contrib/python/s3transfer/py3/tests/unit/test_utils.py index b7a51c7845..a1ff904e7a 100644 --- a/contrib/python/s3transfer/py3/tests/unit/test_utils.py +++ b/contrib/python/s3transfer/py3/tests/unit/test_utils.py @@ -1,1189 +1,1189 @@ -# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -import io -import os.path -import random -import re -import shutil -import tempfile -import threading -import time -from io import BytesIO, StringIO - -from s3transfer.futures import TransferFuture, TransferMeta -from s3transfer.utils import ( - MAX_PARTS, - MAX_SINGLE_UPLOAD_SIZE, - MIN_UPLOAD_CHUNKSIZE, - CallArgs, - ChunksizeAdjuster, - CountCallbackInvoker, - DeferredOpenFile, - FunctionContainer, - NoResourcesAvailable, - OSUtils, - ReadFileChunk, - SlidingWindowSemaphore, - StreamReaderProgress, - TaskSemaphore, - calculate_num_parts, - calculate_range_parameter, - get_callbacks, - get_filtered_dict, - invoke_progress_callbacks, - random_file_extension, -) -from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest - - -class TestGetCallbacks(unittest.TestCase): - def setUp(self): - self.subscriber = RecordingSubscriber() - self.second_subscriber = RecordingSubscriber() - self.call_args = CallArgs( - subscribers=[self.subscriber, self.second_subscriber] - ) - self.transfer_meta = TransferMeta(self.call_args) - self.transfer_future = TransferFuture(self.transfer_meta) - - def test_get_callbacks(self): - callbacks = get_callbacks(self.transfer_future, 'queued') - # Make sure two callbacks were added as both subscribers had - # an on_queued method. - self.assertEqual(len(callbacks), 2) - - # Ensure that the callback was injected with the future by calling - # one of them and checking that the future was used in the call. - callbacks[0]() - self.assertEqual( - self.subscriber.on_queued_calls, [{'future': self.transfer_future}] - ) - - def test_get_callbacks_for_missing_type(self): - callbacks = get_callbacks(self.transfer_future, 'fake_state') - # There should be no callbacks as the subscribers will not have the - # on_fake_state method - self.assertEqual(len(callbacks), 0) - - -class TestGetFilteredDict(unittest.TestCase): - def test_get_filtered_dict(self): - original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'} - whitelist = ['Include'] - self.assertEqual( - get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'} - ) - - -class TestCallArgs(unittest.TestCase): - def test_call_args(self): - call_args = CallArgs(foo='bar', biz='baz') - self.assertEqual(call_args.foo, 'bar') - self.assertEqual(call_args.biz, 'baz') - - -class TestFunctionContainer(unittest.TestCase): - def get_args_kwargs(self, *args, **kwargs): - return args, kwargs - - def test_call(self): - func_container = FunctionContainer( - self.get_args_kwargs, 'foo', bar='baz' - ) - self.assertEqual(func_container(), (('foo',), {'bar': 'baz'})) - - def test_repr(self): - func_container = FunctionContainer( - self.get_args_kwargs, 'foo', bar='baz' - ) - self.assertEqual( - str(func_container), - 'Function: {} with args {} and kwargs {}'.format( - self.get_args_kwargs, ('foo',), {'bar': 'baz'} - ), - ) - - -class TestCountCallbackInvoker(unittest.TestCase): - def invoke_callback(self): - self.ref_results.append('callback invoked') - - def assert_callback_invoked(self): - self.assertEqual(self.ref_results, ['callback invoked']) - - def assert_callback_not_invoked(self): - self.assertEqual(self.ref_results, []) - - def setUp(self): - self.ref_results = [] - self.invoker = CountCallbackInvoker(self.invoke_callback) - - def test_increment(self): - self.invoker.increment() - self.assertEqual(self.invoker.current_count, 1) - - def test_decrement(self): - self.invoker.increment() - self.invoker.increment() - self.invoker.decrement() - self.assertEqual(self.invoker.current_count, 1) - - def test_count_cannot_go_below_zero(self): - with self.assertRaises(RuntimeError): - self.invoker.decrement() - - def test_callback_invoked_only_once_finalized(self): - self.invoker.increment() - self.invoker.decrement() - self.assert_callback_not_invoked() - self.invoker.finalize() - # Callback should only be invoked once finalized - self.assert_callback_invoked() - - def test_callback_invoked_after_finalizing_and_count_reaching_zero(self): - self.invoker.increment() - self.invoker.finalize() - # Make sure that it does not get invoked immediately after - # finalizing as the count is currently one - self.assert_callback_not_invoked() - self.invoker.decrement() - self.assert_callback_invoked() - - def test_cannot_increment_after_finalization(self): - self.invoker.finalize() - with self.assertRaises(RuntimeError): - self.invoker.increment() - - -class TestRandomFileExtension(unittest.TestCase): - def test_has_proper_length(self): - self.assertEqual(len(random_file_extension(num_digits=4)), 4) - - -class TestInvokeProgressCallbacks(unittest.TestCase): - def test_invoke_progress_callbacks(self): - recording_subscriber = RecordingSubscriber() - invoke_progress_callbacks([recording_subscriber.on_progress], 2) - self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2) - - def test_invoke_progress_callbacks_with_no_progress(self): - recording_subscriber = RecordingSubscriber() - invoke_progress_callbacks([recording_subscriber.on_progress], 0) - self.assertEqual(len(recording_subscriber.on_progress_calls), 0) - - -class TestCalculateNumParts(unittest.TestCase): - def test_calculate_num_parts_divisible(self): - self.assertEqual(calculate_num_parts(size=4, part_size=2), 2) - - def test_calculate_num_parts_not_divisible(self): - self.assertEqual(calculate_num_parts(size=3, part_size=2), 2) - - -class TestCalculateRangeParameter(unittest.TestCase): - def setUp(self): - self.part_size = 5 - self.part_index = 1 - self.num_parts = 3 - - def test_calculate_range_paramter(self): - range_val = calculate_range_parameter( - self.part_size, self.part_index, self.num_parts - ) - self.assertEqual(range_val, 'bytes=5-9') - - def test_last_part_with_no_total_size(self): - range_val = calculate_range_parameter( - self.part_size, self.part_index, num_parts=2 - ) - self.assertEqual(range_val, 'bytes=5-') - - def test_last_part_with_total_size(self): - range_val = calculate_range_parameter( - self.part_size, self.part_index, num_parts=2, total_size=8 - ) - self.assertEqual(range_val, 'bytes=5-7') - - -class BaseUtilsTest(unittest.TestCase): - def setUp(self): - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, 'foo') - self.content = b'abc' - with open(self.filename, 'wb') as f: - f.write(self.content) - self.amounts_seen = [] - self.num_close_callback_calls = 0 - - def tearDown(self): - shutil.rmtree(self.tempdir) - - def callback(self, bytes_transferred): - self.amounts_seen.append(bytes_transferred) - - def close_callback(self): - self.num_close_callback_calls += 1 - - -class TestOSUtils(BaseUtilsTest): - def test_get_file_size(self): - self.assertEqual( - OSUtils().get_file_size(self.filename), len(self.content) - ) - - def test_open_file_chunk_reader(self): - reader = OSUtils().open_file_chunk_reader( - self.filename, 0, 3, [self.callback] - ) - - # The returned reader should be a ReadFileChunk. - self.assertIsInstance(reader, ReadFileChunk) - # The content of the reader should be correct. - self.assertEqual(reader.read(), self.content) - # Callbacks should be disabled depspite being passed in. - self.assertEqual(self.amounts_seen, []) - - def test_open_file_chunk_reader_from_fileobj(self): - with open(self.filename, 'rb') as f: - reader = OSUtils().open_file_chunk_reader_from_fileobj( - f, len(self.content), len(self.content), [self.callback] - ) - - # The returned reader should be a ReadFileChunk. - self.assertIsInstance(reader, ReadFileChunk) - # The content of the reader should be correct. - self.assertEqual(reader.read(), self.content) - reader.close() - # Callbacks should be disabled depspite being passed in. - self.assertEqual(self.amounts_seen, []) - self.assertEqual(self.num_close_callback_calls, 0) - - def test_open_file(self): - fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') - self.assertTrue(hasattr(fileobj, 'write')) - - def test_remove_file_ignores_errors(self): - non_existent_file = os.path.join(self.tempdir, 'no-exist') - # This should not exist to start. - self.assertFalse(os.path.exists(non_existent_file)) - try: - OSUtils().remove_file(non_existent_file) - except OSError as e: - self.fail('OSError should have been caught: %s' % e) - - def test_remove_file_proxies_remove_file(self): - OSUtils().remove_file(self.filename) - self.assertFalse(os.path.exists(self.filename)) - - def test_rename_file(self): - new_filename = os.path.join(self.tempdir, 'newfoo') - OSUtils().rename_file(self.filename, new_filename) - self.assertFalse(os.path.exists(self.filename)) - self.assertTrue(os.path.exists(new_filename)) - - def test_is_special_file_for_normal_file(self): - self.assertFalse(OSUtils().is_special_file(self.filename)) - - def test_is_special_file_for_non_existant_file(self): - non_existant_filename = os.path.join(self.tempdir, 'no-exist') - self.assertFalse(os.path.exists(non_existant_filename)) - self.assertFalse(OSUtils().is_special_file(non_existant_filename)) - - def test_get_temp_filename(self): - filename = 'myfile' - self.assertIsNotNone( - re.match( - r'%s\.[0-9A-Fa-f]{8}$' % filename, - OSUtils().get_temp_filename(filename), - ) - ) - - def test_get_temp_filename_len_255(self): - filename = 'a' * 255 - temp_filename = OSUtils().get_temp_filename(filename) - self.assertLessEqual(len(temp_filename), 255) - - def test_get_temp_filename_len_gt_255(self): - filename = 'a' * 280 - temp_filename = OSUtils().get_temp_filename(filename) - self.assertLessEqual(len(temp_filename), 255) - - def test_allocate(self): - truncate_size = 1 - OSUtils().allocate(self.filename, truncate_size) - with open(self.filename, 'rb') as f: - self.assertEqual(len(f.read()), truncate_size) - - @mock.patch('s3transfer.utils.fallocate') - def test_allocate_with_io_error(self, mock_fallocate): - mock_fallocate.side_effect = IOError() - with self.assertRaises(IOError): - OSUtils().allocate(self.filename, 1) - self.assertFalse(os.path.exists(self.filename)) - - @mock.patch('s3transfer.utils.fallocate') - def test_allocate_with_os_error(self, mock_fallocate): - mock_fallocate.side_effect = OSError() - with self.assertRaises(OSError): - OSUtils().allocate(self.filename, 1) - self.assertFalse(os.path.exists(self.filename)) - - -class TestDeferredOpenFile(BaseUtilsTest): - def setUp(self): - super().setUp() - self.filename = os.path.join(self.tempdir, 'foo') - self.contents = b'my contents' - with open(self.filename, 'wb') as f: - f.write(self.contents) - self.deferred_open_file = DeferredOpenFile( - self.filename, open_function=self.recording_open_function - ) - self.open_call_args = [] - - def tearDown(self): - self.deferred_open_file.close() - super().tearDown() - - def recording_open_function(self, filename, mode): - self.open_call_args.append((filename, mode)) - return open(filename, mode) - - def open_nonseekable(self, filename, mode): - self.open_call_args.append((filename, mode)) - return NonSeekableWriter(BytesIO(self.content)) - - def test_instantiation_does_not_open_file(self): - DeferredOpenFile( - self.filename, open_function=self.recording_open_function - ) - self.assertEqual(len(self.open_call_args), 0) - - def test_name(self): - self.assertEqual(self.deferred_open_file.name, self.filename) - - def test_read(self): - content = self.deferred_open_file.read(2) - self.assertEqual(content, self.contents[0:2]) - content = self.deferred_open_file.read(2) - self.assertEqual(content, self.contents[2:4]) - self.assertEqual(len(self.open_call_args), 1) - - def test_write(self): - self.deferred_open_file = DeferredOpenFile( - self.filename, - mode='wb', - open_function=self.recording_open_function, - ) - - write_content = b'foo' - self.deferred_open_file.write(write_content) - self.deferred_open_file.write(write_content) - self.deferred_open_file.close() - # Both of the writes should now be in the file. - with open(self.filename, 'rb') as f: - self.assertEqual(f.read(), write_content * 2) - # Open should have only been called once. - self.assertEqual(len(self.open_call_args), 1) - - def test_seek(self): - self.deferred_open_file.seek(2) - content = self.deferred_open_file.read(2) - self.assertEqual(content, self.contents[2:4]) - self.assertEqual(len(self.open_call_args), 1) - - def test_open_does_not_seek_with_zero_start_byte(self): - self.deferred_open_file = DeferredOpenFile( - self.filename, - mode='wb', - start_byte=0, - open_function=self.open_nonseekable, - ) - - try: - # If this seeks, an UnsupportedOperation error will be raised. - self.deferred_open_file.write(b'data') - except io.UnsupportedOperation: - self.fail('DeferredOpenFile seeked upon opening') - - def test_open_seeks_with_nonzero_start_byte(self): - self.deferred_open_file = DeferredOpenFile( - self.filename, - mode='wb', - start_byte=5, - open_function=self.open_nonseekable, - ) - - # Since a non-seekable file is being opened, calling Seek will raise - # an UnsupportedOperation error. - with self.assertRaises(io.UnsupportedOperation): - self.deferred_open_file.write(b'data') - - def test_tell(self): - self.deferred_open_file.tell() - # tell() should not have opened the file if it has not been seeked - # or read because we know the start bytes upfront. - self.assertEqual(len(self.open_call_args), 0) - - self.deferred_open_file.seek(2) - self.assertEqual(self.deferred_open_file.tell(), 2) - self.assertEqual(len(self.open_call_args), 1) - - def test_open_args(self): - self.deferred_open_file = DeferredOpenFile( - self.filename, - mode='ab+', - open_function=self.recording_open_function, - ) - # Force an open - self.deferred_open_file.write(b'data') - self.assertEqual(len(self.open_call_args), 1) - self.assertEqual(self.open_call_args[0], (self.filename, 'ab+')) - - def test_context_handler(self): - with self.deferred_open_file: - self.assertEqual(len(self.open_call_args), 1) - - -class TestReadFileChunk(BaseUtilsTest): - def test_read_entire_chunk(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=0, chunk_size=3 - ) - self.assertEqual(chunk.read(), b'one') - self.assertEqual(chunk.read(), b'') - - def test_read_with_amount_size(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=11, chunk_size=4 - ) - self.assertEqual(chunk.read(1), b'f') - self.assertEqual(chunk.read(1), b'o') - self.assertEqual(chunk.read(1), b'u') - self.assertEqual(chunk.read(1), b'r') - self.assertEqual(chunk.read(1), b'') - - def test_reset_stream_emulation(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=11, chunk_size=4 - ) - self.assertEqual(chunk.read(), b'four') - chunk.seek(0) - self.assertEqual(chunk.read(), b'four') - - def test_read_past_end_of_file(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=36, chunk_size=100000 - ) - self.assertEqual(chunk.read(), b'ten') - self.assertEqual(chunk.read(), b'') - self.assertEqual(len(chunk), 3) - - def test_tell_and_seek(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'onetwothreefourfivesixseveneightnineten') - chunk = ReadFileChunk.from_filename( - filename, start_byte=36, chunk_size=100000 - ) - self.assertEqual(chunk.tell(), 0) - self.assertEqual(chunk.read(), b'ten') - self.assertEqual(chunk.tell(), 3) - chunk.seek(0) - self.assertEqual(chunk.tell(), 0) - chunk.seek(1, whence=1) - self.assertEqual(chunk.tell(), 1) - chunk.seek(-1, whence=1) - self.assertEqual(chunk.tell(), 0) - chunk.seek(-1, whence=2) - self.assertEqual(chunk.tell(), 2) - - def test_tell_and_seek_boundaries(self): - # Test to ensure ReadFileChunk behaves the same as the - # Python standard library around seeking and reading out - # of bounds in a file object. - data = b'abcdefghij12345678klmnopqrst' - start_pos = 10 - chunk_size = 8 - - # Create test file - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(data) - - # ReadFileChunk should be a substring of only numbers - file_objects = [ - ReadFileChunk.from_filename( - filename, start_byte=start_pos, chunk_size=chunk_size - ) - ] - - # Uncomment next line to validate we match Python's io.BytesIO - # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size])) - - for obj in file_objects: - self._assert_whence_start_behavior(obj) - self._assert_whence_end_behavior(obj) - self._assert_whence_relative_behavior(obj) - self._assert_boundary_behavior(obj) - - def _assert_whence_start_behavior(self, file_obj): - self.assertEqual(file_obj.tell(), 0) - - file_obj.seek(1, 0) - self.assertEqual(file_obj.tell(), 1) - - file_obj.seek(1) - self.assertEqual(file_obj.tell(), 1) - self.assertEqual(file_obj.read(), b'2345678') - - file_obj.seek(3, 0) - self.assertEqual(file_obj.tell(), 3) - - file_obj.seek(0, 0) - self.assertEqual(file_obj.tell(), 0) - - def _assert_whence_relative_behavior(self, file_obj): - self.assertEqual(file_obj.tell(), 0) - - file_obj.seek(2, 1) - self.assertEqual(file_obj.tell(), 2) - - file_obj.seek(1, 1) - self.assertEqual(file_obj.tell(), 3) - self.assertEqual(file_obj.read(), b'45678') - - file_obj.seek(20, 1) - self.assertEqual(file_obj.tell(), 28) - - file_obj.seek(-30, 1) - self.assertEqual(file_obj.tell(), 0) - self.assertEqual(file_obj.read(), b'12345678') - - file_obj.seek(-8, 1) - self.assertEqual(file_obj.tell(), 0) - - def _assert_whence_end_behavior(self, file_obj): - self.assertEqual(file_obj.tell(), 0) - - file_obj.seek(-1, 2) - self.assertEqual(file_obj.tell(), 7) - - file_obj.seek(1, 2) - self.assertEqual(file_obj.tell(), 9) - - file_obj.seek(3, 2) - self.assertEqual(file_obj.tell(), 11) - self.assertEqual(file_obj.read(), b'') - - file_obj.seek(-15, 2) - self.assertEqual(file_obj.tell(), 0) - self.assertEqual(file_obj.read(), b'12345678') - - file_obj.seek(-8, 2) - self.assertEqual(file_obj.tell(), 0) - - def _assert_boundary_behavior(self, file_obj): - # Verify we're at the start - self.assertEqual(file_obj.tell(), 0) - - # Verify we can't move backwards beyond start of file - file_obj.seek(-10, 1) - self.assertEqual(file_obj.tell(), 0) - - # Verify we *can* move after end of file, but return nothing - file_obj.seek(10, 2) - self.assertEqual(file_obj.tell(), 18) - self.assertEqual(file_obj.read(), b'') - self.assertEqual(file_obj.read(10), b'') - - # Verify we can partially rewind - file_obj.seek(-12, 1) - self.assertEqual(file_obj.tell(), 6) - self.assertEqual(file_obj.read(), b'78') - self.assertEqual(file_obj.tell(), 8) - - # Verify we can rewind to start - file_obj.seek(0) - self.assertEqual(file_obj.tell(), 0) - - def test_file_chunk_supports_context_manager(self): - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(b'abc') - with ReadFileChunk.from_filename( - filename, start_byte=0, chunk_size=2 - ) as chunk: - val = chunk.read() - self.assertEqual(val, b'ab') - - def test_iter_is_always_empty(self): - # This tests the workaround for the httplib bug (see - # the source for more info). - filename = os.path.join(self.tempdir, 'foo') - open(filename, 'wb').close() - chunk = ReadFileChunk.from_filename( - filename, start_byte=0, chunk_size=10 - ) - self.assertEqual(list(chunk), []) - - def test_callback_is_invoked_on_read(self): - chunk = ReadFileChunk.from_filename( - self.filename, - start_byte=0, - chunk_size=3, - callbacks=[self.callback], - ) - chunk.read(1) - chunk.read(1) - chunk.read(1) - self.assertEqual(self.amounts_seen, [1, 1, 1]) - - def test_all_callbacks_invoked_on_read(self): - chunk = ReadFileChunk.from_filename( - self.filename, - start_byte=0, - chunk_size=3, - callbacks=[self.callback, self.callback], - ) - chunk.read(1) - chunk.read(1) - chunk.read(1) - # The list should be twice as long because there are two callbacks - # recording the amount read. - self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1]) - - def test_callback_can_be_disabled(self): - chunk = ReadFileChunk.from_filename( - self.filename, - start_byte=0, - chunk_size=3, - callbacks=[self.callback], - ) - chunk.disable_callback() - # Now reading from the ReadFileChunk should not invoke - # the callback. - chunk.read() - self.assertEqual(self.amounts_seen, []) - - def test_callback_will_also_be_triggered_by_seek(self): - chunk = ReadFileChunk.from_filename( - self.filename, - start_byte=0, - chunk_size=3, - callbacks=[self.callback], - ) - chunk.read(2) - chunk.seek(0) - chunk.read(2) - chunk.seek(1) - chunk.read(2) - self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) - - def test_callback_triggered_by_out_of_bound_seeks(self): - data = b'abcdefghij1234567890klmnopqr' - - # Create test file - filename = os.path.join(self.tempdir, 'foo') - with open(filename, 'wb') as f: - f.write(data) - chunk = ReadFileChunk.from_filename( - filename, start_byte=10, chunk_size=10, callbacks=[self.callback] - ) - - # Seek calls that generate "0" progress are skipped by - # invoke_progress_callbacks and won't appear in the list. - expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10] - - self._assert_out_of_bound_start_seek(chunk, expected_callback_prog) - self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog) - self._assert_out_of_bound_end_seek(chunk, expected_callback_prog) - - def _assert_out_of_bound_start_seek(self, chunk, expected): - # clear amounts_seen - self.amounts_seen = [] - self.assertEqual(self.amounts_seen, []) - - # (position, change) - chunk.seek(20) # (20, 10) - chunk.seek(5) # (5, -5) - chunk.seek(20) # (20, 5) - chunk.seek(9) # (9, -1) - chunk.seek(20) # (20, 1) - chunk.seek(11) # (11, 0) - chunk.seek(20) # (20, 0) - chunk.seek(9) # (9, -1) - chunk.seek(20) # (20, 1) - chunk.seek(5) # (5, -5) - chunk.seek(20) # (20, 5) - chunk.seek(0) # (0, -10) - chunk.seek(0) # (0, 0) - - self.assertEqual(self.amounts_seen, expected) - - def _assert_out_of_bound_relative_seek(self, chunk, expected): - # clear amounts_seen - self.amounts_seen = [] - self.assertEqual(self.amounts_seen, []) - - # (position, change) - chunk.seek(20, 1) # (20, 10) - chunk.seek(-15, 1) # (5, -5) - chunk.seek(15, 1) # (20, 5) - chunk.seek(-11, 1) # (9, -1) - chunk.seek(11, 1) # (20, 1) - chunk.seek(-9, 1) # (11, 0) - chunk.seek(9, 1) # (20, 0) - chunk.seek(-11, 1) # (9, -1) - chunk.seek(11, 1) # (20, 1) - chunk.seek(-15, 1) # (5, -5) - chunk.seek(15, 1) # (20, 5) - chunk.seek(-20, 1) # (0, -10) - chunk.seek(-1000, 1) # (0, 0) - - self.assertEqual(self.amounts_seen, expected) - - def _assert_out_of_bound_end_seek(self, chunk, expected): - # clear amounts_seen - self.amounts_seen = [] - self.assertEqual(self.amounts_seen, []) - - # (position, change) - chunk.seek(10, 2) # (20, 10) - chunk.seek(-5, 2) # (5, -5) - chunk.seek(10, 2) # (20, 5) - chunk.seek(-1, 2) # (9, -1) - chunk.seek(10, 2) # (20, 1) - chunk.seek(1, 2) # (11, 0) - chunk.seek(10, 2) # (20, 0) - chunk.seek(-1, 2) # (9, -1) - chunk.seek(10, 2) # (20, 1) - chunk.seek(-5, 2) # (5, -5) - chunk.seek(10, 2) # (20, 5) - chunk.seek(-10, 2) # (0, -10) - chunk.seek(-1000, 2) # (0, 0) - - self.assertEqual(self.amounts_seen, expected) - - def test_close_callbacks(self): - with open(self.filename) as f: - chunk = ReadFileChunk( - f, - chunk_size=1, - full_file_size=3, - close_callbacks=[self.close_callback], - ) - chunk.close() - self.assertEqual(self.num_close_callback_calls, 1) - - def test_close_callbacks_when_not_enabled(self): - with open(self.filename) as f: - chunk = ReadFileChunk( - f, - chunk_size=1, - full_file_size=3, - enable_callbacks=False, - close_callbacks=[self.close_callback], - ) - chunk.close() - self.assertEqual(self.num_close_callback_calls, 0) - - def test_close_callbacks_when_context_handler_is_used(self): - with open(self.filename) as f: - with ReadFileChunk( - f, - chunk_size=1, - full_file_size=3, - close_callbacks=[self.close_callback], - ) as chunk: - chunk.read(1) - self.assertEqual(self.num_close_callback_calls, 1) - - def test_signal_transferring(self): - chunk = ReadFileChunk.from_filename( - self.filename, - start_byte=0, - chunk_size=3, - callbacks=[self.callback], - ) - chunk.signal_not_transferring() - chunk.read(1) - self.assertEqual(self.amounts_seen, []) - chunk.signal_transferring() - chunk.read(1) - self.assertEqual(self.amounts_seen, [1]) - - def test_signal_transferring_to_underlying_fileobj(self): - underlying_stream = mock.Mock() - underlying_stream.tell.return_value = 0 - chunk = ReadFileChunk(underlying_stream, 3, 3) - chunk.signal_transferring() - self.assertTrue(underlying_stream.signal_transferring.called) - - def test_no_call_signal_transferring_to_underlying_fileobj(self): - underlying_stream = mock.Mock(io.RawIOBase) - underlying_stream.tell.return_value = 0 - chunk = ReadFileChunk(underlying_stream, 3, 3) - try: - chunk.signal_transferring() - except AttributeError: - self.fail( - 'The stream should not have tried to call signal_transferring ' - 'to the underlying stream.' - ) - - def test_signal_not_transferring_to_underlying_fileobj(self): - underlying_stream = mock.Mock() - underlying_stream.tell.return_value = 0 - chunk = ReadFileChunk(underlying_stream, 3, 3) - chunk.signal_not_transferring() - self.assertTrue(underlying_stream.signal_not_transferring.called) - - def test_no_call_signal_not_transferring_to_underlying_fileobj(self): - underlying_stream = mock.Mock(io.RawIOBase) - underlying_stream.tell.return_value = 0 - chunk = ReadFileChunk(underlying_stream, 3, 3) - try: - chunk.signal_not_transferring() - except AttributeError: - self.fail( - 'The stream should not have tried to call ' - 'signal_not_transferring to the underlying stream.' - ) - - -class TestStreamReaderProgress(BaseUtilsTest): - def test_proxies_to_wrapped_stream(self): - original_stream = StringIO('foobarbaz') - wrapped = StreamReaderProgress(original_stream) - self.assertEqual(wrapped.read(), 'foobarbaz') - - def test_callback_invoked(self): - original_stream = StringIO('foobarbaz') - wrapped = StreamReaderProgress( - original_stream, [self.callback, self.callback] - ) - self.assertEqual(wrapped.read(), 'foobarbaz') - self.assertEqual(self.amounts_seen, [9, 9]) - - -class TestTaskSemaphore(unittest.TestCase): - def setUp(self): - self.semaphore = TaskSemaphore(1) - - def test_should_block_at_max_capacity(self): - self.semaphore.acquire('a', blocking=False) - with self.assertRaises(NoResourcesAvailable): - self.semaphore.acquire('a', blocking=False) - - def test_release_capacity(self): - acquire_token = self.semaphore.acquire('a', blocking=False) - self.semaphore.release('a', acquire_token) - try: - self.semaphore.acquire('a', blocking=False) - except NoResourcesAvailable: - self.fail( - 'The release of the semaphore should have allowed for ' - 'the second acquire to not be blocked' - ) - - -class TestSlidingWindowSemaphore(unittest.TestCase): - # These tests use block=False to tests will fail - # instead of hang the test runner in the case of x - # incorrect behavior. - def test_acquire_release_basic_case(self): - sem = SlidingWindowSemaphore(1) - # Count is 1 - - num = sem.acquire('a', blocking=False) - self.assertEqual(num, 0) - sem.release('a', 0) - # Count now back to 1. - - def test_can_acquire_release_multiple_times(self): - sem = SlidingWindowSemaphore(1) - num = sem.acquire('a', blocking=False) - self.assertEqual(num, 0) - sem.release('a', num) - - num = sem.acquire('a', blocking=False) - self.assertEqual(num, 1) - sem.release('a', num) - - def test_can_acquire_a_range(self): - sem = SlidingWindowSemaphore(3) - self.assertEqual(sem.acquire('a', blocking=False), 0) - self.assertEqual(sem.acquire('a', blocking=False), 1) - self.assertEqual(sem.acquire('a', blocking=False), 2) - sem.release('a', 0) - sem.release('a', 1) - sem.release('a', 2) - # Now we're reset so we should be able to acquire the same - # sequence again. - self.assertEqual(sem.acquire('a', blocking=False), 3) - self.assertEqual(sem.acquire('a', blocking=False), 4) - self.assertEqual(sem.acquire('a', blocking=False), 5) - self.assertEqual(sem.current_count(), 0) - - def test_counter_release_only_on_min_element(self): - sem = SlidingWindowSemaphore(3) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - - # The count only increases when we free the min - # element. This means if we're currently failing to - # acquire now: - with self.assertRaises(NoResourcesAvailable): - sem.acquire('a', blocking=False) - - # Then freeing a non-min element: - sem.release('a', 1) - - # doesn't change anything. We still fail to acquire. - with self.assertRaises(NoResourcesAvailable): - sem.acquire('a', blocking=False) - self.assertEqual(sem.current_count(), 0) - - def test_raises_error_when_count_is_zero(self): - sem = SlidingWindowSemaphore(3) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - - # Count is now 0 so trying to acquire should fail. - with self.assertRaises(NoResourcesAvailable): - sem.acquire('a', blocking=False) - - def test_release_counters_can_increment_counter_repeatedly(self): - sem = SlidingWindowSemaphore(3) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - - # These two releases don't increment the counter - # because we're waiting on 0. - sem.release('a', 1) - sem.release('a', 2) - self.assertEqual(sem.current_count(), 0) - # But as soon as we release 0, we free up 0, 1, and 2. - sem.release('a', 0) - self.assertEqual(sem.current_count(), 3) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - - def test_error_to_release_unknown_tag(self): - sem = SlidingWindowSemaphore(3) - with self.assertRaises(ValueError): - sem.release('a', 0) - - def test_can_track_multiple_tags(self): - sem = SlidingWindowSemaphore(3) - self.assertEqual(sem.acquire('a', blocking=False), 0) - self.assertEqual(sem.acquire('b', blocking=False), 0) - self.assertEqual(sem.acquire('a', blocking=False), 1) - - # We're at our max of 3 even though 2 are for A and 1 is for B. - with self.assertRaises(NoResourcesAvailable): - sem.acquire('a', blocking=False) - with self.assertRaises(NoResourcesAvailable): - sem.acquire('b', blocking=False) - - def test_can_handle_multiple_tags_released(self): - sem = SlidingWindowSemaphore(4) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('b', blocking=False) - sem.acquire('b', blocking=False) - - sem.release('b', 1) - sem.release('a', 1) - self.assertEqual(sem.current_count(), 0) - - sem.release('b', 0) - self.assertEqual(sem.acquire('a', blocking=False), 2) - - sem.release('a', 0) - self.assertEqual(sem.acquire('b', blocking=False), 2) - - def test_is_error_to_release_unknown_sequence_number(self): - sem = SlidingWindowSemaphore(3) - sem.acquire('a', blocking=False) - with self.assertRaises(ValueError): - sem.release('a', 1) - - def test_is_error_to_double_release(self): - # This is different than other error tests because - # we're verifying we can reset the state after an - # acquire/release cycle. - sem = SlidingWindowSemaphore(2) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.release('a', 0) - sem.release('a', 1) - self.assertEqual(sem.current_count(), 2) - with self.assertRaises(ValueError): - sem.release('a', 0) - - def test_can_check_in_partial_range(self): - sem = SlidingWindowSemaphore(4) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - - sem.release('a', 1) - sem.release('a', 3) - sem.release('a', 0) - self.assertEqual(sem.current_count(), 2) - - -class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase): - # These tests focus on mutithreaded properties of the range - # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore. - def setUp(self): - self.threads = [] - - def tearDown(self): - self.join_threads() - - def join_threads(self): - for thread in self.threads: - thread.join() - self.threads = [] - - def start_threads(self): - for thread in self.threads: - thread.start() - - def test_acquire_blocks_until_release_is_called(self): - sem = SlidingWindowSemaphore(2) - sem.acquire('a', blocking=False) - sem.acquire('a', blocking=False) - - def acquire(): - # This next call to acquire will block. - self.assertEqual(sem.acquire('a', blocking=True), 2) - - t = threading.Thread(target=acquire) - self.threads.append(t) - # Starting the thread will block the sem.acquire() - # in the acquire function above. - t.start() - # This still will keep the thread blocked. - sem.release('a', 1) - # Releasing the min element will unblock the thread. - sem.release('a', 0) - t.join() - sem.release('a', 2) - - def test_stress_invariants_random_order(self): - sem = SlidingWindowSemaphore(100) - for _ in range(10): - recorded = [] - for _ in range(100): - recorded.append(sem.acquire('a', blocking=False)) - # Release them in randomized order. As long as we - # eventually free all 100, we should have all the - # resources released. - random.shuffle(recorded) - for i in recorded: - sem.release('a', i) - - # Everything's freed so should be back at count == 100 - self.assertEqual(sem.current_count(), 100) - - def test_blocking_stress(self): - sem = SlidingWindowSemaphore(5) - num_threads = 10 - num_iterations = 50 - - def acquire(): - for _ in range(num_iterations): - num = sem.acquire('a', blocking=True) - time.sleep(0.001) - sem.release('a', num) - - for i in range(num_threads): - t = threading.Thread(target=acquire) - self.threads.append(t) - self.start_threads() - self.join_threads() - # Should have all the available resources freed. - self.assertEqual(sem.current_count(), 5) - # Should have acquired num_threads * num_iterations - self.assertEqual( - sem.acquire('a', blocking=False), num_threads * num_iterations - ) - - -class TestAdjustChunksize(unittest.TestCase): - def setUp(self): - self.adjuster = ChunksizeAdjuster() - - def test_valid_chunksize(self): - chunksize = 7 * (1024 ** 2) - file_size = 8 * (1024 ** 2) - new_size = self.adjuster.adjust_chunksize(chunksize, file_size) - self.assertEqual(new_size, chunksize) - - def test_chunksize_below_minimum(self): - chunksize = MIN_UPLOAD_CHUNKSIZE - 1 - file_size = 3 * MIN_UPLOAD_CHUNKSIZE - new_size = self.adjuster.adjust_chunksize(chunksize, file_size) - self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) - - def test_chunksize_above_maximum(self): - chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 - file_size = MAX_SINGLE_UPLOAD_SIZE * 2 - new_size = self.adjuster.adjust_chunksize(chunksize, file_size) - self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) - - def test_chunksize_too_small(self): - chunksize = 7 * (1024 ** 2) - file_size = 5 * (1024 ** 4) - # If we try to upload a 5TB file, we'll need to use 896MB part - # sizes. - new_size = self.adjuster.adjust_chunksize(chunksize, file_size) - self.assertEqual(new_size, 896 * (1024 ** 2)) - num_parts = file_size / new_size - self.assertLessEqual(num_parts, MAX_PARTS) - - def test_unknown_file_size_with_valid_chunksize(self): - chunksize = 7 * (1024 ** 2) - new_size = self.adjuster.adjust_chunksize(chunksize) - self.assertEqual(new_size, chunksize) - - def test_unknown_file_size_below_minimum(self): - chunksize = MIN_UPLOAD_CHUNKSIZE - 1 - new_size = self.adjuster.adjust_chunksize(chunksize) - self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) - - def test_unknown_file_size_above_maximum(self): - chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 - new_size = self.adjuster.adjust_chunksize(chunksize) - self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) +# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import io +import os.path +import random +import re +import shutil +import tempfile +import threading +import time +from io import BytesIO, StringIO + +from s3transfer.futures import TransferFuture, TransferMeta +from s3transfer.utils import ( + MAX_PARTS, + MAX_SINGLE_UPLOAD_SIZE, + MIN_UPLOAD_CHUNKSIZE, + CallArgs, + ChunksizeAdjuster, + CountCallbackInvoker, + DeferredOpenFile, + FunctionContainer, + NoResourcesAvailable, + OSUtils, + ReadFileChunk, + SlidingWindowSemaphore, + StreamReaderProgress, + TaskSemaphore, + calculate_num_parts, + calculate_range_parameter, + get_callbacks, + get_filtered_dict, + invoke_progress_callbacks, + random_file_extension, +) +from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest + + +class TestGetCallbacks(unittest.TestCase): + def setUp(self): + self.subscriber = RecordingSubscriber() + self.second_subscriber = RecordingSubscriber() + self.call_args = CallArgs( + subscribers=[self.subscriber, self.second_subscriber] + ) + self.transfer_meta = TransferMeta(self.call_args) + self.transfer_future = TransferFuture(self.transfer_meta) + + def test_get_callbacks(self): + callbacks = get_callbacks(self.transfer_future, 'queued') + # Make sure two callbacks were added as both subscribers had + # an on_queued method. + self.assertEqual(len(callbacks), 2) + + # Ensure that the callback was injected with the future by calling + # one of them and checking that the future was used in the call. + callbacks[0]() + self.assertEqual( + self.subscriber.on_queued_calls, [{'future': self.transfer_future}] + ) + + def test_get_callbacks_for_missing_type(self): + callbacks = get_callbacks(self.transfer_future, 'fake_state') + # There should be no callbacks as the subscribers will not have the + # on_fake_state method + self.assertEqual(len(callbacks), 0) + + +class TestGetFilteredDict(unittest.TestCase): + def test_get_filtered_dict(self): + original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'} + whitelist = ['Include'] + self.assertEqual( + get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'} + ) + + +class TestCallArgs(unittest.TestCase): + def test_call_args(self): + call_args = CallArgs(foo='bar', biz='baz') + self.assertEqual(call_args.foo, 'bar') + self.assertEqual(call_args.biz, 'baz') + + +class TestFunctionContainer(unittest.TestCase): + def get_args_kwargs(self, *args, **kwargs): + return args, kwargs + + def test_call(self): + func_container = FunctionContainer( + self.get_args_kwargs, 'foo', bar='baz' + ) + self.assertEqual(func_container(), (('foo',), {'bar': 'baz'})) + + def test_repr(self): + func_container = FunctionContainer( + self.get_args_kwargs, 'foo', bar='baz' + ) + self.assertEqual( + str(func_container), + 'Function: {} with args {} and kwargs {}'.format( + self.get_args_kwargs, ('foo',), {'bar': 'baz'} + ), + ) + + +class TestCountCallbackInvoker(unittest.TestCase): + def invoke_callback(self): + self.ref_results.append('callback invoked') + + def assert_callback_invoked(self): + self.assertEqual(self.ref_results, ['callback invoked']) + + def assert_callback_not_invoked(self): + self.assertEqual(self.ref_results, []) + + def setUp(self): + self.ref_results = [] + self.invoker = CountCallbackInvoker(self.invoke_callback) + + def test_increment(self): + self.invoker.increment() + self.assertEqual(self.invoker.current_count, 1) + + def test_decrement(self): + self.invoker.increment() + self.invoker.increment() + self.invoker.decrement() + self.assertEqual(self.invoker.current_count, 1) + + def test_count_cannot_go_below_zero(self): + with self.assertRaises(RuntimeError): + self.invoker.decrement() + + def test_callback_invoked_only_once_finalized(self): + self.invoker.increment() + self.invoker.decrement() + self.assert_callback_not_invoked() + self.invoker.finalize() + # Callback should only be invoked once finalized + self.assert_callback_invoked() + + def test_callback_invoked_after_finalizing_and_count_reaching_zero(self): + self.invoker.increment() + self.invoker.finalize() + # Make sure that it does not get invoked immediately after + # finalizing as the count is currently one + self.assert_callback_not_invoked() + self.invoker.decrement() + self.assert_callback_invoked() + + def test_cannot_increment_after_finalization(self): + self.invoker.finalize() + with self.assertRaises(RuntimeError): + self.invoker.increment() + + +class TestRandomFileExtension(unittest.TestCase): + def test_has_proper_length(self): + self.assertEqual(len(random_file_extension(num_digits=4)), 4) + + +class TestInvokeProgressCallbacks(unittest.TestCase): + def test_invoke_progress_callbacks(self): + recording_subscriber = RecordingSubscriber() + invoke_progress_callbacks([recording_subscriber.on_progress], 2) + self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2) + + def test_invoke_progress_callbacks_with_no_progress(self): + recording_subscriber = RecordingSubscriber() + invoke_progress_callbacks([recording_subscriber.on_progress], 0) + self.assertEqual(len(recording_subscriber.on_progress_calls), 0) + + +class TestCalculateNumParts(unittest.TestCase): + def test_calculate_num_parts_divisible(self): + self.assertEqual(calculate_num_parts(size=4, part_size=2), 2) + + def test_calculate_num_parts_not_divisible(self): + self.assertEqual(calculate_num_parts(size=3, part_size=2), 2) + + +class TestCalculateRangeParameter(unittest.TestCase): + def setUp(self): + self.part_size = 5 + self.part_index = 1 + self.num_parts = 3 + + def test_calculate_range_paramter(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, self.num_parts + ) + self.assertEqual(range_val, 'bytes=5-9') + + def test_last_part_with_no_total_size(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, num_parts=2 + ) + self.assertEqual(range_val, 'bytes=5-') + + def test_last_part_with_total_size(self): + range_val = calculate_range_parameter( + self.part_size, self.part_index, num_parts=2, total_size=8 + ) + self.assertEqual(range_val, 'bytes=5-7') + + +class BaseUtilsTest(unittest.TestCase): + def setUp(self): + self.tempdir = tempfile.mkdtemp() + self.filename = os.path.join(self.tempdir, 'foo') + self.content = b'abc' + with open(self.filename, 'wb') as f: + f.write(self.content) + self.amounts_seen = [] + self.num_close_callback_calls = 0 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def callback(self, bytes_transferred): + self.amounts_seen.append(bytes_transferred) + + def close_callback(self): + self.num_close_callback_calls += 1 + + +class TestOSUtils(BaseUtilsTest): + def test_get_file_size(self): + self.assertEqual( + OSUtils().get_file_size(self.filename), len(self.content) + ) + + def test_open_file_chunk_reader(self): + reader = OSUtils().open_file_chunk_reader( + self.filename, 0, 3, [self.callback] + ) + + # The returned reader should be a ReadFileChunk. + self.assertIsInstance(reader, ReadFileChunk) + # The content of the reader should be correct. + self.assertEqual(reader.read(), self.content) + # Callbacks should be disabled depspite being passed in. + self.assertEqual(self.amounts_seen, []) + + def test_open_file_chunk_reader_from_fileobj(self): + with open(self.filename, 'rb') as f: + reader = OSUtils().open_file_chunk_reader_from_fileobj( + f, len(self.content), len(self.content), [self.callback] + ) + + # The returned reader should be a ReadFileChunk. + self.assertIsInstance(reader, ReadFileChunk) + # The content of the reader should be correct. + self.assertEqual(reader.read(), self.content) + reader.close() + # Callbacks should be disabled depspite being passed in. + self.assertEqual(self.amounts_seen, []) + self.assertEqual(self.num_close_callback_calls, 0) + + def test_open_file(self): + fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w') + self.assertTrue(hasattr(fileobj, 'write')) + + def test_remove_file_ignores_errors(self): + non_existent_file = os.path.join(self.tempdir, 'no-exist') + # This should not exist to start. + self.assertFalse(os.path.exists(non_existent_file)) + try: + OSUtils().remove_file(non_existent_file) + except OSError as e: + self.fail('OSError should have been caught: %s' % e) + + def test_remove_file_proxies_remove_file(self): + OSUtils().remove_file(self.filename) + self.assertFalse(os.path.exists(self.filename)) + + def test_rename_file(self): + new_filename = os.path.join(self.tempdir, 'newfoo') + OSUtils().rename_file(self.filename, new_filename) + self.assertFalse(os.path.exists(self.filename)) + self.assertTrue(os.path.exists(new_filename)) + + def test_is_special_file_for_normal_file(self): + self.assertFalse(OSUtils().is_special_file(self.filename)) + + def test_is_special_file_for_non_existant_file(self): + non_existant_filename = os.path.join(self.tempdir, 'no-exist') + self.assertFalse(os.path.exists(non_existant_filename)) + self.assertFalse(OSUtils().is_special_file(non_existant_filename)) + + def test_get_temp_filename(self): + filename = 'myfile' + self.assertIsNotNone( + re.match( + r'%s\.[0-9A-Fa-f]{8}$' % filename, + OSUtils().get_temp_filename(filename), + ) + ) + + def test_get_temp_filename_len_255(self): + filename = 'a' * 255 + temp_filename = OSUtils().get_temp_filename(filename) + self.assertLessEqual(len(temp_filename), 255) + + def test_get_temp_filename_len_gt_255(self): + filename = 'a' * 280 + temp_filename = OSUtils().get_temp_filename(filename) + self.assertLessEqual(len(temp_filename), 255) + + def test_allocate(self): + truncate_size = 1 + OSUtils().allocate(self.filename, truncate_size) + with open(self.filename, 'rb') as f: + self.assertEqual(len(f.read()), truncate_size) + + @mock.patch('s3transfer.utils.fallocate') + def test_allocate_with_io_error(self, mock_fallocate): + mock_fallocate.side_effect = IOError() + with self.assertRaises(IOError): + OSUtils().allocate(self.filename, 1) + self.assertFalse(os.path.exists(self.filename)) + + @mock.patch('s3transfer.utils.fallocate') + def test_allocate_with_os_error(self, mock_fallocate): + mock_fallocate.side_effect = OSError() + with self.assertRaises(OSError): + OSUtils().allocate(self.filename, 1) + self.assertFalse(os.path.exists(self.filename)) + + +class TestDeferredOpenFile(BaseUtilsTest): + def setUp(self): + super().setUp() + self.filename = os.path.join(self.tempdir, 'foo') + self.contents = b'my contents' + with open(self.filename, 'wb') as f: + f.write(self.contents) + self.deferred_open_file = DeferredOpenFile( + self.filename, open_function=self.recording_open_function + ) + self.open_call_args = [] + + def tearDown(self): + self.deferred_open_file.close() + super().tearDown() + + def recording_open_function(self, filename, mode): + self.open_call_args.append((filename, mode)) + return open(filename, mode) + + def open_nonseekable(self, filename, mode): + self.open_call_args.append((filename, mode)) + return NonSeekableWriter(BytesIO(self.content)) + + def test_instantiation_does_not_open_file(self): + DeferredOpenFile( + self.filename, open_function=self.recording_open_function + ) + self.assertEqual(len(self.open_call_args), 0) + + def test_name(self): + self.assertEqual(self.deferred_open_file.name, self.filename) + + def test_read(self): + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[0:2]) + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[2:4]) + self.assertEqual(len(self.open_call_args), 1) + + def test_write(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + open_function=self.recording_open_function, + ) + + write_content = b'foo' + self.deferred_open_file.write(write_content) + self.deferred_open_file.write(write_content) + self.deferred_open_file.close() + # Both of the writes should now be in the file. + with open(self.filename, 'rb') as f: + self.assertEqual(f.read(), write_content * 2) + # Open should have only been called once. + self.assertEqual(len(self.open_call_args), 1) + + def test_seek(self): + self.deferred_open_file.seek(2) + content = self.deferred_open_file.read(2) + self.assertEqual(content, self.contents[2:4]) + self.assertEqual(len(self.open_call_args), 1) + + def test_open_does_not_seek_with_zero_start_byte(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + start_byte=0, + open_function=self.open_nonseekable, + ) + + try: + # If this seeks, an UnsupportedOperation error will be raised. + self.deferred_open_file.write(b'data') + except io.UnsupportedOperation: + self.fail('DeferredOpenFile seeked upon opening') + + def test_open_seeks_with_nonzero_start_byte(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='wb', + start_byte=5, + open_function=self.open_nonseekable, + ) + + # Since a non-seekable file is being opened, calling Seek will raise + # an UnsupportedOperation error. + with self.assertRaises(io.UnsupportedOperation): + self.deferred_open_file.write(b'data') + + def test_tell(self): + self.deferred_open_file.tell() + # tell() should not have opened the file if it has not been seeked + # or read because we know the start bytes upfront. + self.assertEqual(len(self.open_call_args), 0) + + self.deferred_open_file.seek(2) + self.assertEqual(self.deferred_open_file.tell(), 2) + self.assertEqual(len(self.open_call_args), 1) + + def test_open_args(self): + self.deferred_open_file = DeferredOpenFile( + self.filename, + mode='ab+', + open_function=self.recording_open_function, + ) + # Force an open + self.deferred_open_file.write(b'data') + self.assertEqual(len(self.open_call_args), 1) + self.assertEqual(self.open_call_args[0], (self.filename, 'ab+')) + + def test_context_handler(self): + with self.deferred_open_file: + self.assertEqual(len(self.open_call_args), 1) + + +class TestReadFileChunk(BaseUtilsTest): + def test_read_entire_chunk(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=3 + ) + self.assertEqual(chunk.read(), b'one') + self.assertEqual(chunk.read(), b'') + + def test_read_with_amount_size(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(1), b'f') + self.assertEqual(chunk.read(1), b'o') + self.assertEqual(chunk.read(1), b'u') + self.assertEqual(chunk.read(1), b'r') + self.assertEqual(chunk.read(1), b'') + + def test_reset_stream_emulation(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=11, chunk_size=4 + ) + self.assertEqual(chunk.read(), b'four') + chunk.seek(0) + self.assertEqual(chunk.read(), b'four') + + def test_read_past_end_of_file(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.read(), b'') + self.assertEqual(len(chunk), 3) + + def test_tell_and_seek(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'onetwothreefourfivesixseveneightnineten') + chunk = ReadFileChunk.from_filename( + filename, start_byte=36, chunk_size=100000 + ) + self.assertEqual(chunk.tell(), 0) + self.assertEqual(chunk.read(), b'ten') + self.assertEqual(chunk.tell(), 3) + chunk.seek(0) + self.assertEqual(chunk.tell(), 0) + chunk.seek(1, whence=1) + self.assertEqual(chunk.tell(), 1) + chunk.seek(-1, whence=1) + self.assertEqual(chunk.tell(), 0) + chunk.seek(-1, whence=2) + self.assertEqual(chunk.tell(), 2) + + def test_tell_and_seek_boundaries(self): + # Test to ensure ReadFileChunk behaves the same as the + # Python standard library around seeking and reading out + # of bounds in a file object. + data = b'abcdefghij12345678klmnopqrst' + start_pos = 10 + chunk_size = 8 + + # Create test file + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(data) + + # ReadFileChunk should be a substring of only numbers + file_objects = [ + ReadFileChunk.from_filename( + filename, start_byte=start_pos, chunk_size=chunk_size + ) + ] + + # Uncomment next line to validate we match Python's io.BytesIO + # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size])) + + for obj in file_objects: + self._assert_whence_start_behavior(obj) + self._assert_whence_end_behavior(obj) + self._assert_whence_relative_behavior(obj) + self._assert_boundary_behavior(obj) + + def _assert_whence_start_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(1, 0) + self.assertEqual(file_obj.tell(), 1) + + file_obj.seek(1) + self.assertEqual(file_obj.tell(), 1) + self.assertEqual(file_obj.read(), b'2345678') + + file_obj.seek(3, 0) + self.assertEqual(file_obj.tell(), 3) + + file_obj.seek(0, 0) + self.assertEqual(file_obj.tell(), 0) + + def _assert_whence_relative_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(2, 1) + self.assertEqual(file_obj.tell(), 2) + + file_obj.seek(1, 1) + self.assertEqual(file_obj.tell(), 3) + self.assertEqual(file_obj.read(), b'45678') + + file_obj.seek(20, 1) + self.assertEqual(file_obj.tell(), 28) + + file_obj.seek(-30, 1) + self.assertEqual(file_obj.tell(), 0) + self.assertEqual(file_obj.read(), b'12345678') + + file_obj.seek(-8, 1) + self.assertEqual(file_obj.tell(), 0) + + def _assert_whence_end_behavior(self, file_obj): + self.assertEqual(file_obj.tell(), 0) + + file_obj.seek(-1, 2) + self.assertEqual(file_obj.tell(), 7) + + file_obj.seek(1, 2) + self.assertEqual(file_obj.tell(), 9) + + file_obj.seek(3, 2) + self.assertEqual(file_obj.tell(), 11) + self.assertEqual(file_obj.read(), b'') + + file_obj.seek(-15, 2) + self.assertEqual(file_obj.tell(), 0) + self.assertEqual(file_obj.read(), b'12345678') + + file_obj.seek(-8, 2) + self.assertEqual(file_obj.tell(), 0) + + def _assert_boundary_behavior(self, file_obj): + # Verify we're at the start + self.assertEqual(file_obj.tell(), 0) + + # Verify we can't move backwards beyond start of file + file_obj.seek(-10, 1) + self.assertEqual(file_obj.tell(), 0) + + # Verify we *can* move after end of file, but return nothing + file_obj.seek(10, 2) + self.assertEqual(file_obj.tell(), 18) + self.assertEqual(file_obj.read(), b'') + self.assertEqual(file_obj.read(10), b'') + + # Verify we can partially rewind + file_obj.seek(-12, 1) + self.assertEqual(file_obj.tell(), 6) + self.assertEqual(file_obj.read(), b'78') + self.assertEqual(file_obj.tell(), 8) + + # Verify we can rewind to start + file_obj.seek(0) + self.assertEqual(file_obj.tell(), 0) + + def test_file_chunk_supports_context_manager(self): + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(b'abc') + with ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=2 + ) as chunk: + val = chunk.read() + self.assertEqual(val, b'ab') + + def test_iter_is_always_empty(self): + # This tests the workaround for the httplib bug (see + # the source for more info). + filename = os.path.join(self.tempdir, 'foo') + open(filename, 'wb').close() + chunk = ReadFileChunk.from_filename( + filename, start_byte=0, chunk_size=10 + ) + self.assertEqual(list(chunk), []) + + def test_callback_is_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + self.assertEqual(self.amounts_seen, [1, 1, 1]) + + def test_all_callbacks_invoked_on_read(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback, self.callback], + ) + chunk.read(1) + chunk.read(1) + chunk.read(1) + # The list should be twice as long because there are two callbacks + # recording the amount read. + self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1]) + + def test_callback_can_be_disabled(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.disable_callback() + # Now reading from the ReadFileChunk should not invoke + # the callback. + chunk.read() + self.assertEqual(self.amounts_seen, []) + + def test_callback_will_also_be_triggered_by_seek(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.read(2) + chunk.seek(0) + chunk.read(2) + chunk.seek(1) + chunk.read(2) + self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2]) + + def test_callback_triggered_by_out_of_bound_seeks(self): + data = b'abcdefghij1234567890klmnopqr' + + # Create test file + filename = os.path.join(self.tempdir, 'foo') + with open(filename, 'wb') as f: + f.write(data) + chunk = ReadFileChunk.from_filename( + filename, start_byte=10, chunk_size=10, callbacks=[self.callback] + ) + + # Seek calls that generate "0" progress are skipped by + # invoke_progress_callbacks and won't appear in the list. + expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10] + + self._assert_out_of_bound_start_seek(chunk, expected_callback_prog) + self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog) + self._assert_out_of_bound_end_seek(chunk, expected_callback_prog) + + def _assert_out_of_bound_start_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(20) # (20, 10) + chunk.seek(5) # (5, -5) + chunk.seek(20) # (20, 5) + chunk.seek(9) # (9, -1) + chunk.seek(20) # (20, 1) + chunk.seek(11) # (11, 0) + chunk.seek(20) # (20, 0) + chunk.seek(9) # (9, -1) + chunk.seek(20) # (20, 1) + chunk.seek(5) # (5, -5) + chunk.seek(20) # (20, 5) + chunk.seek(0) # (0, -10) + chunk.seek(0) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def _assert_out_of_bound_relative_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(20, 1) # (20, 10) + chunk.seek(-15, 1) # (5, -5) + chunk.seek(15, 1) # (20, 5) + chunk.seek(-11, 1) # (9, -1) + chunk.seek(11, 1) # (20, 1) + chunk.seek(-9, 1) # (11, 0) + chunk.seek(9, 1) # (20, 0) + chunk.seek(-11, 1) # (9, -1) + chunk.seek(11, 1) # (20, 1) + chunk.seek(-15, 1) # (5, -5) + chunk.seek(15, 1) # (20, 5) + chunk.seek(-20, 1) # (0, -10) + chunk.seek(-1000, 1) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def _assert_out_of_bound_end_seek(self, chunk, expected): + # clear amounts_seen + self.amounts_seen = [] + self.assertEqual(self.amounts_seen, []) + + # (position, change) + chunk.seek(10, 2) # (20, 10) + chunk.seek(-5, 2) # (5, -5) + chunk.seek(10, 2) # (20, 5) + chunk.seek(-1, 2) # (9, -1) + chunk.seek(10, 2) # (20, 1) + chunk.seek(1, 2) # (11, 0) + chunk.seek(10, 2) # (20, 0) + chunk.seek(-1, 2) # (9, -1) + chunk.seek(10, 2) # (20, 1) + chunk.seek(-5, 2) # (5, -5) + chunk.seek(10, 2) # (20, 5) + chunk.seek(-10, 2) # (0, -10) + chunk.seek(-1000, 2) # (0, 0) + + self.assertEqual(self.amounts_seen, expected) + + def test_close_callbacks(self): + with open(self.filename) as f: + chunk = ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + close_callbacks=[self.close_callback], + ) + chunk.close() + self.assertEqual(self.num_close_callback_calls, 1) + + def test_close_callbacks_when_not_enabled(self): + with open(self.filename) as f: + chunk = ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + enable_callbacks=False, + close_callbacks=[self.close_callback], + ) + chunk.close() + self.assertEqual(self.num_close_callback_calls, 0) + + def test_close_callbacks_when_context_handler_is_used(self): + with open(self.filename) as f: + with ReadFileChunk( + f, + chunk_size=1, + full_file_size=3, + close_callbacks=[self.close_callback], + ) as chunk: + chunk.read(1) + self.assertEqual(self.num_close_callback_calls, 1) + + def test_signal_transferring(self): + chunk = ReadFileChunk.from_filename( + self.filename, + start_byte=0, + chunk_size=3, + callbacks=[self.callback], + ) + chunk.signal_not_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, []) + chunk.signal_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, [1]) + + def test_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_transferring() + self.assertTrue(underlying_stream.signal_transferring.called) + + def test_no_call_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call signal_transferring ' + 'to the underlying stream.' + ) + + def test_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_not_transferring() + self.assertTrue(underlying_stream.signal_not_transferring.called) + + def test_no_call_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_not_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call ' + 'signal_not_transferring to the underlying stream.' + ) + + +class TestStreamReaderProgress(BaseUtilsTest): + def test_proxies_to_wrapped_stream(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress(original_stream) + self.assertEqual(wrapped.read(), 'foobarbaz') + + def test_callback_invoked(self): + original_stream = StringIO('foobarbaz') + wrapped = StreamReaderProgress( + original_stream, [self.callback, self.callback] + ) + self.assertEqual(wrapped.read(), 'foobarbaz') + self.assertEqual(self.amounts_seen, [9, 9]) + + +class TestTaskSemaphore(unittest.TestCase): + def setUp(self): + self.semaphore = TaskSemaphore(1) + + def test_should_block_at_max_capacity(self): + self.semaphore.acquire('a', blocking=False) + with self.assertRaises(NoResourcesAvailable): + self.semaphore.acquire('a', blocking=False) + + def test_release_capacity(self): + acquire_token = self.semaphore.acquire('a', blocking=False) + self.semaphore.release('a', acquire_token) + try: + self.semaphore.acquire('a', blocking=False) + except NoResourcesAvailable: + self.fail( + 'The release of the semaphore should have allowed for ' + 'the second acquire to not be blocked' + ) + + +class TestSlidingWindowSemaphore(unittest.TestCase): + # These tests use block=False to tests will fail + # instead of hang the test runner in the case of x + # incorrect behavior. + def test_acquire_release_basic_case(self): + sem = SlidingWindowSemaphore(1) + # Count is 1 + + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 0) + sem.release('a', 0) + # Count now back to 1. + + def test_can_acquire_release_multiple_times(self): + sem = SlidingWindowSemaphore(1) + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 0) + sem.release('a', num) + + num = sem.acquire('a', blocking=False) + self.assertEqual(num, 1) + sem.release('a', num) + + def test_can_acquire_a_range(self): + sem = SlidingWindowSemaphore(3) + self.assertEqual(sem.acquire('a', blocking=False), 0) + self.assertEqual(sem.acquire('a', blocking=False), 1) + self.assertEqual(sem.acquire('a', blocking=False), 2) + sem.release('a', 0) + sem.release('a', 1) + sem.release('a', 2) + # Now we're reset so we should be able to acquire the same + # sequence again. + self.assertEqual(sem.acquire('a', blocking=False), 3) + self.assertEqual(sem.acquire('a', blocking=False), 4) + self.assertEqual(sem.acquire('a', blocking=False), 5) + self.assertEqual(sem.current_count(), 0) + + def test_counter_release_only_on_min_element(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # The count only increases when we free the min + # element. This means if we're currently failing to + # acquire now: + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + + # Then freeing a non-min element: + sem.release('a', 1) + + # doesn't change anything. We still fail to acquire. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + self.assertEqual(sem.current_count(), 0) + + def test_raises_error_when_count_is_zero(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # Count is now 0 so trying to acquire should fail. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + + def test_release_counters_can_increment_counter_repeatedly(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + # These two releases don't increment the counter + # because we're waiting on 0. + sem.release('a', 1) + sem.release('a', 2) + self.assertEqual(sem.current_count(), 0) + # But as soon as we release 0, we free up 0, 1, and 2. + sem.release('a', 0) + self.assertEqual(sem.current_count(), 3) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + def test_error_to_release_unknown_tag(self): + sem = SlidingWindowSemaphore(3) + with self.assertRaises(ValueError): + sem.release('a', 0) + + def test_can_track_multiple_tags(self): + sem = SlidingWindowSemaphore(3) + self.assertEqual(sem.acquire('a', blocking=False), 0) + self.assertEqual(sem.acquire('b', blocking=False), 0) + self.assertEqual(sem.acquire('a', blocking=False), 1) + + # We're at our max of 3 even though 2 are for A and 1 is for B. + with self.assertRaises(NoResourcesAvailable): + sem.acquire('a', blocking=False) + with self.assertRaises(NoResourcesAvailable): + sem.acquire('b', blocking=False) + + def test_can_handle_multiple_tags_released(self): + sem = SlidingWindowSemaphore(4) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('b', blocking=False) + sem.acquire('b', blocking=False) + + sem.release('b', 1) + sem.release('a', 1) + self.assertEqual(sem.current_count(), 0) + + sem.release('b', 0) + self.assertEqual(sem.acquire('a', blocking=False), 2) + + sem.release('a', 0) + self.assertEqual(sem.acquire('b', blocking=False), 2) + + def test_is_error_to_release_unknown_sequence_number(self): + sem = SlidingWindowSemaphore(3) + sem.acquire('a', blocking=False) + with self.assertRaises(ValueError): + sem.release('a', 1) + + def test_is_error_to_double_release(self): + # This is different than other error tests because + # we're verifying we can reset the state after an + # acquire/release cycle. + sem = SlidingWindowSemaphore(2) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.release('a', 0) + sem.release('a', 1) + self.assertEqual(sem.current_count(), 2) + with self.assertRaises(ValueError): + sem.release('a', 0) + + def test_can_check_in_partial_range(self): + sem = SlidingWindowSemaphore(4) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + sem.release('a', 1) + sem.release('a', 3) + sem.release('a', 0) + self.assertEqual(sem.current_count(), 2) + + +class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase): + # These tests focus on mutithreaded properties of the range + # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore. + def setUp(self): + self.threads = [] + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + self.threads = [] + + def start_threads(self): + for thread in self.threads: + thread.start() + + def test_acquire_blocks_until_release_is_called(self): + sem = SlidingWindowSemaphore(2) + sem.acquire('a', blocking=False) + sem.acquire('a', blocking=False) + + def acquire(): + # This next call to acquire will block. + self.assertEqual(sem.acquire('a', blocking=True), 2) + + t = threading.Thread(target=acquire) + self.threads.append(t) + # Starting the thread will block the sem.acquire() + # in the acquire function above. + t.start() + # This still will keep the thread blocked. + sem.release('a', 1) + # Releasing the min element will unblock the thread. + sem.release('a', 0) + t.join() + sem.release('a', 2) + + def test_stress_invariants_random_order(self): + sem = SlidingWindowSemaphore(100) + for _ in range(10): + recorded = [] + for _ in range(100): + recorded.append(sem.acquire('a', blocking=False)) + # Release them in randomized order. As long as we + # eventually free all 100, we should have all the + # resources released. + random.shuffle(recorded) + for i in recorded: + sem.release('a', i) + + # Everything's freed so should be back at count == 100 + self.assertEqual(sem.current_count(), 100) + + def test_blocking_stress(self): + sem = SlidingWindowSemaphore(5) + num_threads = 10 + num_iterations = 50 + + def acquire(): + for _ in range(num_iterations): + num = sem.acquire('a', blocking=True) + time.sleep(0.001) + sem.release('a', num) + + for i in range(num_threads): + t = threading.Thread(target=acquire) + self.threads.append(t) + self.start_threads() + self.join_threads() + # Should have all the available resources freed. + self.assertEqual(sem.current_count(), 5) + # Should have acquired num_threads * num_iterations + self.assertEqual( + sem.acquire('a', blocking=False), num_threads * num_iterations + ) + + +class TestAdjustChunksize(unittest.TestCase): + def setUp(self): + self.adjuster = ChunksizeAdjuster() + + def test_valid_chunksize(self): + chunksize = 7 * (1024 ** 2) + file_size = 8 * (1024 ** 2) + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, chunksize) + + def test_chunksize_below_minimum(self): + chunksize = MIN_UPLOAD_CHUNKSIZE - 1 + file_size = 3 * MIN_UPLOAD_CHUNKSIZE + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) + + def test_chunksize_above_maximum(self): + chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 + file_size = MAX_SINGLE_UPLOAD_SIZE * 2 + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) + + def test_chunksize_too_small(self): + chunksize = 7 * (1024 ** 2) + file_size = 5 * (1024 ** 4) + # If we try to upload a 5TB file, we'll need to use 896MB part + # sizes. + new_size = self.adjuster.adjust_chunksize(chunksize, file_size) + self.assertEqual(new_size, 896 * (1024 ** 2)) + num_parts = file_size / new_size + self.assertLessEqual(num_parts, MAX_PARTS) + + def test_unknown_file_size_with_valid_chunksize(self): + chunksize = 7 * (1024 ** 2) + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, chunksize) + + def test_unknown_file_size_below_minimum(self): + chunksize = MIN_UPLOAD_CHUNKSIZE - 1 + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE) + + def test_unknown_file_size_above_maximum(self): + chunksize = MAX_SINGLE_UPLOAD_SIZE + 1 + new_size = self.adjuster.adjust_chunksize(chunksize) + self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE) diff --git a/contrib/python/s3transfer/py3/tests/ya.make b/contrib/python/s3transfer/py3/tests/ya.make index 2fdeaf8ca2..fdbf22b0c5 100644 --- a/contrib/python/s3transfer/py3/tests/ya.make +++ b/contrib/python/s3transfer/py3/tests/ya.make @@ -1,44 +1,44 @@ -PY3TEST() - -OWNER(g:python-contrib) - -SIZE(MEDIUM) - -FORK_TESTS() - -PEERDIR( - contrib/python/mock - contrib/python/s3transfer -) - -TEST_SRCS( - functional/__init__.py - functional/test_copy.py - functional/test_crt.py - functional/test_delete.py - functional/test_download.py - functional/test_manager.py - functional/test_processpool.py - functional/test_upload.py - functional/test_utils.py - __init__.py - unit/__init__.py - unit/test_bandwidth.py - unit/test_compat.py - unit/test_copies.py - unit/test_crt.py - unit/test_delete.py - unit/test_download.py - unit/test_futures.py - unit/test_manager.py - unit/test_processpool.py - unit/test_s3transfer.py - unit/test_subscribers.py - unit/test_tasks.py - unit/test_upload.py - unit/test_utils.py -) - -NO_LINT() - -END() +PY3TEST() + +OWNER(g:python-contrib) + +SIZE(MEDIUM) + +FORK_TESTS() + +PEERDIR( + contrib/python/mock + contrib/python/s3transfer +) + +TEST_SRCS( + functional/__init__.py + functional/test_copy.py + functional/test_crt.py + functional/test_delete.py + functional/test_download.py + functional/test_manager.py + functional/test_processpool.py + functional/test_upload.py + functional/test_utils.py + __init__.py + unit/__init__.py + unit/test_bandwidth.py + unit/test_compat.py + unit/test_copies.py + unit/test_crt.py + unit/test_delete.py + unit/test_download.py + unit/test_futures.py + unit/test_manager.py + unit/test_processpool.py + unit/test_s3transfer.py + unit/test_subscribers.py + unit/test_tasks.py + unit/test_upload.py + unit/test_utils.py +) + +NO_LINT() + +END() diff --git a/contrib/python/s3transfer/py3/ya.make b/contrib/python/s3transfer/py3/ya.make index 73a52aafa9..964a630639 100644 --- a/contrib/python/s3transfer/py3/ya.make +++ b/contrib/python/s3transfer/py3/ya.make @@ -1,51 +1,51 @@ -# Generated by devtools/yamaker (pypi). - -PY3_LIBRARY() +# Generated by devtools/yamaker (pypi). -OWNER(gebetix g:python-contrib) +PY3_LIBRARY() -VERSION(0.5.1) +OWNER(gebetix g:python-contrib) -LICENSE(Apache-2.0) +VERSION(0.5.1) + +LICENSE(Apache-2.0) PEERDIR( contrib/python/botocore ) -NO_LINT() - -NO_CHECK_IMPORTS( - s3transfer.crt -) - +NO_LINT() + +NO_CHECK_IMPORTS( + s3transfer.crt +) + PY_SRCS( TOP_LEVEL s3transfer/__init__.py s3transfer/bandwidth.py s3transfer/compat.py - s3transfer/constants.py - s3transfer/copies.py - s3transfer/crt.py + s3transfer/constants.py + s3transfer/copies.py + s3transfer/crt.py s3transfer/delete.py s3transfer/download.py s3transfer/exceptions.py s3transfer/futures.py s3transfer/manager.py - s3transfer/processpool.py + s3transfer/processpool.py s3transfer/subscribers.py s3transfer/tasks.py s3transfer/upload.py s3transfer/utils.py ) -RESOURCE_FILES( - PREFIX contrib/python/s3transfer/py3/ - .dist-info/METADATA - .dist-info/top_level.txt -) - +RESOURCE_FILES( + PREFIX contrib/python/s3transfer/py3/ + .dist-info/METADATA + .dist-info/top_level.txt +) + END() - -RECURSE_FOR_TESTS( - tests -) + +RECURSE_FOR_TESTS( + tests +) |