aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/s3transfer/py3
diff options
context:
space:
mode:
authorshadchin <shadchin@yandex-team.ru>2022-02-10 16:44:39 +0300
committerDaniil Cherednik <dcherednik@yandex-team.ru>2022-02-10 16:44:39 +0300
commite9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch)
tree64175d5cadab313b3e7039ebaa06c5bc3295e274 /contrib/python/s3transfer/py3
parent2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff)
downloadydb-e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0.tar.gz
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 2 of 2.
Diffstat (limited to 'contrib/python/s3transfer/py3')
-rw-r--r--contrib/python/s3transfer/py3/.dist-info/METADATA84
-rw-r--r--contrib/python/s3transfer/py3/.dist-info/top_level.txt2
-rw-r--r--contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml4
-rw-r--r--contrib/python/s3transfer/py3/LICENSE.txt404
-rw-r--r--contrib/python/s3transfer/py3/NOTICE.txt4
-rw-r--r--contrib/python/s3transfer/py3/README.rst26
-rw-r--r--contrib/python/s3transfer/py3/patches/01-fix-tests.patch484
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/__init__.py568
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/bandwidth.py122
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/compat.py46
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/constants.py58
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/copies.py198
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/crt.py1288
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/delete.py6
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/download.py384
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/exceptions.py4
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/futures.py176
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/manager.py344
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/processpool.py2016
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/subscribers.py20
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/tasks.py106
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/upload.py324
-rw-r--r--contrib/python/s3transfer/py3/s3transfer/utils.py312
-rw-r--r--contrib/python/s3transfer/py3/tests/__init__.py1062
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/__init__.py24
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_copy.py1108
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_crt.py534
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_delete.py150
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_download.py994
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_manager.py382
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_processpool.py562
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_upload.py1076
-rw-r--r--contrib/python/s3transfer/py3/tests/functional/test_utils.py82
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/__init__.py24
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py904
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_compat.py210
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_copies.py354
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_crt.py346
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_delete.py132
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_download.py1998
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_futures.py1392
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_manager.py286
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_processpool.py1456
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py1560
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_subscribers.py182
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_tasks.py1666
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_upload.py1388
-rw-r--r--contrib/python/s3transfer/py3/tests/unit/test_utils.py2378
-rw-r--r--contrib/python/s3transfer/py3/tests/ya.make88
-rw-r--r--contrib/python/s3transfer/py3/ya.make52
50 files changed, 13685 insertions, 13685 deletions
diff --git a/contrib/python/s3transfer/py3/.dist-info/METADATA b/contrib/python/s3transfer/py3/.dist-info/METADATA
index 2805466e80..7d635068d7 100644
--- a/contrib/python/s3transfer/py3/.dist-info/METADATA
+++ b/contrib/python/s3transfer/py3/.dist-info/METADATA
@@ -1,42 +1,42 @@
-Metadata-Version: 2.1
-Name: s3transfer
-Version: 0.5.1
-Summary: An Amazon S3 Transfer Manager
-Home-page: https://github.com/boto/s3transfer
-Author: Amazon Web Services
-Author-email: kyknapp1@gmail.com
-License: Apache License 2.0
-Platform: UNKNOWN
-Classifier: Development Status :: 3 - Alpha
-Classifier: Intended Audience :: Developers
-Classifier: Natural Language :: English
-Classifier: License :: OSI Approved :: Apache Software License
-Classifier: Programming Language :: Python
-Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.6
-Classifier: Programming Language :: Python :: 3.7
-Classifier: Programming Language :: Python :: 3.8
-Classifier: Programming Language :: Python :: 3.9
-Classifier: Programming Language :: Python :: 3.10
-Requires-Python: >= 3.6
-License-File: LICENSE.txt
-License-File: NOTICE.txt
-Requires-Dist: botocore (<2.0a.0,>=1.12.36)
-Provides-Extra: crt
-Requires-Dist: botocore[crt] (<2.0a.0,>=1.20.29) ; extra == 'crt'
-
-=====================================================
-s3transfer - An Amazon S3 Transfer Manager for Python
-=====================================================
-
-S3transfer is a Python library for managing Amazon S3 transfers.
-This project is maintained and published by Amazon Web Services.
-
-.. note::
-
- This project is not currently GA. If you are planning to use this code in
- production, make sure to lock to a minor version as interfaces may break
- from minor version to minor version. For a basic, stable interface of
- s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__
-
-
+Metadata-Version: 2.1
+Name: s3transfer
+Version: 0.5.1
+Summary: An Amazon S3 Transfer Manager
+Home-page: https://github.com/boto/s3transfer
+Author: Amazon Web Services
+Author-email: kyknapp1@gmail.com
+License: Apache License 2.0
+Platform: UNKNOWN
+Classifier: Development Status :: 3 - Alpha
+Classifier: Intended Audience :: Developers
+Classifier: Natural Language :: English
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Requires-Python: >= 3.6
+License-File: LICENSE.txt
+License-File: NOTICE.txt
+Requires-Dist: botocore (<2.0a.0,>=1.12.36)
+Provides-Extra: crt
+Requires-Dist: botocore[crt] (<2.0a.0,>=1.20.29) ; extra == 'crt'
+
+=====================================================
+s3transfer - An Amazon S3 Transfer Manager for Python
+=====================================================
+
+S3transfer is a Python library for managing Amazon S3 transfers.
+This project is maintained and published by Amazon Web Services.
+
+.. note::
+
+ This project is not currently GA. If you are planning to use this code in
+ production, make sure to lock to a minor version as interfaces may break
+ from minor version to minor version. For a basic, stable interface of
+ s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__
+
+
diff --git a/contrib/python/s3transfer/py3/.dist-info/top_level.txt b/contrib/python/s3transfer/py3/.dist-info/top_level.txt
index 305b80c924..572c6a92fb 100644
--- a/contrib/python/s3transfer/py3/.dist-info/top_level.txt
+++ b/contrib/python/s3transfer/py3/.dist-info/top_level.txt
@@ -1 +1 @@
-s3transfer
+s3transfer
diff --git a/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml
index d5ef0ef5e6..f2f140fb3c 100644
--- a/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml
+++ b/contrib/python/s3transfer/py3/.yandex_meta/yamaker.yaml
@@ -1,2 +1,2 @@
-exclude:
-- tests/integration/.+
+exclude:
+- tests/integration/.+
diff --git a/contrib/python/s3transfer/py3/LICENSE.txt b/contrib/python/s3transfer/py3/LICENSE.txt
index c0fd617439..d645695673 100644
--- a/contrib/python/s3transfer/py3/LICENSE.txt
+++ b/contrib/python/s3transfer/py3/LICENSE.txt
@@ -1,202 +1,202 @@
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/contrib/python/s3transfer/py3/NOTICE.txt b/contrib/python/s3transfer/py3/NOTICE.txt
index 96e3cc3530..3e616fdf0c 100644
--- a/contrib/python/s3transfer/py3/NOTICE.txt
+++ b/contrib/python/s3transfer/py3/NOTICE.txt
@@ -1,2 +1,2 @@
-s3transfer
-Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+s3transfer
+Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
diff --git a/contrib/python/s3transfer/py3/README.rst b/contrib/python/s3transfer/py3/README.rst
index 923be4f461..441029109e 100644
--- a/contrib/python/s3transfer/py3/README.rst
+++ b/contrib/python/s3transfer/py3/README.rst
@@ -1,13 +1,13 @@
-=====================================================
-s3transfer - An Amazon S3 Transfer Manager for Python
-=====================================================
-
-S3transfer is a Python library for managing Amazon S3 transfers.
-This project is maintained and published by Amazon Web Services.
-
-.. note::
-
- This project is not currently GA. If you are planning to use this code in
- production, make sure to lock to a minor version as interfaces may break
- from minor version to minor version. For a basic, stable interface of
- s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__
+=====================================================
+s3transfer - An Amazon S3 Transfer Manager for Python
+=====================================================
+
+S3transfer is a Python library for managing Amazon S3 transfers.
+This project is maintained and published by Amazon Web Services.
+
+.. note::
+
+ This project is not currently GA. If you are planning to use this code in
+ production, make sure to lock to a minor version as interfaces may break
+ from minor version to minor version. For a basic, stable interface of
+ s3transfer, try the interfaces exposed in `boto3 <https://boto3.readthedocs.io/en/latest/guide/s3.html#using-the-transfer-manager>`__
diff --git a/contrib/python/s3transfer/py3/patches/01-fix-tests.patch b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch
index 4f20d019e9..aa8d3fab4e 100644
--- a/contrib/python/s3transfer/py3/patches/01-fix-tests.patch
+++ b/contrib/python/s3transfer/py3/patches/01-fix-tests.patch
@@ -1,242 +1,242 @@
---- contrib/python/s3transfer/py3/tests/functional/test_copy.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_copy.py (working tree)
-@@ -15,7 +15,7 @@ from botocore.stub import Stubber
-
- from s3transfer.manager import TransferConfig, TransferManager
- from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE
--from tests import BaseGeneralInterfaceTest, FileSizeProvider
-+from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider
-
-
- class BaseCopyTest(BaseGeneralInterfaceTest):
---- contrib/python/s3transfer/py3/tests/functional/test_crt.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_crt.py (working tree)
-@@ -18,7 +18,7 @@ from concurrent.futures import Future
- from botocore.session import Session
-
- from s3transfer.subscribers import BaseSubscriber
--from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest
-+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
-
- if HAS_CRT:
- import awscrt
---- contrib/python/s3transfer/py3/tests/functional/test_delete.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_delete.py (working tree)
-@@ -11,7 +11,7 @@
- # ANY KIND, either express or implied. See the License for the specific
- # language governing permissions and limitations under the License.
- from s3transfer.manager import TransferManager
--from tests import BaseGeneralInterfaceTest
-+from __tests__ import BaseGeneralInterfaceTest
-
-
- class TestDeleteObject(BaseGeneralInterfaceTest):
---- contrib/python/s3transfer/py3/tests/functional/test_download.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_download.py (working tree)
-@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError
- from s3transfer.compat import SOCKET_ERROR
- from s3transfer.exceptions import RetriesExceededError
- from s3transfer.manager import TransferConfig, TransferManager
--from tests import (
-+from __tests__ import (
- BaseGeneralInterfaceTest,
- FileSizeProvider,
- NonSeekableWriter,
---- contrib/python/s3transfer/py3/tests/functional/test_manager.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_manager.py (working tree)
-@@ -17,7 +17,7 @@ from botocore.awsrequest import create_request_object
- from s3transfer.exceptions import CancelledError, FatalError
- from s3transfer.futures import BaseExecutor
- from s3transfer.manager import TransferConfig, TransferManager
--from tests import StubbedClientTest, mock, skip_if_using_serial_implementation
-+from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation
-
-
- class ArbitraryException(Exception):
---- contrib/python/s3transfer/py3/tests/functional/test_processpool.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_processpool.py (working tree)
-@@ -21,7 +21,7 @@ from botocore.stub import Stubber
-
- from s3transfer.exceptions import CancelledError
- from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
--from tests import FileCreator, mock, unittest
-+from __tests__ import FileCreator, mock, unittest
-
-
- class StubbedClient:
---- contrib/python/s3transfer/py3/tests/functional/test_upload.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_upload.py (working tree)
-@@ -23,7 +23,7 @@ from botocore.stub import ANY
-
- from s3transfer.manager import TransferConfig, TransferManager
- from s3transfer.utils import ChunksizeAdjuster
--from tests import (
-+from __tests__ import (
- BaseGeneralInterfaceTest,
- NonSeekableReader,
- RecordingOSUtils,
---- contrib/python/s3transfer/py3/tests/functional/test_utils.py (index)
-+++ contrib/python/s3transfer/py3/tests/functional/test_utils.py (working tree)
-@@ -16,7 +16,7 @@ import socket
- import tempfile
-
- from s3transfer.utils import OSUtils
--from tests import skip_if_windows, unittest
-+from __tests__ import skip_if_windows, unittest
-
-
- @skip_if_windows('Windows does not support UNIX special files')
---- contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (working tree)
-@@ -25,7 +25,7 @@ from s3transfer.bandwidth import (
- TimeUtils,
- )
- from s3transfer.futures import TransferCoordinator
--from tests import mock, unittest
-+from __tests__ import mock, unittest
-
-
- class FixedIncrementalTickTimeUtils(TimeUtils):
---- contrib/python/s3transfer/py3/tests/unit/test_compat.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_compat.py (working tree)
-@@ -17,7 +17,7 @@ import tempfile
- from io import BytesIO
-
- from s3transfer.compat import BaseManager, readable, seekable
--from tests import skip_if_windows, unittest
-+from __tests__ import skip_if_windows, unittest
-
-
- class ErrorRaisingSeekWrapper:
---- contrib/python/s3transfer/py3/tests/unit/test_copies.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_copies.py (working tree)
-@@ -11,7 +11,7 @@
- # ANY KIND, either express or implied. See the License for the specific
- # language governing permissions and limitations under the License.
- from s3transfer.copies import CopyObjectTask, CopyPartTask
--from tests import BaseTaskTest, RecordingSubscriber
-+from __tests__ import BaseTaskTest, RecordingSubscriber
-
-
- class BaseCopyTaskTest(BaseTaskTest):
---- contrib/python/s3transfer/py3/tests/unit/test_crt.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_crt.py (working tree)
-@@ -15,7 +15,7 @@ from botocore.session import Session
-
- from s3transfer.exceptions import TransferNotDoneError
- from s3transfer.utils import CallArgs
--from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest
-+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
-
- if HAS_CRT:
- import awscrt.s3
---- contrib/python/s3transfer/py3/tests/unit/test_delete.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_delete.py (working tree)
-@@ -11,7 +11,7 @@
- # ANY KIND, either express or implied. See the License for the specific
- # language governing permissions and limitations under the License.
- from s3transfer.delete import DeleteObjectTask
--from tests import BaseTaskTest
-+from __tests__ import BaseTaskTest
-
-
- class TestDeleteObjectTask(BaseTaskTest):
---- contrib/python/s3transfer/py3/tests/unit/test_download.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_download.py (working tree)
-@@ -37,7 +37,7 @@ from s3transfer.download import (
- from s3transfer.exceptions import RetriesExceededError
- from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
- from s3transfer.utils import CallArgs, OSUtils
--from tests import (
-+from __tests__ import (
- BaseSubmissionTaskTest,
- BaseTaskTest,
- FileCreator,
---- contrib/python/s3transfer/py3/tests/unit/test_futures.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_futures.py (working tree)
-@@ -37,7 +37,7 @@ from s3transfer.utils import (
- NoResourcesAvailable,
- TaskSemaphore,
- )
--from tests import (
-+from __tests__ import (
- RecordingExecutor,
- TransferCoordinatorWithInterrupt,
- mock,
---- contrib/python/s3transfer/py3/tests/unit/test_manager.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_manager.py (working tree)
-@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor
- from s3transfer.exceptions import CancelledError, FatalError
- from s3transfer.futures import TransferCoordinator
- from s3transfer.manager import TransferConfig, TransferCoordinatorController
--from tests import TransferCoordinatorWithInterrupt, unittest
-+from __tests__ import TransferCoordinatorWithInterrupt, unittest
-
-
- class FutureResultException(Exception):
---- contrib/python/s3transfer/py3/tests/unit/test_processpool.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_processpool.py (working tree)
-@@ -39,7 +39,7 @@ from s3transfer.processpool import (
- ignore_ctrl_c,
- )
- from s3transfer.utils import CallArgs, OSUtils
--from tests import (
-+from __tests__ import (
- FileCreator,
- StreamWithError,
- StubbedClientTest,
---- contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (working tree)
-@@ -33,7 +33,7 @@ from s3transfer import (
- random_file_extension,
- )
- from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
--from tests import mock, unittest
-+from __tests__ import mock, unittest
-
-
- class InMemoryOSLayer(OSUtils):
---- contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (working tree)
-@@ -12,7 +12,7 @@
- # language governing permissions and limitations under the License.
- from s3transfer.exceptions import InvalidSubscriberMethodError
- from s3transfer.subscribers import BaseSubscriber
--from tests import unittest
-+from __tests__ import unittest
-
-
- class ExtraMethodsSubscriber(BaseSubscriber):
---- contrib/python/s3transfer/py3/tests/unit/test_tasks.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_tasks.py (working tree)
-@@ -23,7 +23,7 @@ from s3transfer.tasks import (
- Task,
- )
- from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
--from tests import (
-+from __tests__ import (
- BaseSubmissionTaskTest,
- BaseTaskTest,
- RecordingSubscriber,
---- contrib/python/s3transfer/py3/tests/unit/test_upload.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_upload.py (working tree)
-@@ -32,7 +32,7 @@ from s3transfer.upload import (
- UploadSubmissionTask,
- )
- from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
--from tests import (
-+from __tests__ import (
- BaseSubmissionTaskTest,
- BaseTaskTest,
- FileSizeProvider,
---- contrib/python/s3transfer/py3/tests/unit/test_utils.py (index)
-+++ contrib/python/s3transfer/py3/tests/unit/test_utils.py (working tree)
-@@ -43,7 +43,7 @@ from s3transfer.utils import (
- invoke_progress_callbacks,
- random_file_extension,
- )
--from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest
-+from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
-
-
- class TestGetCallbacks(unittest.TestCase):
+--- contrib/python/s3transfer/py3/tests/functional/test_copy.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_copy.py (working tree)
+@@ -15,7 +15,7 @@ from botocore.stub import Stubber
+
+ from s3transfer.manager import TransferConfig, TransferManager
+ from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE
+-from tests import BaseGeneralInterfaceTest, FileSizeProvider
++from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider
+
+
+ class BaseCopyTest(BaseGeneralInterfaceTest):
+--- contrib/python/s3transfer/py3/tests/functional/test_crt.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_crt.py (working tree)
+@@ -18,7 +18,7 @@ from concurrent.futures import Future
+ from botocore.session import Session
+
+ from s3transfer.subscribers import BaseSubscriber
+-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest
++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+ if HAS_CRT:
+ import awscrt
+--- contrib/python/s3transfer/py3/tests/functional/test_delete.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_delete.py (working tree)
+@@ -11,7 +11,7 @@
+ # ANY KIND, either express or implied. See the License for the specific
+ # language governing permissions and limitations under the License.
+ from s3transfer.manager import TransferManager
+-from tests import BaseGeneralInterfaceTest
++from __tests__ import BaseGeneralInterfaceTest
+
+
+ class TestDeleteObject(BaseGeneralInterfaceTest):
+--- contrib/python/s3transfer/py3/tests/functional/test_download.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_download.py (working tree)
+@@ -23,7 +23,7 @@ from botocore.exceptions import ClientError
+ from s3transfer.compat import SOCKET_ERROR
+ from s3transfer.exceptions import RetriesExceededError
+ from s3transfer.manager import TransferConfig, TransferManager
+-from tests import (
++from __tests__ import (
+ BaseGeneralInterfaceTest,
+ FileSizeProvider,
+ NonSeekableWriter,
+--- contrib/python/s3transfer/py3/tests/functional/test_manager.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_manager.py (working tree)
+@@ -17,7 +17,7 @@ from botocore.awsrequest import create_request_object
+ from s3transfer.exceptions import CancelledError, FatalError
+ from s3transfer.futures import BaseExecutor
+ from s3transfer.manager import TransferConfig, TransferManager
+-from tests import StubbedClientTest, mock, skip_if_using_serial_implementation
++from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation
+
+
+ class ArbitraryException(Exception):
+--- contrib/python/s3transfer/py3/tests/functional/test_processpool.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_processpool.py (working tree)
+@@ -21,7 +21,7 @@ from botocore.stub import Stubber
+
+ from s3transfer.exceptions import CancelledError
+ from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
+-from tests import FileCreator, mock, unittest
++from __tests__ import FileCreator, mock, unittest
+
+
+ class StubbedClient:
+--- contrib/python/s3transfer/py3/tests/functional/test_upload.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_upload.py (working tree)
+@@ -23,7 +23,7 @@ from botocore.stub import ANY
+
+ from s3transfer.manager import TransferConfig, TransferManager
+ from s3transfer.utils import ChunksizeAdjuster
+-from tests import (
++from __tests__ import (
+ BaseGeneralInterfaceTest,
+ NonSeekableReader,
+ RecordingOSUtils,
+--- contrib/python/s3transfer/py3/tests/functional/test_utils.py (index)
++++ contrib/python/s3transfer/py3/tests/functional/test_utils.py (working tree)
+@@ -16,7 +16,7 @@ import socket
+ import tempfile
+
+ from s3transfer.utils import OSUtils
+-from tests import skip_if_windows, unittest
++from __tests__ import skip_if_windows, unittest
+
+
+ @skip_if_windows('Windows does not support UNIX special files')
+--- contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py (working tree)
+@@ -25,7 +25,7 @@ from s3transfer.bandwidth import (
+ TimeUtils,
+ )
+ from s3transfer.futures import TransferCoordinator
+-from tests import mock, unittest
++from __tests__ import mock, unittest
+
+
+ class FixedIncrementalTickTimeUtils(TimeUtils):
+--- contrib/python/s3transfer/py3/tests/unit/test_compat.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_compat.py (working tree)
+@@ -17,7 +17,7 @@ import tempfile
+ from io import BytesIO
+
+ from s3transfer.compat import BaseManager, readable, seekable
+-from tests import skip_if_windows, unittest
++from __tests__ import skip_if_windows, unittest
+
+
+ class ErrorRaisingSeekWrapper:
+--- contrib/python/s3transfer/py3/tests/unit/test_copies.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_copies.py (working tree)
+@@ -11,7 +11,7 @@
+ # ANY KIND, either express or implied. See the License for the specific
+ # language governing permissions and limitations under the License.
+ from s3transfer.copies import CopyObjectTask, CopyPartTask
+-from tests import BaseTaskTest, RecordingSubscriber
++from __tests__ import BaseTaskTest, RecordingSubscriber
+
+
+ class BaseCopyTaskTest(BaseTaskTest):
+--- contrib/python/s3transfer/py3/tests/unit/test_crt.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_crt.py (working tree)
+@@ -15,7 +15,7 @@ from botocore.session import Session
+
+ from s3transfer.exceptions import TransferNotDoneError
+ from s3transfer.utils import CallArgs
+-from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest
++from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+ if HAS_CRT:
+ import awscrt.s3
+--- contrib/python/s3transfer/py3/tests/unit/test_delete.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_delete.py (working tree)
+@@ -11,7 +11,7 @@
+ # ANY KIND, either express or implied. See the License for the specific
+ # language governing permissions and limitations under the License.
+ from s3transfer.delete import DeleteObjectTask
+-from tests import BaseTaskTest
++from __tests__ import BaseTaskTest
+
+
+ class TestDeleteObjectTask(BaseTaskTest):
+--- contrib/python/s3transfer/py3/tests/unit/test_download.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_download.py (working tree)
+@@ -37,7 +37,7 @@ from s3transfer.download import (
+ from s3transfer.exceptions import RetriesExceededError
+ from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
+ from s3transfer.utils import CallArgs, OSUtils
+-from tests import (
++from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileCreator,
+--- contrib/python/s3transfer/py3/tests/unit/test_futures.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_futures.py (working tree)
+@@ -37,7 +37,7 @@ from s3transfer.utils import (
+ NoResourcesAvailable,
+ TaskSemaphore,
+ )
+-from tests import (
++from __tests__ import (
+ RecordingExecutor,
+ TransferCoordinatorWithInterrupt,
+ mock,
+--- contrib/python/s3transfer/py3/tests/unit/test_manager.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_manager.py (working tree)
+@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor
+ from s3transfer.exceptions import CancelledError, FatalError
+ from s3transfer.futures import TransferCoordinator
+ from s3transfer.manager import TransferConfig, TransferCoordinatorController
+-from tests import TransferCoordinatorWithInterrupt, unittest
++from __tests__ import TransferCoordinatorWithInterrupt, unittest
+
+
+ class FutureResultException(Exception):
+--- contrib/python/s3transfer/py3/tests/unit/test_processpool.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_processpool.py (working tree)
+@@ -39,7 +39,7 @@ from s3transfer.processpool import (
+ ignore_ctrl_c,
+ )
+ from s3transfer.utils import CallArgs, OSUtils
+-from tests import (
++from __tests__ import (
+ FileCreator,
+ StreamWithError,
+ StubbedClientTest,
+--- contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py (working tree)
+@@ -33,7 +33,7 @@ from s3transfer import (
+ random_file_extension,
+ )
+ from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
+-from tests import mock, unittest
++from __tests__ import mock, unittest
+
+
+ class InMemoryOSLayer(OSUtils):
+--- contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_subscribers.py (working tree)
+@@ -12,7 +12,7 @@
+ # language governing permissions and limitations under the License.
+ from s3transfer.exceptions import InvalidSubscriberMethodError
+ from s3transfer.subscribers import BaseSubscriber
+-from tests import unittest
++from __tests__ import unittest
+
+
+ class ExtraMethodsSubscriber(BaseSubscriber):
+--- contrib/python/s3transfer/py3/tests/unit/test_tasks.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_tasks.py (working tree)
+@@ -23,7 +23,7 @@ from s3transfer.tasks import (
+ Task,
+ )
+ from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
+-from tests import (
++from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ RecordingSubscriber,
+--- contrib/python/s3transfer/py3/tests/unit/test_upload.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_upload.py (working tree)
+@@ -32,7 +32,7 @@ from s3transfer.upload import (
+ UploadSubmissionTask,
+ )
+ from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
+-from tests import (
++from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileSizeProvider,
+--- contrib/python/s3transfer/py3/tests/unit/test_utils.py (index)
++++ contrib/python/s3transfer/py3/tests/unit/test_utils.py (working tree)
+@@ -43,7 +43,7 @@ from s3transfer.utils import (
+ invoke_progress_callbacks,
+ random_file_extension,
+ )
+-from tests import NonSeekableWriter, RecordingSubscriber, mock, unittest
++from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
+
+
+ class TestGetCallbacks(unittest.TestCase):
diff --git a/contrib/python/s3transfer/py3/s3transfer/__init__.py b/contrib/python/s3transfer/py3/s3transfer/__init__.py
index cb2591e1cd..1a749c712e 100644
--- a/contrib/python/s3transfer/py3/s3transfer/__init__.py
+++ b/contrib/python/s3transfer/py3/s3transfer/__init__.py
@@ -72,7 +72,7 @@ client operation. Here are a few examples using ``upload_file``::
extra_args={'ContentType': "application/json"})
-The ``S3Transfer`` class also supports progress callbacks so you can
+The ``S3Transfer`` class also supports progress callbacks so you can
provide transfer progress to users. Both the ``upload_file`` and
``download_file`` methods take an optional ``callback`` parameter.
Here's an example of how to print a simple progress percentage
@@ -123,28 +123,28 @@ transfer. For example:
"""
-import concurrent.futures
+import concurrent.futures
import functools
import logging
-import math
-import os
-import queue
-import random
+import math
+import os
+import queue
+import random
import socket
-import string
+import string
import threading
-from botocore.compat import six # noqa: F401
+from botocore.compat import six # noqa: F401
from botocore.exceptions import IncompleteReadError
-from botocore.vendored.requests.packages.urllib3.exceptions import (
- ReadTimeoutError,
-)
+from botocore.vendored.requests.packages.urllib3.exceptions import (
+ ReadTimeoutError,
+)
import s3transfer.compat
from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
__author__ = 'Amazon Web Services'
-__version__ = '0.5.1'
+__version__ = '0.5.1'
class NullHandler(logging.Handler):
@@ -164,16 +164,16 @@ def random_file_extension(num_digits=8):
def disable_upload_callbacks(request, operation_name, **kwargs):
- if operation_name in ['PutObject', 'UploadPart'] and hasattr(
- request.body, 'disable_callback'
- ):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'disable_callback'
+ ):
request.body.disable_callback()
def enable_upload_callbacks(request, operation_name, **kwargs):
- if operation_name in ['PutObject', 'UploadPart'] and hasattr(
- request.body, 'enable_callback'
- ):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'enable_callback'
+ ):
request.body.enable_callback()
@@ -181,16 +181,16 @@ class QueueShutdownError(Exception):
pass
-class ReadFileChunk:
- def __init__(
- self,
- fileobj,
- start_byte,
- chunk_size,
- full_file_size,
- callback=None,
- enable_callback=True,
- ):
+class ReadFileChunk:
+ def __init__(
+ self,
+ fileobj,
+ start_byte,
+ chunk_size,
+ full_file_size,
+ callback=None,
+ enable_callback=True,
+ ):
"""
Given a file object shown below:
@@ -222,25 +222,25 @@ class ReadFileChunk:
self._fileobj = fileobj
self._start_byte = start_byte
self._size = self._calculate_file_size(
- self._fileobj,
- requested_size=chunk_size,
- start_byte=start_byte,
- actual_file_size=full_file_size,
- )
+ self._fileobj,
+ requested_size=chunk_size,
+ start_byte=start_byte,
+ actual_file_size=full_file_size,
+ )
self._fileobj.seek(self._start_byte)
self._amount_read = 0
self._callback = callback
self._callback_enabled = enable_callback
@classmethod
- def from_filename(
- cls,
- filename,
- start_byte,
- chunk_size,
- callback=None,
- enable_callback=True,
- ):
+ def from_filename(
+ cls,
+ filename,
+ start_byte,
+ chunk_size,
+ callback=None,
+ enable_callback=True,
+ ):
"""Convenience factory function to create from a filename.
:type start_byte: int
@@ -268,13 +268,13 @@ class ReadFileChunk:
"""
f = open(filename, 'rb')
file_size = os.fstat(f.fileno()).st_size
- return cls(
- f, start_byte, chunk_size, file_size, callback, enable_callback
- )
+ return cls(
+ f, start_byte, chunk_size, file_size, callback, enable_callback
+ )
- def _calculate_file_size(
- self, fileobj, requested_size, start_byte, actual_file_size
- ):
+ def _calculate_file_size(
+ self, fileobj, requested_size, start_byte, actual_file_size
+ ):
max_chunk_size = actual_file_size - start_byte
return min(max_chunk_size, requested_size)
@@ -331,9 +331,9 @@ class ReadFileChunk:
return iter([])
-class StreamReaderProgress:
+class StreamReaderProgress:
"""Wrapper for a read only stream that adds progress callbacks."""
-
+
def __init__(self, stream, callback=None):
self._stream = stream
self._callback = callback
@@ -345,14 +345,14 @@ class StreamReaderProgress:
return value
-class OSUtils:
+class OSUtils:
def get_file_size(self, filename):
return os.path.getsize(filename)
def open_file_chunk_reader(self, filename, start_byte, size, callback):
- return ReadFileChunk.from_filename(
- filename, start_byte, size, callback, enable_callback=False
- )
+ return ReadFileChunk.from_filename(
+ filename, start_byte, size, callback, enable_callback=False
+ )
def open(self, filename, mode):
return open(filename, mode)
@@ -370,7 +370,7 @@ class OSUtils:
s3transfer.compat.rename_file(current_filename, new_filename)
-class MultipartUploader:
+class MultipartUploader:
# These are the extra_args that need to be forwarded onto
# subsequent upload_parts.
UPLOAD_PART_ARGS = [
@@ -380,13 +380,13 @@ class MultipartUploader:
'RequestPayer',
]
- def __init__(
- self,
- client,
- config,
- osutil,
- executor_cls=concurrent.futures.ThreadPoolExecutor,
- ):
+ def __init__(
+ self,
+ client,
+ config,
+ osutil,
+ executor_cls=concurrent.futures.ThreadPoolExecutor,
+ ):
self._client = client
self._config = config
self._os = osutil
@@ -402,83 +402,83 @@ class MultipartUploader:
return upload_parts_args
def upload_file(self, filename, bucket, key, callback, extra_args):
- response = self._client.create_multipart_upload(
- Bucket=bucket, Key=key, **extra_args
- )
+ response = self._client.create_multipart_upload(
+ Bucket=bucket, Key=key, **extra_args
+ )
upload_id = response['UploadId']
try:
- parts = self._upload_parts(
- upload_id, filename, bucket, key, callback, extra_args
- )
+ parts = self._upload_parts(
+ upload_id, filename, bucket, key, callback, extra_args
+ )
except Exception as e:
- logger.debug(
- "Exception raised while uploading parts, "
- "aborting multipart upload.",
- exc_info=True,
- )
+ logger.debug(
+ "Exception raised while uploading parts, "
+ "aborting multipart upload.",
+ exc_info=True,
+ )
self._client.abort_multipart_upload(
- Bucket=bucket, Key=key, UploadId=upload_id
- )
+ Bucket=bucket, Key=key, UploadId=upload_id
+ )
raise S3UploadFailedError(
- "Failed to upload {} to {}: {}".format(
- filename, '/'.join([bucket, key]), e
- )
- )
+ "Failed to upload {} to {}: {}".format(
+ filename, '/'.join([bucket, key]), e
+ )
+ )
self._client.complete_multipart_upload(
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
- MultipartUpload={'Parts': parts},
- )
-
- def _upload_parts(
- self, upload_id, filename, bucket, key, callback, extra_args
- ):
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ MultipartUpload={'Parts': parts},
+ )
+
+ def _upload_parts(
+ self, upload_id, filename, bucket, key, callback, extra_args
+ ):
upload_parts_extra_args = self._extra_upload_part_args(extra_args)
parts = []
part_size = self._config.multipart_chunksize
num_parts = int(
- math.ceil(self._os.get_file_size(filename) / float(part_size))
- )
+ math.ceil(self._os.get_file_size(filename) / float(part_size))
+ )
max_workers = self._config.max_concurrency
with self._executor_cls(max_workers=max_workers) as executor:
upload_partial = functools.partial(
- self._upload_one_part,
- filename,
- bucket,
- key,
- upload_id,
- part_size,
- upload_parts_extra_args,
- callback,
- )
+ self._upload_one_part,
+ filename,
+ bucket,
+ key,
+ upload_id,
+ part_size,
+ upload_parts_extra_args,
+ callback,
+ )
for part in executor.map(upload_partial, range(1, num_parts + 1)):
parts.append(part)
return parts
- def _upload_one_part(
- self,
- filename,
- bucket,
- key,
- upload_id,
- part_size,
- extra_args,
- callback,
- part_number,
- ):
+ def _upload_one_part(
+ self,
+ filename,
+ bucket,
+ key,
+ upload_id,
+ part_size,
+ extra_args,
+ callback,
+ part_number,
+ ):
open_chunk_reader = self._os.open_file_chunk_reader
- with open_chunk_reader(
- filename, part_size * (part_number - 1), part_size, callback
- ) as body:
+ with open_chunk_reader(
+ filename, part_size * (part_number - 1), part_size, callback
+ ) as body:
response = self._client.upload_part(
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
- PartNumber=part_number,
- Body=body,
- **extra_args,
- )
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ PartNumber=part_number,
+ Body=body,
+ **extra_args,
+ )
etag = response['ETag']
return {'ETag': etag, 'PartNumber': part_number}
@@ -494,7 +494,7 @@ class ShutdownQueue(queue.Queue):
to be a drop in replacement for ``queue.Queue``.
"""
-
+
def _init(self, maxsize):
self._shutdown = False
self._shutdown_lock = threading.Lock()
@@ -511,50 +511,50 @@ class ShutdownQueue(queue.Queue):
# Need to hook into the condition vars used by this class.
with self._shutdown_lock:
if self._shutdown:
- raise QueueShutdownError(
- "Cannot put item to queue when " "queue has been shutdown."
- )
+ raise QueueShutdownError(
+ "Cannot put item to queue when " "queue has been shutdown."
+ )
return queue.Queue.put(self, item)
-class MultipartDownloader:
- def __init__(
- self,
- client,
- config,
- osutil,
- executor_cls=concurrent.futures.ThreadPoolExecutor,
- ):
+class MultipartDownloader:
+ def __init__(
+ self,
+ client,
+ config,
+ osutil,
+ executor_cls=concurrent.futures.ThreadPoolExecutor,
+ ):
self._client = client
self._config = config
self._os = osutil
self._executor_cls = executor_cls
self._ioqueue = ShutdownQueue(self._config.max_io_queue)
- def download_file(
- self, bucket, key, filename, object_size, extra_args, callback=None
- ):
+ def download_file(
+ self, bucket, key, filename, object_size, extra_args, callback=None
+ ):
with self._executor_cls(max_workers=2) as controller:
# 1 thread for the future that manages the uploading of files
# 1 thread for the future that manages IO writes.
download_parts_handler = functools.partial(
self._download_file_as_future,
- bucket,
- key,
- filename,
- object_size,
- callback,
- )
+ bucket,
+ key,
+ filename,
+ object_size,
+ callback,
+ )
parts_future = controller.submit(download_parts_handler)
io_writes_handler = functools.partial(
- self._perform_io_writes, filename
- )
+ self._perform_io_writes, filename
+ )
io_future = controller.submit(io_writes_handler)
results = concurrent.futures.wait(
[parts_future, io_future],
- return_when=concurrent.futures.FIRST_EXCEPTION,
- )
+ return_when=concurrent.futures.FIRST_EXCEPTION,
+ )
self._process_future_results(results)
def _process_future_results(self, futures):
@@ -562,21 +562,21 @@ class MultipartDownloader:
for future in finished:
future.result()
- def _download_file_as_future(
- self, bucket, key, filename, object_size, callback
- ):
+ def _download_file_as_future(
+ self, bucket, key, filename, object_size, callback
+ ):
part_size = self._config.multipart_chunksize
num_parts = int(math.ceil(object_size / float(part_size)))
max_workers = self._config.max_concurrency
download_partial = functools.partial(
- self._download_range,
- bucket,
- key,
- filename,
- part_size,
- num_parts,
- callback,
- )
+ self._download_range,
+ bucket,
+ key,
+ filename,
+ part_size,
+ num_parts,
+ callback,
+ )
try:
with self._executor_cls(max_workers=max_workers) as executor:
list(executor.map(download_partial, range(num_parts)))
@@ -589,16 +589,16 @@ class MultipartDownloader:
end_range = ''
else:
end_range = start_range + part_size - 1
- range_param = f'bytes={start_range}-{end_range}'
+ range_param = f'bytes={start_range}-{end_range}'
return range_param
- def _download_range(
- self, bucket, key, filename, part_size, num_parts, callback, part_index
- ):
+ def _download_range(
+ self, bucket, key, filename, part_size, num_parts, callback, part_index
+ ):
try:
range_param = self._calculate_range_param(
- part_size, part_index, num_parts
- )
+ part_size, part_index, num_parts
+ )
max_attempts = self._config.num_download_attempts
last_exception = None
@@ -606,33 +606,33 @@ class MultipartDownloader:
try:
logger.debug("Making get_object call.")
response = self._client.get_object(
- Bucket=bucket, Key=key, Range=range_param
- )
+ Bucket=bucket, Key=key, Range=range_param
+ )
streaming_body = StreamReaderProgress(
- response['Body'], callback
- )
+ response['Body'], callback
+ )
buffer_size = 1024 * 16
current_index = part_size * part_index
- for chunk in iter(
- lambda: streaming_body.read(buffer_size), b''
- ):
+ for chunk in iter(
+ lambda: streaming_body.read(buffer_size), b''
+ ):
self._ioqueue.put((current_index, chunk))
current_index += len(chunk)
return
- except (
- socket.timeout,
- OSError,
- ReadTimeoutError,
- IncompleteReadError,
- ) as e:
- logger.debug(
- "Retrying exception caught (%s), "
- "retrying request, (attempt %s / %s)",
- e,
- i,
- max_attempts,
- exc_info=True,
- )
+ except (
+ socket.timeout,
+ OSError,
+ ReadTimeoutError,
+ IncompleteReadError,
+ ) as e:
+ logger.debug(
+ "Retrying exception caught (%s), "
+ "retrying request, (attempt %s / %s)",
+ e,
+ i,
+ max_attempts,
+ exc_info=True,
+ )
last_exception = e
continue
raise RetriesExceededError(last_exception)
@@ -644,10 +644,10 @@ class MultipartDownloader:
while True:
task = self._ioqueue.get()
if task is SHUTDOWN_SENTINEL:
- logger.debug(
- "Shutdown sentinel received in IO handler, "
- "shutting down IO handler."
- )
+ logger.debug(
+ "Shutdown sentinel received in IO handler, "
+ "shutting down IO handler."
+ )
return
else:
try:
@@ -655,24 +655,24 @@ class MultipartDownloader:
f.seek(offset)
f.write(data)
except Exception as e:
- logger.debug(
- "Caught exception in IO thread: %s",
- e,
- exc_info=True,
- )
+ logger.debug(
+ "Caught exception in IO thread: %s",
+ e,
+ exc_info=True,
+ )
self._ioqueue.trigger_shutdown()
raise
-class TransferConfig:
- def __init__(
- self,
- multipart_threshold=8 * MB,
- max_concurrency=10,
- multipart_chunksize=8 * MB,
- num_download_attempts=5,
- max_io_queue=100,
- ):
+class TransferConfig:
+ def __init__(
+ self,
+ multipart_threshold=8 * MB,
+ max_concurrency=10,
+ multipart_chunksize=8 * MB,
+ num_download_attempts=5,
+ max_io_queue=100,
+ ):
self.multipart_threshold = multipart_threshold
self.max_concurrency = max_concurrency
self.multipart_chunksize = multipart_chunksize
@@ -680,7 +680,7 @@ class TransferConfig:
self.max_io_queue = max_io_queue
-class S3Transfer:
+class S3Transfer:
ALLOWED_DOWNLOAD_ARGS = [
'VersionId',
@@ -710,8 +710,8 @@ class S3Transfer:
'SSECustomerKey',
'SSECustomerKeyMD5',
'SSEKMSKeyId',
- 'SSEKMSEncryptionContext',
- 'Tagging',
+ 'SSEKMSEncryptionContext',
+ 'Tagging',
]
def __init__(self, client, config=None, osutil=None):
@@ -723,9 +723,9 @@ class S3Transfer:
osutil = OSUtils()
self._osutil = osutil
- def upload_file(
- self, filename, bucket, key, callback=None, extra_args=None
- ):
+ def upload_file(
+ self, filename, bucket, key, callback=None, extra_args=None
+ ):
"""Upload a file to an S3 object.
Variants have also been injected into S3 client, Bucket and Object.
@@ -735,20 +735,20 @@ class S3Transfer:
extra_args = {}
self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS)
events = self._client.meta.events
- events.register_first(
- 'request-created.s3',
- disable_upload_callbacks,
- unique_id='s3upload-callback-disable',
- )
- events.register_last(
- 'request-created.s3',
- enable_upload_callbacks,
- unique_id='s3upload-callback-enable',
- )
- if (
- self._osutil.get_file_size(filename)
- >= self._config.multipart_threshold
- ):
+ events.register_first(
+ 'request-created.s3',
+ disable_upload_callbacks,
+ unique_id='s3upload-callback-disable',
+ )
+ events.register_last(
+ 'request-created.s3',
+ enable_upload_callbacks,
+ unique_id='s3upload-callback-enable',
+ )
+ if (
+ self._osutil.get_file_size(filename)
+ >= self._config.multipart_threshold
+ ):
self._multipart_upload(filename, bucket, key, callback, extra_args)
else:
self._put_object(filename, bucket, key, callback, extra_args)
@@ -757,19 +757,19 @@ class S3Transfer:
# We're using open_file_chunk_reader so we can take advantage of the
# progress callback functionality.
open_chunk_reader = self._osutil.open_file_chunk_reader
- with open_chunk_reader(
- filename,
- 0,
- self._osutil.get_file_size(filename),
- callback=callback,
- ) as body:
- self._client.put_object(
- Bucket=bucket, Key=key, Body=body, **extra_args
- )
-
- def download_file(
- self, bucket, key, filename, extra_args=None, callback=None
- ):
+ with open_chunk_reader(
+ filename,
+ 0,
+ self._osutil.get_file_size(filename),
+ callback=callback,
+ ) as body:
+ self._client.put_object(
+ Bucket=bucket, Key=key, Body=body, **extra_args
+ )
+
+ def download_file(
+ self, bucket, key, filename, extra_args=None, callback=None
+ ):
"""Download an S3 object to a file.
Variants have also been injected into S3 client, Bucket and Object.
@@ -784,28 +784,28 @@ class S3Transfer:
object_size = self._object_size(bucket, key, extra_args)
temp_filename = filename + os.extsep + random_file_extension()
try:
- self._download_file(
- bucket, key, temp_filename, object_size, extra_args, callback
- )
+ self._download_file(
+ bucket, key, temp_filename, object_size, extra_args, callback
+ )
except Exception:
- logger.debug(
- "Exception caught in download_file, removing partial "
- "file: %s",
- temp_filename,
- exc_info=True,
- )
+ logger.debug(
+ "Exception caught in download_file, removing partial "
+ "file: %s",
+ temp_filename,
+ exc_info=True,
+ )
self._osutil.remove_file(temp_filename)
raise
else:
self._osutil.rename_file(temp_filename, filename)
- def _download_file(
- self, bucket, key, filename, object_size, extra_args, callback
- ):
+ def _download_file(
+ self, bucket, key, filename, object_size, extra_args, callback
+ ):
if object_size >= self._config.multipart_threshold:
- self._ranged_download(
- bucket, key, filename, object_size, extra_args, callback
- )
+ self._ranged_download(
+ bucket, key, filename, object_size, extra_args, callback
+ )
else:
self._get_object(bucket, key, filename, extra_args, callback)
@@ -814,18 +814,18 @@ class S3Transfer:
if kwarg not in allowed:
raise ValueError(
"Invalid extra_args key '%s', "
- "must be one of: %s" % (kwarg, ', '.join(allowed))
- )
-
- def _ranged_download(
- self, bucket, key, filename, object_size, extra_args, callback
- ):
- downloader = MultipartDownloader(
- self._client, self._config, self._osutil
- )
- downloader.download_file(
- bucket, key, filename, object_size, extra_args, callback
- )
+ "must be one of: %s" % (kwarg, ', '.join(allowed))
+ )
+
+ def _ranged_download(
+ self, bucket, key, filename, object_size, extra_args, callback
+ ):
+ downloader = MultipartDownloader(
+ self._client, self._config, self._osutil
+ )
+ downloader.download_file(
+ bucket, key, filename, object_size, extra_args, callback
+ )
def _get_object(self, bucket, key, filename, extra_args, callback):
# precondition: num_download_attempts > 0
@@ -833,42 +833,42 @@ class S3Transfer:
last_exception = None
for i in range(max_attempts):
try:
- return self._do_get_object(
- bucket, key, filename, extra_args, callback
- )
- except (
- socket.timeout,
- OSError,
- ReadTimeoutError,
- IncompleteReadError,
- ) as e:
+ return self._do_get_object(
+ bucket, key, filename, extra_args, callback
+ )
+ except (
+ socket.timeout,
+ OSError,
+ ReadTimeoutError,
+ IncompleteReadError,
+ ) as e:
# TODO: we need a way to reset the callback if the
# download failed.
- logger.debug(
- "Retrying exception caught (%s), "
- "retrying request, (attempt %s / %s)",
- e,
- i,
- max_attempts,
- exc_info=True,
- )
+ logger.debug(
+ "Retrying exception caught (%s), "
+ "retrying request, (attempt %s / %s)",
+ e,
+ i,
+ max_attempts,
+ exc_info=True,
+ )
last_exception = e
continue
raise RetriesExceededError(last_exception)
def _do_get_object(self, bucket, key, filename, extra_args, callback):
- response = self._client.get_object(
- Bucket=bucket, Key=key, **extra_args
- )
- streaming_body = StreamReaderProgress(response['Body'], callback)
+ response = self._client.get_object(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ streaming_body = StreamReaderProgress(response['Body'], callback)
with self._osutil.open(filename, 'wb') as f:
for chunk in iter(lambda: streaming_body.read(8192), b''):
f.write(chunk)
def _object_size(self, bucket, key, extra_args):
- return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[
- 'ContentLength'
- ]
+ return self._client.head_object(Bucket=bucket, Key=key, **extra_args)[
+ 'ContentLength'
+ ]
def _multipart_upload(self, filename, bucket, key, callback, extra_args):
uploader = MultipartUploader(self._client, self._config, self._osutil)
diff --git a/contrib/python/s3transfer/py3/s3transfer/bandwidth.py b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py
index 957049dffc..9bac5885e1 100644
--- a/contrib/python/s3transfer/py3/s3transfer/bandwidth.py
+++ b/contrib/python/s3transfer/py3/s3transfer/bandwidth.py
@@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
-import threading
+import threading
import time
@@ -30,19 +30,19 @@ class RequestExceededException(Exception):
"""
self.requested_amt = requested_amt
self.retry_time = retry_time
- msg = 'Request amount {} exceeded the amount available. Retry in {}'.format(
- requested_amt, retry_time
+ msg = 'Request amount {} exceeded the amount available. Retry in {}'.format(
+ requested_amt, retry_time
)
- super().__init__(msg)
+ super().__init__(msg)
-class RequestToken:
+class RequestToken:
"""A token to pass as an identifier when consuming from the LeakyBucket"""
-
+
pass
-class TimeUtils:
+class TimeUtils:
def time(self):
"""Get the current time back
@@ -60,7 +60,7 @@ class TimeUtils:
return time.sleep(value)
-class BandwidthLimiter:
+class BandwidthLimiter:
def __init__(self, leaky_bucket, time_utils=None):
"""Limits bandwidth for shared S3 transfers
@@ -75,9 +75,9 @@ class BandwidthLimiter:
if time_utils is None:
self._time_utils = TimeUtils()
- def get_bandwith_limited_stream(
- self, fileobj, transfer_coordinator, enabled=True
- ):
+ def get_bandwith_limited_stream(
+ self, fileobj, transfer_coordinator, enabled=True
+ ):
"""Wraps a fileobj in a bandwidth limited stream wrapper
:type fileobj: file-like obj
@@ -91,22 +91,22 @@ class BandwidthLimiter:
:param enabled: Whether bandwidth limiting should be enabled to start
"""
stream = BandwidthLimitedStream(
- fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils
- )
+ fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils
+ )
if not enabled:
stream.disable_bandwidth_limiting()
return stream
-class BandwidthLimitedStream:
- def __init__(
- self,
- fileobj,
- leaky_bucket,
- transfer_coordinator,
- time_utils=None,
- bytes_threshold=256 * 1024,
- ):
+class BandwidthLimitedStream:
+ def __init__(
+ self,
+ fileobj,
+ leaky_bucket,
+ transfer_coordinator,
+ time_utils=None,
+ bytes_threshold=256 * 1024,
+ ):
"""Limits bandwidth for reads on a wrapped stream
:type fileobj: file-like object
@@ -163,7 +163,7 @@ class BandwidthLimitedStream:
return self._fileobj.read(amount)
def _consume_through_leaky_bucket(self):
- # NOTE: If the read amount on the stream are high, it will result
+ # NOTE: If the read amount on the stream are high, it will result
# in large bursty behavior as there is not an interface for partial
# reads. However given the read's on this abstraction are at most 256KB
# (via downloads), it reduces the burstiness to be small KB bursts at
@@ -171,8 +171,8 @@ class BandwidthLimitedStream:
while not self._transfer_coordinator.exception:
try:
self._leaky_bucket.consume(
- self._bytes_seen, self._request_token
- )
+ self._bytes_seen, self._request_token
+ )
self._bytes_seen = 0
return
except RequestExceededException as e:
@@ -188,8 +188,8 @@ class BandwidthLimitedStream:
"""Signal that data being read is not being transferred to S3"""
self.disable_bandwidth_limiting()
- def seek(self, where, whence=0):
- self._fileobj.seek(where, whence)
+ def seek(self, where, whence=0):
+ self._fileobj.seek(where, whence)
def tell(self):
return self._fileobj.tell()
@@ -211,14 +211,14 @@ class BandwidthLimitedStream:
self.close()
-class LeakyBucket:
- def __init__(
- self,
- max_rate,
- time_utils=None,
- rate_tracker=None,
- consumption_scheduler=None,
- ):
+class LeakyBucket:
+ def __init__(
+ self,
+ max_rate,
+ time_utils=None,
+ rate_tracker=None,
+ consumption_scheduler=None,
+ ):
"""A leaky bucket abstraction to limit bandwidth consumption
:type rate: int
@@ -269,12 +269,12 @@ class LeakyBucket:
time_now = self._time_utils.time()
if self._consumption_scheduler.is_scheduled(request_token):
return self._release_requested_amt_for_scheduled_request(
- amt, request_token, time_now
- )
+ amt, request_token, time_now
+ )
elif self._projected_to_exceed_max_rate(amt, time_now):
self._raise_request_exceeded_exception(
- amt, request_token, time_now
- )
+ amt, request_token, time_now
+ )
else:
return self._release_requested_amt(amt, time_now)
@@ -282,29 +282,29 @@ class LeakyBucket:
projected_rate = self._rate_tracker.get_projected_rate(amt, time_now)
return projected_rate > self._max_rate
- def _release_requested_amt_for_scheduled_request(
- self, amt, request_token, time_now
- ):
+ def _release_requested_amt_for_scheduled_request(
+ self, amt, request_token, time_now
+ ):
self._consumption_scheduler.process_scheduled_consumption(
- request_token
- )
+ request_token
+ )
return self._release_requested_amt(amt, time_now)
def _raise_request_exceeded_exception(self, amt, request_token, time_now):
- allocated_time = amt / float(self._max_rate)
+ allocated_time = amt / float(self._max_rate)
retry_time = self._consumption_scheduler.schedule_consumption(
- amt, request_token, allocated_time
- )
+ amt, request_token, allocated_time
+ )
raise RequestExceededException(
- requested_amt=amt, retry_time=retry_time
- )
+ requested_amt=amt, retry_time=retry_time
+ )
def _release_requested_amt(self, amt, time_now):
self._rate_tracker.record_consumption_rate(amt, time_now)
return amt
-class ConsumptionScheduler:
+class ConsumptionScheduler:
def __init__(self):
"""Schedules when to consume a desired amount"""
self._tokens_to_scheduled_consumption = {}
@@ -354,11 +354,11 @@ class ConsumptionScheduler:
"""
scheduled_retry = self._tokens_to_scheduled_consumption.pop(token)
self._total_wait = max(
- self._total_wait - scheduled_retry['time_to_consume'], 0
- )
+ self._total_wait - scheduled_retry['time_to_consume'], 0
+ )
-class BandwidthRateTracker:
+class BandwidthRateTracker:
def __init__(self, alpha=0.8):
"""Tracks the rate of bandwidth consumption
@@ -401,8 +401,8 @@ class BandwidthRateTracker:
if self._last_time is None:
return 0.0
return self._calculate_exponential_moving_average_rate(
- amt, time_at_consumption
- )
+ amt, time_at_consumption
+ )
def record_consumption_rate(self, amt, time_at_consumption):
"""Record the consumption rate based off amount and time point
@@ -418,22 +418,22 @@ class BandwidthRateTracker:
self._current_rate = 0.0
return
self._current_rate = self._calculate_exponential_moving_average_rate(
- amt, time_at_consumption
- )
+ amt, time_at_consumption
+ )
self._last_time = time_at_consumption
def _calculate_rate(self, amt, time_at_consumption):
time_delta = time_at_consumption - self._last_time
if time_delta <= 0:
- # While it is really unlikely to see this in an actual transfer,
+ # While it is really unlikely to see this in an actual transfer,
# we do not want to be returning back a negative rate or try to
# divide the amount by zero. So instead return back an infinite
# rate as the time delta is infinitesimally small.
return float('inf')
return amt / (time_delta)
- def _calculate_exponential_moving_average_rate(
- self, amt, time_at_consumption
- ):
+ def _calculate_exponential_moving_average_rate(
+ self, amt, time_at_consumption
+ ):
new_rate = self._calculate_rate(amt, time_at_consumption)
return self._alpha * new_rate + (1 - self._alpha) * self._current_rate
diff --git a/contrib/python/s3transfer/py3/s3transfer/compat.py b/contrib/python/s3transfer/py3/s3transfer/compat.py
index 5746025983..68267ad0e2 100644
--- a/contrib/python/s3transfer/py3/s3transfer/compat.py
+++ b/contrib/python/s3transfer/py3/s3transfer/compat.py
@@ -10,11 +10,11 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
-import errno
+import errno
import inspect
import os
import socket
-import sys
+import sys
from botocore.compat import six
@@ -34,18 +34,18 @@ else:
rename_file = os.rename
-def accepts_kwargs(func):
- return inspect.getfullargspec(func)[2]
+def accepts_kwargs(func):
+ return inspect.getfullargspec(func)[2]
-# In python 3, socket.error is OSError, which is too general
-# for what we want (i.e FileNotFoundError is a subclass of OSError).
-# In python 3, all the socket related errors are in a newly created
-# ConnectionError.
-SOCKET_ERROR = ConnectionError
-MAXINT = None
+# In python 3, socket.error is OSError, which is too general
+# for what we want (i.e FileNotFoundError is a subclass of OSError).
+# In python 3, all the socket related errors are in a newly created
+# ConnectionError.
+SOCKET_ERROR = ConnectionError
+MAXINT = None
+
-
def seekable(fileobj):
"""Backwards compat function to determine if a fileobj is seekable
@@ -63,7 +63,7 @@ def seekable(fileobj):
try:
fileobj.seek(0, 1)
return True
- except OSError:
+ except OSError:
# If an io related error was thrown then it is not seekable.
return False
# Else, the fileobj is not seekable
@@ -81,14 +81,14 @@ def readable(fileobj):
return fileobj.readable()
return hasattr(fileobj, 'read')
-
-
-def fallocate(fileobj, size):
- if hasattr(os, 'posix_fallocate'):
- os.posix_fallocate(fileobj.fileno(), 0, size)
- else:
- fileobj.truncate(size)
-
-
-# Import at end of file to avoid circular dependencies
-from multiprocessing.managers import BaseManager # noqa: F401,E402
+
+
+def fallocate(fileobj, size):
+ if hasattr(os, 'posix_fallocate'):
+ os.posix_fallocate(fileobj.fileno(), 0, size)
+ else:
+ fileobj.truncate(size)
+
+
+# Import at end of file to avoid circular dependencies
+from multiprocessing.managers import BaseManager # noqa: F401,E402
diff --git a/contrib/python/s3transfer/py3/s3transfer/constants.py b/contrib/python/s3transfer/py3/s3transfer/constants.py
index 29da842982..ba35bc72e9 100644
--- a/contrib/python/s3transfer/py3/s3transfer/constants.py
+++ b/contrib/python/s3transfer/py3/s3transfer/constants.py
@@ -1,29 +1,29 @@
-# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import s3transfer
-
-KB = 1024
-MB = KB * KB
-GB = MB * KB
-
-ALLOWED_DOWNLOAD_ARGS = [
- 'VersionId',
- 'SSECustomerAlgorithm',
- 'SSECustomerKey',
- 'SSECustomerKeyMD5',
- 'RequestPayer',
- 'ExpectedBucketOwner',
-]
-
-USER_AGENT = 's3transfer/%s' % s3transfer.__version__
-PROCESS_USER_AGENT = '%s processpool' % USER_AGENT
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import s3transfer
+
+KB = 1024
+MB = KB * KB
+GB = MB * KB
+
+ALLOWED_DOWNLOAD_ARGS = [
+ 'VersionId',
+ 'SSECustomerAlgorithm',
+ 'SSECustomerKey',
+ 'SSECustomerKeyMD5',
+ 'RequestPayer',
+ 'ExpectedBucketOwner',
+]
+
+USER_AGENT = 's3transfer/%s' % s3transfer.__version__
+PROCESS_USER_AGENT = '%s processpool' % USER_AGENT
diff --git a/contrib/python/s3transfer/py3/s3transfer/copies.py b/contrib/python/s3transfer/py3/s3transfer/copies.py
index 8daa710281..a1dfdc8ba3 100644
--- a/contrib/python/s3transfer/py3/s3transfer/copies.py
+++ b/contrib/python/s3transfer/py3/s3transfer/copies.py
@@ -13,18 +13,18 @@
import copy
import math
-from s3transfer.tasks import (
- CompleteMultipartUploadTask,
- CreateMultipartUploadTask,
- SubmissionTask,
- Task,
-)
-from s3transfer.utils import (
- ChunksizeAdjuster,
- calculate_range_parameter,
- get_callbacks,
- get_filtered_dict,
-)
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import (
+ ChunksizeAdjuster,
+ calculate_range_parameter,
+ get_callbacks,
+ get_filtered_dict,
+)
class CopySubmissionTask(SubmissionTask):
@@ -38,8 +38,8 @@ class CopySubmissionTask(SubmissionTask):
'CopySourceSSECustomerKey': 'SSECustomerKey',
'CopySourceSSECustomerAlgorithm': 'SSECustomerAlgorithm',
'CopySourceSSECustomerKeyMD5': 'SSECustomerKeyMD5',
- 'RequestPayer': 'RequestPayer',
- 'ExpectedBucketOwner': 'ExpectedBucketOwner',
+ 'RequestPayer': 'RequestPayer',
+ 'ExpectedBucketOwner': 'ExpectedBucketOwner',
}
UPLOAD_PART_COPY_ARGS = [
@@ -54,7 +54,7 @@ class CopySubmissionTask(SubmissionTask):
'SSECustomerAlgorithm',
'SSECustomerKeyMD5',
'RequestPayer',
- 'ExpectedBucketOwner',
+ 'ExpectedBucketOwner',
]
CREATE_MULTIPART_ARGS_BLACKLIST = [
@@ -65,15 +65,15 @@ class CopySubmissionTask(SubmissionTask):
'CopySourceSSECustomerKey',
'CopySourceSSECustomerAlgorithm',
'CopySourceSSECustomerKeyMD5',
- 'MetadataDirective',
- 'TaggingDirective',
+ 'MetadataDirective',
+ 'TaggingDirective',
]
- COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
+ COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
- def _submit(
- self, client, config, osutil, request_executor, transfer_future
- ):
+ def _submit(
+ self, client, config, osutil, request_executor, transfer_future
+ ):
"""
:param client: The client associated with the transfer manager
@@ -100,11 +100,11 @@ class CopySubmissionTask(SubmissionTask):
# of the client, they may have to provide the file size themselves
# with a completely new client.
call_args = transfer_future.meta.call_args
- head_object_request = (
+ head_object_request = (
self._get_head_object_request_from_copy_source(
- call_args.copy_source
- )
- )
+ call_args.copy_source
+ )
+ )
extra_args = call_args.extra_args
# Map any values that may be used in the head object that is
@@ -112,30 +112,30 @@ class CopySubmissionTask(SubmissionTask):
for param, value in extra_args.items():
if param in self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING:
head_object_request[
- self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param]
- ] = value
+ self.EXTRA_ARGS_TO_HEAD_ARGS_MAPPING[param]
+ ] = value
response = call_args.source_client.head_object(
- **head_object_request
- )
+ **head_object_request
+ )
transfer_future.meta.provide_transfer_size(
- response['ContentLength']
- )
+ response['ContentLength']
+ )
# If it is greater than threshold do a multipart copy, otherwise
# do a regular copy object.
if transfer_future.meta.size < config.multipart_threshold:
self._submit_copy_request(
- client, config, osutil, request_executor, transfer_future
- )
+ client, config, osutil, request_executor, transfer_future
+ )
else:
self._submit_multipart_request(
- client, config, osutil, request_executor, transfer_future
- )
+ client, config, osutil, request_executor, transfer_future
+ )
- def _submit_copy_request(
- self, client, config, osutil, request_executor, transfer_future
- ):
+ def _submit_copy_request(
+ self, client, config, osutil, request_executor, transfer_future
+ ):
call_args = transfer_future.meta.call_args
# Get the needed progress callbacks for the task
@@ -153,15 +153,15 @@ class CopySubmissionTask(SubmissionTask):
'key': call_args.key,
'extra_args': call_args.extra_args,
'callbacks': progress_callbacks,
- 'size': transfer_future.meta.size,
+ 'size': transfer_future.meta.size,
},
- is_final=True,
- ),
+ is_final=True,
+ ),
)
- def _submit_multipart_request(
- self, client, config, osutil, request_executor, transfer_future
- ):
+ def _submit_multipart_request(
+ self, client, config, osutil, request_executor, transfer_future
+ ):
call_args = transfer_future.meta.call_args
# Submit the request to create a multipart upload and make sure it
@@ -180,8 +180,8 @@ class CopySubmissionTask(SubmissionTask):
'bucket': call_args.bucket,
'key': call_args.key,
'extra_args': create_multipart_extra_args,
- },
- ),
+ },
+ ),
)
# Determine how many parts are needed based on filesize and
@@ -189,11 +189,11 @@ class CopySubmissionTask(SubmissionTask):
part_size = config.multipart_chunksize
adjuster = ChunksizeAdjuster()
part_size = adjuster.adjust_chunksize(
- part_size, transfer_future.meta.size
- )
+ part_size, transfer_future.meta.size
+ )
num_parts = int(
- math.ceil(transfer_future.meta.size / float(part_size))
- )
+ math.ceil(transfer_future.meta.size / float(part_size))
+ )
# Submit requests to upload the parts of the file.
part_futures = []
@@ -201,24 +201,24 @@ class CopySubmissionTask(SubmissionTask):
for part_number in range(1, num_parts + 1):
extra_part_args = self._extra_upload_part_args(
- call_args.extra_args
- )
+ call_args.extra_args
+ )
# The part number for upload part starts at 1 while the
# range parameter starts at zero, so just subtract 1 off of
# the part number
extra_part_args['CopySourceRange'] = calculate_range_parameter(
- part_size,
- part_number - 1,
- num_parts,
- transfer_future.meta.size,
- )
+ part_size,
+ part_number - 1,
+ num_parts,
+ transfer_future.meta.size,
+ )
# Get the size of the part copy as well for the progress
# callbacks.
size = self._get_transfer_size(
- part_size,
- part_number - 1,
- num_parts,
- transfer_future.meta.size,
+ part_size,
+ part_number - 1,
+ num_parts,
+ transfer_future.meta.size,
)
part_futures.append(
self._transfer_coordinator.submit(
@@ -233,18 +233,18 @@ class CopySubmissionTask(SubmissionTask):
'part_number': part_number,
'extra_args': extra_part_args,
'callbacks': progress_callbacks,
- 'size': size,
+ 'size': size,
},
pending_main_kwargs={
'upload_id': create_multipart_future
- },
- ),
+ },
+ ),
)
)
complete_multipart_extra_args = self._extra_complete_multipart_args(
- call_args.extra_args
- )
+ call_args.extra_args
+ )
# Submit the request to complete the multipart upload.
self._transfer_coordinator.submit(
request_executor,
@@ -258,10 +258,10 @@ class CopySubmissionTask(SubmissionTask):
},
pending_main_kwargs={
'upload_id': create_multipart_future,
- 'parts': part_futures,
+ 'parts': part_futures,
},
- is_final=True,
- ),
+ is_final=True,
+ ),
)
def _get_head_object_request_from_copy_source(self, copy_source):
@@ -271,7 +271,7 @@ class CopySubmissionTask(SubmissionTask):
raise TypeError(
'Expecting dictionary formatted: '
'{"Bucket": bucket_name, "Key": key} '
- 'but got %s or type %s.' % (copy_source, type(copy_source))
+ 'but got %s or type %s.' % (copy_source, type(copy_source))
)
def _extra_upload_part_args(self, extra_args):
@@ -282,9 +282,9 @@ class CopySubmissionTask(SubmissionTask):
def _extra_complete_multipart_args(self, extra_args):
return get_filtered_dict(extra_args, self.COMPLETE_MULTIPART_ARGS)
- def _get_transfer_size(
- self, part_size, part_index, num_parts, total_transfer_size
- ):
+ def _get_transfer_size(
+ self, part_size, part_index, num_parts, total_transfer_size
+ ):
if part_index == num_parts - 1:
# The last part may be different in size then the rest of the
# parts.
@@ -294,10 +294,10 @@ class CopySubmissionTask(SubmissionTask):
class CopyObjectTask(Task):
"""Task to do a nonmultipart copy"""
-
- def _main(
- self, client, copy_source, bucket, key, extra_args, callbacks, size
- ):
+
+ def _main(
+ self, client, copy_source, bucket, key, extra_args, callbacks, size
+ ):
"""
:param client: The client to use when calling PutObject
:param copy_source: The CopySource parameter to use
@@ -311,27 +311,27 @@ class CopyObjectTask(Task):
"""
client.copy_object(
- CopySource=copy_source, Bucket=bucket, Key=key, **extra_args
- )
+ CopySource=copy_source, Bucket=bucket, Key=key, **extra_args
+ )
for callback in callbacks:
callback(bytes_transferred=size)
class CopyPartTask(Task):
"""Task to upload a part in a multipart copy"""
-
- def _main(
- self,
- client,
- copy_source,
- bucket,
- key,
- upload_id,
- part_number,
- extra_args,
- callbacks,
- size,
- ):
+
+ def _main(
+ self,
+ client,
+ copy_source,
+ bucket,
+ key,
+ upload_id,
+ part_number,
+ extra_args,
+ callbacks,
+ size,
+ ):
"""
:param client: The client to use when calling PutObject
:param copy_source: The CopySource parameter to use
@@ -355,13 +355,13 @@ class CopyPartTask(Task):
the multipart upload.
"""
response = client.upload_part_copy(
- CopySource=copy_source,
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
- PartNumber=part_number,
- **extra_args
- )
+ CopySource=copy_source,
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ PartNumber=part_number,
+ **extra_args
+ )
for callback in callbacks:
callback(bytes_transferred=size)
etag = response['CopyPartResult']['ETag']
diff --git a/contrib/python/s3transfer/py3/s3transfer/crt.py b/contrib/python/s3transfer/py3/s3transfer/crt.py
index b1d573c1b5..7b5d130136 100644
--- a/contrib/python/s3transfer/py3/s3transfer/crt.py
+++ b/contrib/python/s3transfer/py3/s3transfer/crt.py
@@ -1,644 +1,644 @@
-# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import logging
-import threading
-from io import BytesIO
-
-import awscrt.http
-import botocore.awsrequest
-import botocore.session
-from awscrt.auth import AwsCredentials, AwsCredentialsProvider
-from awscrt.io import (
- ClientBootstrap,
- ClientTlsContext,
- DefaultHostResolver,
- EventLoopGroup,
- TlsContextOptions,
-)
-from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
-from botocore import UNSIGNED
-from botocore.compat import urlsplit
-from botocore.config import Config
-from botocore.exceptions import NoCredentialsError
-
-from s3transfer.constants import GB, MB
-from s3transfer.exceptions import TransferNotDoneError
-from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
-from s3transfer.utils import CallArgs, OSUtils, get_callbacks
-
-logger = logging.getLogger(__name__)
-
-
-class CRTCredentialProviderAdapter:
- def __init__(self, botocore_credential_provider):
- self._botocore_credential_provider = botocore_credential_provider
- self._loaded_credentials = None
- self._lock = threading.Lock()
-
- def __call__(self):
- credentials = self._get_credentials().get_frozen_credentials()
- return AwsCredentials(
- credentials.access_key, credentials.secret_key, credentials.token
- )
-
- def _get_credentials(self):
- with self._lock:
- if self._loaded_credentials is None:
- loaded_creds = (
- self._botocore_credential_provider.load_credentials()
- )
- if loaded_creds is None:
- raise NoCredentialsError()
- self._loaded_credentials = loaded_creds
- return self._loaded_credentials
-
-
-def create_s3_crt_client(
- region,
- botocore_credential_provider=None,
- num_threads=None,
- target_throughput=5 * GB / 8,
- part_size=8 * MB,
- use_ssl=True,
- verify=None,
-):
- """
- :type region: str
- :param region: The region used for signing
-
- :type botocore_credential_provider:
- Optional[botocore.credentials.CredentialResolver]
- :param botocore_credential_provider: Provide credentials for CRT
- to sign the request if not set, the request will not be signed
-
- :type num_threads: Optional[int]
- :param num_threads: Number of worker threads generated. Default
- is the number of processors in the machine.
-
- :type target_throughput: Optional[int]
- :param target_throughput: Throughput target in Bytes.
- Default is 0.625 GB/s (which translates to 5 Gb/s).
-
- :type part_size: Optional[int]
- :param part_size: Size, in Bytes, of parts that files will be downloaded
- or uploaded in.
-
- :type use_ssl: boolean
- :param use_ssl: Whether or not to use SSL. By default, SSL is used.
- Note that not all services support non-ssl connections.
-
- :type verify: Optional[boolean/string]
- :param verify: Whether or not to verify SSL certificates.
- By default SSL certificates are verified. You can provide the
- following values:
-
- * False - do not validate SSL certificates. SSL will still be
- used (unless use_ssl is False), but SSL certificates
- will not be verified.
- * path/to/cert/bundle.pem - A filename of the CA cert bundle to
- use. Specify this argument if you want to use a custom CA cert
- bundle instead of the default one on your system.
- """
-
- event_loop_group = EventLoopGroup(num_threads)
- host_resolver = DefaultHostResolver(event_loop_group)
- bootstrap = ClientBootstrap(event_loop_group, host_resolver)
- provider = None
- tls_connection_options = None
-
- tls_mode = (
- S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED
- )
- if verify is not None:
- tls_ctx_options = TlsContextOptions()
- if verify:
- tls_ctx_options.override_default_trust_store_from_path(
- ca_filepath=verify
- )
- else:
- tls_ctx_options.verify_peer = False
- client_tls_option = ClientTlsContext(tls_ctx_options)
- tls_connection_options = client_tls_option.new_connection_options()
- if botocore_credential_provider:
- credentails_provider_adapter = CRTCredentialProviderAdapter(
- botocore_credential_provider
- )
- provider = AwsCredentialsProvider.new_delegate(
- credentails_provider_adapter
- )
-
- target_gbps = target_throughput * 8 / GB
- return S3Client(
- bootstrap=bootstrap,
- region=region,
- credential_provider=provider,
- part_size=part_size,
- tls_mode=tls_mode,
- tls_connection_options=tls_connection_options,
- throughput_target_gbps=target_gbps,
- )
-
-
-class CRTTransferManager:
- def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
- """A transfer manager interface for Amazon S3 on CRT s3 client.
-
- :type crt_s3_client: awscrt.s3.S3Client
- :param crt_s3_client: The CRT s3 client, handling all the
- HTTP requests and functions under then hood
-
- :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer
- :param crt_request_serializer: Serializer, generates unsigned crt HTTP
- request.
-
- :type osutil: s3transfer.utils.OSUtils
- :param osutil: OSUtils object to use for os-related behavior when
- using with transfer manager.
- """
- if osutil is None:
- self._osutil = OSUtils()
- self._crt_s3_client = crt_s3_client
- self._s3_args_creator = S3ClientArgsCreator(
- crt_request_serializer, self._osutil
- )
- self._future_coordinators = []
- self._semaphore = threading.Semaphore(128) # not configurable
- # A counter to create unique id's for each transfer submitted.
- self._id_counter = 0
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, *args):
- cancel = False
- if exc_type:
- cancel = True
- self._shutdown(cancel)
-
- def download(
- self, bucket, key, fileobj, extra_args=None, subscribers=None
- ):
- if extra_args is None:
- extra_args = {}
- if subscribers is None:
- subscribers = {}
- callargs = CallArgs(
- bucket=bucket,
- key=key,
- fileobj=fileobj,
- extra_args=extra_args,
- subscribers=subscribers,
- )
- return self._submit_transfer("get_object", callargs)
-
- def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
- if extra_args is None:
- extra_args = {}
- if subscribers is None:
- subscribers = {}
- callargs = CallArgs(
- bucket=bucket,
- key=key,
- fileobj=fileobj,
- extra_args=extra_args,
- subscribers=subscribers,
- )
- return self._submit_transfer("put_object", callargs)
-
- def delete(self, bucket, key, extra_args=None, subscribers=None):
- if extra_args is None:
- extra_args = {}
- if subscribers is None:
- subscribers = {}
- callargs = CallArgs(
- bucket=bucket,
- key=key,
- extra_args=extra_args,
- subscribers=subscribers,
- )
- return self._submit_transfer("delete_object", callargs)
-
- def shutdown(self, cancel=False):
- self._shutdown(cancel)
-
- def _cancel_transfers(self):
- for coordinator in self._future_coordinators:
- if not coordinator.done():
- coordinator.cancel()
-
- def _finish_transfers(self):
- for coordinator in self._future_coordinators:
- coordinator.result()
-
- def _wait_transfers_done(self):
- for coordinator in self._future_coordinators:
- coordinator.wait_until_on_done_callbacks_complete()
-
- def _shutdown(self, cancel=False):
- if cancel:
- self._cancel_transfers()
- try:
- self._finish_transfers()
-
- except KeyboardInterrupt:
- self._cancel_transfers()
- except Exception:
- pass
- finally:
- self._wait_transfers_done()
-
- def _release_semaphore(self, **kwargs):
- self._semaphore.release()
-
- def _submit_transfer(self, request_type, call_args):
- on_done_after_calls = [self._release_semaphore]
- coordinator = CRTTransferCoordinator(transfer_id=self._id_counter)
- components = {
- 'meta': CRTTransferMeta(self._id_counter, call_args),
- 'coordinator': coordinator,
- }
- future = CRTTransferFuture(**components)
- afterdone = AfterDoneHandler(coordinator)
- on_done_after_calls.append(afterdone)
-
- try:
- self._semaphore.acquire()
- on_queued = self._s3_args_creator.get_crt_callback(
- future, 'queued'
- )
- on_queued()
- crt_callargs = self._s3_args_creator.get_make_request_args(
- request_type,
- call_args,
- coordinator,
- future,
- on_done_after_calls,
- )
- crt_s3_request = self._crt_s3_client.make_request(**crt_callargs)
- except Exception as e:
- coordinator.set_exception(e, True)
- on_done = self._s3_args_creator.get_crt_callback(
- future, 'done', after_subscribers=on_done_after_calls
- )
- on_done(error=e)
- else:
- coordinator.set_s3_request(crt_s3_request)
- self._future_coordinators.append(coordinator)
-
- self._id_counter += 1
- return future
-
-
-class CRTTransferMeta(BaseTransferMeta):
- """Holds metadata about the CRTTransferFuture"""
-
- def __init__(self, transfer_id=None, call_args=None):
- self._transfer_id = transfer_id
- self._call_args = call_args
- self._user_context = {}
-
- @property
- def call_args(self):
- return self._call_args
-
- @property
- def transfer_id(self):
- return self._transfer_id
-
- @property
- def user_context(self):
- return self._user_context
-
-
-class CRTTransferFuture(BaseTransferFuture):
- def __init__(self, meta=None, coordinator=None):
- """The future associated to a submitted transfer request via CRT S3 client
-
- :type meta: s3transfer.crt.CRTTransferMeta
- :param meta: The metadata associated to the transfer future.
-
- :type coordinator: s3transfer.crt.CRTTransferCoordinator
- :param coordinator: The coordinator associated to the transfer future.
- """
- self._meta = meta
- if meta is None:
- self._meta = CRTTransferMeta()
- self._coordinator = coordinator
-
- @property
- def meta(self):
- return self._meta
-
- def done(self):
- return self._coordinator.done()
-
- def result(self, timeout=None):
- self._coordinator.result(timeout)
-
- def cancel(self):
- self._coordinator.cancel()
-
- def set_exception(self, exception):
- """Sets the exception on the future."""
- if not self.done():
- raise TransferNotDoneError(
- 'set_exception can only be called once the transfer is '
- 'complete.'
- )
- self._coordinator.set_exception(exception, override=True)
-
-
-class BaseCRTRequestSerializer:
- def serialize_http_request(self, transfer_type, future):
- """Serialize CRT HTTP requests.
-
- :type transfer_type: string
- :param transfer_type: the type of transfer made,
- e.g 'put_object', 'get_object', 'delete_object'
-
- :type future: s3transfer.crt.CRTTransferFuture
-
- :rtype: awscrt.http.HttpRequest
- :returns: An unsigned HTTP request to be used for the CRT S3 client
- """
- raise NotImplementedError('serialize_http_request()')
-
-
-class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer):
- def __init__(self, session, client_kwargs=None):
- """Serialize CRT HTTP request using botocore logic
- It also takes into account configuration from both the session
- and any keyword arguments that could be passed to
- `Session.create_client()` when serializing the request.
-
- :type session: botocore.session.Session
-
- :type client_kwargs: Optional[Dict[str, str]])
- :param client_kwargs: The kwargs for the botocore
- s3 client initialization.
- """
- self._session = session
- if client_kwargs is None:
- client_kwargs = {}
- self._resolve_client_config(session, client_kwargs)
- self._client = session.create_client(**client_kwargs)
- self._client.meta.events.register(
- 'request-created.s3.*', self._capture_http_request
- )
- self._client.meta.events.register(
- 'after-call.s3.*', self._change_response_to_serialized_http_request
- )
- self._client.meta.events.register(
- 'before-send.s3.*', self._make_fake_http_response
- )
-
- def _resolve_client_config(self, session, client_kwargs):
- user_provided_config = None
- if session.get_default_client_config():
- user_provided_config = session.get_default_client_config()
- if 'config' in client_kwargs:
- user_provided_config = client_kwargs['config']
-
- client_config = Config(signature_version=UNSIGNED)
- if user_provided_config:
- client_config = user_provided_config.merge(client_config)
- client_kwargs['config'] = client_config
- client_kwargs["service_name"] = "s3"
-
- def _crt_request_from_aws_request(self, aws_request):
- url_parts = urlsplit(aws_request.url)
- crt_path = url_parts.path
- if url_parts.query:
- crt_path = f'{crt_path}?{url_parts.query}'
- headers_list = []
- for name, value in aws_request.headers.items():
- if isinstance(value, str):
- headers_list.append((name, value))
- else:
- headers_list.append((name, str(value, 'utf-8')))
-
- crt_headers = awscrt.http.HttpHeaders(headers_list)
- # CRT requires body (if it exists) to be an I/O stream.
- crt_body_stream = None
- if aws_request.body:
- if hasattr(aws_request.body, 'seek'):
- crt_body_stream = aws_request.body
- else:
- crt_body_stream = BytesIO(aws_request.body)
-
- crt_request = awscrt.http.HttpRequest(
- method=aws_request.method,
- path=crt_path,
- headers=crt_headers,
- body_stream=crt_body_stream,
- )
- return crt_request
-
- def _convert_to_crt_http_request(self, botocore_http_request):
- # Logic that does CRTUtils.crt_request_from_aws_request
- crt_request = self._crt_request_from_aws_request(botocore_http_request)
- if crt_request.headers.get("host") is None:
- # If host is not set, set it for the request before using CRT s3
- url_parts = urlsplit(botocore_http_request.url)
- crt_request.headers.set("host", url_parts.netloc)
- if crt_request.headers.get('Content-MD5') is not None:
- crt_request.headers.remove("Content-MD5")
- return crt_request
-
- def _capture_http_request(self, request, **kwargs):
- request.context['http_request'] = request
-
- def _change_response_to_serialized_http_request(
- self, context, parsed, **kwargs
- ):
- request = context['http_request']
- parsed['HTTPRequest'] = request.prepare()
-
- def _make_fake_http_response(self, request, **kwargs):
- return botocore.awsrequest.AWSResponse(
- None,
- 200,
- {},
- FakeRawResponse(b""),
- )
-
- def _get_botocore_http_request(self, client_method, call_args):
- return getattr(self._client, client_method)(
- Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args
- )['HTTPRequest']
-
- def serialize_http_request(self, transfer_type, future):
- botocore_http_request = self._get_botocore_http_request(
- transfer_type, future.meta.call_args
- )
- crt_request = self._convert_to_crt_http_request(botocore_http_request)
- return crt_request
-
-
-class FakeRawResponse(BytesIO):
- def stream(self, amt=1024, decode_content=None):
- while True:
- chunk = self.read(amt)
- if not chunk:
- break
- yield chunk
-
-
-class CRTTransferCoordinator:
- """A helper class for managing CRTTransferFuture"""
-
- def __init__(self, transfer_id=None, s3_request=None):
- self.transfer_id = transfer_id
- self._s3_request = s3_request
- self._lock = threading.Lock()
- self._exception = None
- self._crt_future = None
- self._done_event = threading.Event()
-
- @property
- def s3_request(self):
- return self._s3_request
-
- def set_done_callbacks_complete(self):
- self._done_event.set()
-
- def wait_until_on_done_callbacks_complete(self, timeout=None):
- self._done_event.wait(timeout)
-
- def set_exception(self, exception, override=False):
- with self._lock:
- if not self.done() or override:
- self._exception = exception
-
- def cancel(self):
- if self._s3_request:
- self._s3_request.cancel()
-
- def result(self, timeout=None):
- if self._exception:
- raise self._exception
- try:
- self._crt_future.result(timeout)
- except KeyboardInterrupt:
- self.cancel()
- raise
- finally:
- if self._s3_request:
- self._s3_request = None
- self._crt_future.result(timeout)
-
- def done(self):
- if self._crt_future is None:
- return False
- return self._crt_future.done()
-
- def set_s3_request(self, s3_request):
- self._s3_request = s3_request
- self._crt_future = self._s3_request.finished_future
-
-
-class S3ClientArgsCreator:
- def __init__(self, crt_request_serializer, os_utils):
- self._request_serializer = crt_request_serializer
- self._os_utils = os_utils
-
- def get_make_request_args(
- self, request_type, call_args, coordinator, future, on_done_after_calls
- ):
- recv_filepath = None
- send_filepath = None
- s3_meta_request_type = getattr(
- S3RequestType, request_type.upper(), S3RequestType.DEFAULT
- )
- on_done_before_calls = []
- if s3_meta_request_type == S3RequestType.GET_OBJECT:
- final_filepath = call_args.fileobj
- recv_filepath = self._os_utils.get_temp_filename(final_filepath)
- file_ondone_call = RenameTempFileHandler(
- coordinator, final_filepath, recv_filepath, self._os_utils
- )
- on_done_before_calls.append(file_ondone_call)
- elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
- send_filepath = call_args.fileobj
- data_len = self._os_utils.get_file_size(send_filepath)
- call_args.extra_args["ContentLength"] = data_len
-
- crt_request = self._request_serializer.serialize_http_request(
- request_type, future
- )
-
- return {
- 'request': crt_request,
- 'type': s3_meta_request_type,
- 'recv_filepath': recv_filepath,
- 'send_filepath': send_filepath,
- 'on_done': self.get_crt_callback(
- future, 'done', on_done_before_calls, on_done_after_calls
- ),
- 'on_progress': self.get_crt_callback(future, 'progress'),
- }
-
- def get_crt_callback(
- self,
- future,
- callback_type,
- before_subscribers=None,
- after_subscribers=None,
- ):
- def invoke_all_callbacks(*args, **kwargs):
- callbacks_list = []
- if before_subscribers is not None:
- callbacks_list += before_subscribers
- callbacks_list += get_callbacks(future, callback_type)
- if after_subscribers is not None:
- callbacks_list += after_subscribers
- for callback in callbacks_list:
- # The get_callbacks helper will set the first augment
- # by keyword, the other augments need to be set by keyword
- # as well
- if callback_type == "progress":
- callback(bytes_transferred=args[0])
- else:
- callback(*args, **kwargs)
-
- return invoke_all_callbacks
-
-
-class RenameTempFileHandler:
- def __init__(self, coordinator, final_filename, temp_filename, osutil):
- self._coordinator = coordinator
- self._final_filename = final_filename
- self._temp_filename = temp_filename
- self._osutil = osutil
-
- def __call__(self, **kwargs):
- error = kwargs['error']
- if error:
- self._osutil.remove_file(self._temp_filename)
- else:
- try:
- self._osutil.rename_file(
- self._temp_filename, self._final_filename
- )
- except Exception as e:
- self._osutil.remove_file(self._temp_filename)
- # the CRT future has done already at this point
- self._coordinator.set_exception(e)
-
-
-class AfterDoneHandler:
- def __init__(self, coordinator):
- self._coordinator = coordinator
-
- def __call__(self, **kwargs):
- self._coordinator.set_done_callbacks_complete()
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import logging
+import threading
+from io import BytesIO
+
+import awscrt.http
+import botocore.awsrequest
+import botocore.session
+from awscrt.auth import AwsCredentials, AwsCredentialsProvider
+from awscrt.io import (
+ ClientBootstrap,
+ ClientTlsContext,
+ DefaultHostResolver,
+ EventLoopGroup,
+ TlsContextOptions,
+)
+from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType
+from botocore import UNSIGNED
+from botocore.compat import urlsplit
+from botocore.config import Config
+from botocore.exceptions import NoCredentialsError
+
+from s3transfer.constants import GB, MB
+from s3transfer.exceptions import TransferNotDoneError
+from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
+from s3transfer.utils import CallArgs, OSUtils, get_callbacks
+
+logger = logging.getLogger(__name__)
+
+
+class CRTCredentialProviderAdapter:
+ def __init__(self, botocore_credential_provider):
+ self._botocore_credential_provider = botocore_credential_provider
+ self._loaded_credentials = None
+ self._lock = threading.Lock()
+
+ def __call__(self):
+ credentials = self._get_credentials().get_frozen_credentials()
+ return AwsCredentials(
+ credentials.access_key, credentials.secret_key, credentials.token
+ )
+
+ def _get_credentials(self):
+ with self._lock:
+ if self._loaded_credentials is None:
+ loaded_creds = (
+ self._botocore_credential_provider.load_credentials()
+ )
+ if loaded_creds is None:
+ raise NoCredentialsError()
+ self._loaded_credentials = loaded_creds
+ return self._loaded_credentials
+
+
+def create_s3_crt_client(
+ region,
+ botocore_credential_provider=None,
+ num_threads=None,
+ target_throughput=5 * GB / 8,
+ part_size=8 * MB,
+ use_ssl=True,
+ verify=None,
+):
+ """
+ :type region: str
+ :param region: The region used for signing
+
+ :type botocore_credential_provider:
+ Optional[botocore.credentials.CredentialResolver]
+ :param botocore_credential_provider: Provide credentials for CRT
+ to sign the request if not set, the request will not be signed
+
+ :type num_threads: Optional[int]
+ :param num_threads: Number of worker threads generated. Default
+ is the number of processors in the machine.
+
+ :type target_throughput: Optional[int]
+ :param target_throughput: Throughput target in Bytes.
+ Default is 0.625 GB/s (which translates to 5 Gb/s).
+
+ :type part_size: Optional[int]
+ :param part_size: Size, in Bytes, of parts that files will be downloaded
+ or uploaded in.
+
+ :type use_ssl: boolean
+ :param use_ssl: Whether or not to use SSL. By default, SSL is used.
+ Note that not all services support non-ssl connections.
+
+ :type verify: Optional[boolean/string]
+ :param verify: Whether or not to verify SSL certificates.
+ By default SSL certificates are verified. You can provide the
+ following values:
+
+ * False - do not validate SSL certificates. SSL will still be
+ used (unless use_ssl is False), but SSL certificates
+ will not be verified.
+ * path/to/cert/bundle.pem - A filename of the CA cert bundle to
+ use. Specify this argument if you want to use a custom CA cert
+ bundle instead of the default one on your system.
+ """
+
+ event_loop_group = EventLoopGroup(num_threads)
+ host_resolver = DefaultHostResolver(event_loop_group)
+ bootstrap = ClientBootstrap(event_loop_group, host_resolver)
+ provider = None
+ tls_connection_options = None
+
+ tls_mode = (
+ S3RequestTlsMode.ENABLED if use_ssl else S3RequestTlsMode.DISABLED
+ )
+ if verify is not None:
+ tls_ctx_options = TlsContextOptions()
+ if verify:
+ tls_ctx_options.override_default_trust_store_from_path(
+ ca_filepath=verify
+ )
+ else:
+ tls_ctx_options.verify_peer = False
+ client_tls_option = ClientTlsContext(tls_ctx_options)
+ tls_connection_options = client_tls_option.new_connection_options()
+ if botocore_credential_provider:
+ credentails_provider_adapter = CRTCredentialProviderAdapter(
+ botocore_credential_provider
+ )
+ provider = AwsCredentialsProvider.new_delegate(
+ credentails_provider_adapter
+ )
+
+ target_gbps = target_throughput * 8 / GB
+ return S3Client(
+ bootstrap=bootstrap,
+ region=region,
+ credential_provider=provider,
+ part_size=part_size,
+ tls_mode=tls_mode,
+ tls_connection_options=tls_connection_options,
+ throughput_target_gbps=target_gbps,
+ )
+
+
+class CRTTransferManager:
+ def __init__(self, crt_s3_client, crt_request_serializer, osutil=None):
+ """A transfer manager interface for Amazon S3 on CRT s3 client.
+
+ :type crt_s3_client: awscrt.s3.S3Client
+ :param crt_s3_client: The CRT s3 client, handling all the
+ HTTP requests and functions under then hood
+
+ :type crt_request_serializer: s3transfer.crt.BaseCRTRequestSerializer
+ :param crt_request_serializer: Serializer, generates unsigned crt HTTP
+ request.
+
+ :type osutil: s3transfer.utils.OSUtils
+ :param osutil: OSUtils object to use for os-related behavior when
+ using with transfer manager.
+ """
+ if osutil is None:
+ self._osutil = OSUtils()
+ self._crt_s3_client = crt_s3_client
+ self._s3_args_creator = S3ClientArgsCreator(
+ crt_request_serializer, self._osutil
+ )
+ self._future_coordinators = []
+ self._semaphore = threading.Semaphore(128) # not configurable
+ # A counter to create unique id's for each transfer submitted.
+ self._id_counter = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, *args):
+ cancel = False
+ if exc_type:
+ cancel = True
+ self._shutdown(cancel)
+
+ def download(
+ self, bucket, key, fileobj, extra_args=None, subscribers=None
+ ):
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = {}
+ callargs = CallArgs(
+ bucket=bucket,
+ key=key,
+ fileobj=fileobj,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer("get_object", callargs)
+
+ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = {}
+ callargs = CallArgs(
+ bucket=bucket,
+ key=key,
+ fileobj=fileobj,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer("put_object", callargs)
+
+ def delete(self, bucket, key, extra_args=None, subscribers=None):
+ if extra_args is None:
+ extra_args = {}
+ if subscribers is None:
+ subscribers = {}
+ callargs = CallArgs(
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ )
+ return self._submit_transfer("delete_object", callargs)
+
+ def shutdown(self, cancel=False):
+ self._shutdown(cancel)
+
+ def _cancel_transfers(self):
+ for coordinator in self._future_coordinators:
+ if not coordinator.done():
+ coordinator.cancel()
+
+ def _finish_transfers(self):
+ for coordinator in self._future_coordinators:
+ coordinator.result()
+
+ def _wait_transfers_done(self):
+ for coordinator in self._future_coordinators:
+ coordinator.wait_until_on_done_callbacks_complete()
+
+ def _shutdown(self, cancel=False):
+ if cancel:
+ self._cancel_transfers()
+ try:
+ self._finish_transfers()
+
+ except KeyboardInterrupt:
+ self._cancel_transfers()
+ except Exception:
+ pass
+ finally:
+ self._wait_transfers_done()
+
+ def _release_semaphore(self, **kwargs):
+ self._semaphore.release()
+
+ def _submit_transfer(self, request_type, call_args):
+ on_done_after_calls = [self._release_semaphore]
+ coordinator = CRTTransferCoordinator(transfer_id=self._id_counter)
+ components = {
+ 'meta': CRTTransferMeta(self._id_counter, call_args),
+ 'coordinator': coordinator,
+ }
+ future = CRTTransferFuture(**components)
+ afterdone = AfterDoneHandler(coordinator)
+ on_done_after_calls.append(afterdone)
+
+ try:
+ self._semaphore.acquire()
+ on_queued = self._s3_args_creator.get_crt_callback(
+ future, 'queued'
+ )
+ on_queued()
+ crt_callargs = self._s3_args_creator.get_make_request_args(
+ request_type,
+ call_args,
+ coordinator,
+ future,
+ on_done_after_calls,
+ )
+ crt_s3_request = self._crt_s3_client.make_request(**crt_callargs)
+ except Exception as e:
+ coordinator.set_exception(e, True)
+ on_done = self._s3_args_creator.get_crt_callback(
+ future, 'done', after_subscribers=on_done_after_calls
+ )
+ on_done(error=e)
+ else:
+ coordinator.set_s3_request(crt_s3_request)
+ self._future_coordinators.append(coordinator)
+
+ self._id_counter += 1
+ return future
+
+
+class CRTTransferMeta(BaseTransferMeta):
+ """Holds metadata about the CRTTransferFuture"""
+
+ def __init__(self, transfer_id=None, call_args=None):
+ self._transfer_id = transfer_id
+ self._call_args = call_args
+ self._user_context = {}
+
+ @property
+ def call_args(self):
+ return self._call_args
+
+ @property
+ def transfer_id(self):
+ return self._transfer_id
+
+ @property
+ def user_context(self):
+ return self._user_context
+
+
+class CRTTransferFuture(BaseTransferFuture):
+ def __init__(self, meta=None, coordinator=None):
+ """The future associated to a submitted transfer request via CRT S3 client
+
+ :type meta: s3transfer.crt.CRTTransferMeta
+ :param meta: The metadata associated to the transfer future.
+
+ :type coordinator: s3transfer.crt.CRTTransferCoordinator
+ :param coordinator: The coordinator associated to the transfer future.
+ """
+ self._meta = meta
+ if meta is None:
+ self._meta = CRTTransferMeta()
+ self._coordinator = coordinator
+
+ @property
+ def meta(self):
+ return self._meta
+
+ def done(self):
+ return self._coordinator.done()
+
+ def result(self, timeout=None):
+ self._coordinator.result(timeout)
+
+ def cancel(self):
+ self._coordinator.cancel()
+
+ def set_exception(self, exception):
+ """Sets the exception on the future."""
+ if not self.done():
+ raise TransferNotDoneError(
+ 'set_exception can only be called once the transfer is '
+ 'complete.'
+ )
+ self._coordinator.set_exception(exception, override=True)
+
+
+class BaseCRTRequestSerializer:
+ def serialize_http_request(self, transfer_type, future):
+ """Serialize CRT HTTP requests.
+
+ :type transfer_type: string
+ :param transfer_type: the type of transfer made,
+ e.g 'put_object', 'get_object', 'delete_object'
+
+ :type future: s3transfer.crt.CRTTransferFuture
+
+ :rtype: awscrt.http.HttpRequest
+ :returns: An unsigned HTTP request to be used for the CRT S3 client
+ """
+ raise NotImplementedError('serialize_http_request()')
+
+
+class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer):
+ def __init__(self, session, client_kwargs=None):
+ """Serialize CRT HTTP request using botocore logic
+ It also takes into account configuration from both the session
+ and any keyword arguments that could be passed to
+ `Session.create_client()` when serializing the request.
+
+ :type session: botocore.session.Session
+
+ :type client_kwargs: Optional[Dict[str, str]])
+ :param client_kwargs: The kwargs for the botocore
+ s3 client initialization.
+ """
+ self._session = session
+ if client_kwargs is None:
+ client_kwargs = {}
+ self._resolve_client_config(session, client_kwargs)
+ self._client = session.create_client(**client_kwargs)
+ self._client.meta.events.register(
+ 'request-created.s3.*', self._capture_http_request
+ )
+ self._client.meta.events.register(
+ 'after-call.s3.*', self._change_response_to_serialized_http_request
+ )
+ self._client.meta.events.register(
+ 'before-send.s3.*', self._make_fake_http_response
+ )
+
+ def _resolve_client_config(self, session, client_kwargs):
+ user_provided_config = None
+ if session.get_default_client_config():
+ user_provided_config = session.get_default_client_config()
+ if 'config' in client_kwargs:
+ user_provided_config = client_kwargs['config']
+
+ client_config = Config(signature_version=UNSIGNED)
+ if user_provided_config:
+ client_config = user_provided_config.merge(client_config)
+ client_kwargs['config'] = client_config
+ client_kwargs["service_name"] = "s3"
+
+ def _crt_request_from_aws_request(self, aws_request):
+ url_parts = urlsplit(aws_request.url)
+ crt_path = url_parts.path
+ if url_parts.query:
+ crt_path = f'{crt_path}?{url_parts.query}'
+ headers_list = []
+ for name, value in aws_request.headers.items():
+ if isinstance(value, str):
+ headers_list.append((name, value))
+ else:
+ headers_list.append((name, str(value, 'utf-8')))
+
+ crt_headers = awscrt.http.HttpHeaders(headers_list)
+ # CRT requires body (if it exists) to be an I/O stream.
+ crt_body_stream = None
+ if aws_request.body:
+ if hasattr(aws_request.body, 'seek'):
+ crt_body_stream = aws_request.body
+ else:
+ crt_body_stream = BytesIO(aws_request.body)
+
+ crt_request = awscrt.http.HttpRequest(
+ method=aws_request.method,
+ path=crt_path,
+ headers=crt_headers,
+ body_stream=crt_body_stream,
+ )
+ return crt_request
+
+ def _convert_to_crt_http_request(self, botocore_http_request):
+ # Logic that does CRTUtils.crt_request_from_aws_request
+ crt_request = self._crt_request_from_aws_request(botocore_http_request)
+ if crt_request.headers.get("host") is None:
+ # If host is not set, set it for the request before using CRT s3
+ url_parts = urlsplit(botocore_http_request.url)
+ crt_request.headers.set("host", url_parts.netloc)
+ if crt_request.headers.get('Content-MD5') is not None:
+ crt_request.headers.remove("Content-MD5")
+ return crt_request
+
+ def _capture_http_request(self, request, **kwargs):
+ request.context['http_request'] = request
+
+ def _change_response_to_serialized_http_request(
+ self, context, parsed, **kwargs
+ ):
+ request = context['http_request']
+ parsed['HTTPRequest'] = request.prepare()
+
+ def _make_fake_http_response(self, request, **kwargs):
+ return botocore.awsrequest.AWSResponse(
+ None,
+ 200,
+ {},
+ FakeRawResponse(b""),
+ )
+
+ def _get_botocore_http_request(self, client_method, call_args):
+ return getattr(self._client, client_method)(
+ Bucket=call_args.bucket, Key=call_args.key, **call_args.extra_args
+ )['HTTPRequest']
+
+ def serialize_http_request(self, transfer_type, future):
+ botocore_http_request = self._get_botocore_http_request(
+ transfer_type, future.meta.call_args
+ )
+ crt_request = self._convert_to_crt_http_request(botocore_http_request)
+ return crt_request
+
+
+class FakeRawResponse(BytesIO):
+ def stream(self, amt=1024, decode_content=None):
+ while True:
+ chunk = self.read(amt)
+ if not chunk:
+ break
+ yield chunk
+
+
+class CRTTransferCoordinator:
+ """A helper class for managing CRTTransferFuture"""
+
+ def __init__(self, transfer_id=None, s3_request=None):
+ self.transfer_id = transfer_id
+ self._s3_request = s3_request
+ self._lock = threading.Lock()
+ self._exception = None
+ self._crt_future = None
+ self._done_event = threading.Event()
+
+ @property
+ def s3_request(self):
+ return self._s3_request
+
+ def set_done_callbacks_complete(self):
+ self._done_event.set()
+
+ def wait_until_on_done_callbacks_complete(self, timeout=None):
+ self._done_event.wait(timeout)
+
+ def set_exception(self, exception, override=False):
+ with self._lock:
+ if not self.done() or override:
+ self._exception = exception
+
+ def cancel(self):
+ if self._s3_request:
+ self._s3_request.cancel()
+
+ def result(self, timeout=None):
+ if self._exception:
+ raise self._exception
+ try:
+ self._crt_future.result(timeout)
+ except KeyboardInterrupt:
+ self.cancel()
+ raise
+ finally:
+ if self._s3_request:
+ self._s3_request = None
+ self._crt_future.result(timeout)
+
+ def done(self):
+ if self._crt_future is None:
+ return False
+ return self._crt_future.done()
+
+ def set_s3_request(self, s3_request):
+ self._s3_request = s3_request
+ self._crt_future = self._s3_request.finished_future
+
+
+class S3ClientArgsCreator:
+ def __init__(self, crt_request_serializer, os_utils):
+ self._request_serializer = crt_request_serializer
+ self._os_utils = os_utils
+
+ def get_make_request_args(
+ self, request_type, call_args, coordinator, future, on_done_after_calls
+ ):
+ recv_filepath = None
+ send_filepath = None
+ s3_meta_request_type = getattr(
+ S3RequestType, request_type.upper(), S3RequestType.DEFAULT
+ )
+ on_done_before_calls = []
+ if s3_meta_request_type == S3RequestType.GET_OBJECT:
+ final_filepath = call_args.fileobj
+ recv_filepath = self._os_utils.get_temp_filename(final_filepath)
+ file_ondone_call = RenameTempFileHandler(
+ coordinator, final_filepath, recv_filepath, self._os_utils
+ )
+ on_done_before_calls.append(file_ondone_call)
+ elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
+ send_filepath = call_args.fileobj
+ data_len = self._os_utils.get_file_size(send_filepath)
+ call_args.extra_args["ContentLength"] = data_len
+
+ crt_request = self._request_serializer.serialize_http_request(
+ request_type, future
+ )
+
+ return {
+ 'request': crt_request,
+ 'type': s3_meta_request_type,
+ 'recv_filepath': recv_filepath,
+ 'send_filepath': send_filepath,
+ 'on_done': self.get_crt_callback(
+ future, 'done', on_done_before_calls, on_done_after_calls
+ ),
+ 'on_progress': self.get_crt_callback(future, 'progress'),
+ }
+
+ def get_crt_callback(
+ self,
+ future,
+ callback_type,
+ before_subscribers=None,
+ after_subscribers=None,
+ ):
+ def invoke_all_callbacks(*args, **kwargs):
+ callbacks_list = []
+ if before_subscribers is not None:
+ callbacks_list += before_subscribers
+ callbacks_list += get_callbacks(future, callback_type)
+ if after_subscribers is not None:
+ callbacks_list += after_subscribers
+ for callback in callbacks_list:
+ # The get_callbacks helper will set the first augment
+ # by keyword, the other augments need to be set by keyword
+ # as well
+ if callback_type == "progress":
+ callback(bytes_transferred=args[0])
+ else:
+ callback(*args, **kwargs)
+
+ return invoke_all_callbacks
+
+
+class RenameTempFileHandler:
+ def __init__(self, coordinator, final_filename, temp_filename, osutil):
+ self._coordinator = coordinator
+ self._final_filename = final_filename
+ self._temp_filename = temp_filename
+ self._osutil = osutil
+
+ def __call__(self, **kwargs):
+ error = kwargs['error']
+ if error:
+ self._osutil.remove_file(self._temp_filename)
+ else:
+ try:
+ self._osutil.rename_file(
+ self._temp_filename, self._final_filename
+ )
+ except Exception as e:
+ self._osutil.remove_file(self._temp_filename)
+ # the CRT future has done already at this point
+ self._coordinator.set_exception(e)
+
+
+class AfterDoneHandler:
+ def __init__(self, coordinator):
+ self._coordinator = coordinator
+
+ def __call__(self, **kwargs):
+ self._coordinator.set_done_callbacks_complete()
diff --git a/contrib/python/s3transfer/py3/s3transfer/delete.py b/contrib/python/s3transfer/py3/s3transfer/delete.py
index e179d45d4c..74084d312a 100644
--- a/contrib/python/s3transfer/py3/s3transfer/delete.py
+++ b/contrib/python/s3transfer/py3/s3transfer/delete.py
@@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
-from s3transfer.tasks import SubmissionTask, Task
+from s3transfer.tasks import SubmissionTask, Task
class DeleteSubmissionTask(SubmissionTask):
@@ -47,8 +47,8 @@ class DeleteSubmissionTask(SubmissionTask):
'key': call_args.key,
'extra_args': call_args.extra_args,
},
- is_final=True,
- ),
+ is_final=True,
+ ),
)
diff --git a/contrib/python/s3transfer/py3/s3transfer/download.py b/contrib/python/s3transfer/py3/s3transfer/download.py
index 911912cab5..dc8980d4ed 100644
--- a/contrib/python/s3transfer/py3/s3transfer/download.py
+++ b/contrib/python/s3transfer/py3/s3transfer/download.py
@@ -10,30 +10,30 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
-import heapq
+import heapq
import logging
import threading
from s3transfer.compat import seekable
from s3transfer.exceptions import RetriesExceededError
from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG
-from s3transfer.tasks import SubmissionTask, Task
-from s3transfer.utils import (
- S3_RETRYABLE_DOWNLOAD_ERRORS,
- CountCallbackInvoker,
- DeferredOpenFile,
- FunctionContainer,
- StreamReaderProgress,
- calculate_num_parts,
- calculate_range_parameter,
- get_callbacks,
- invoke_progress_callbacks,
-)
+from s3transfer.tasks import SubmissionTask, Task
+from s3transfer.utils import (
+ S3_RETRYABLE_DOWNLOAD_ERRORS,
+ CountCallbackInvoker,
+ DeferredOpenFile,
+ FunctionContainer,
+ StreamReaderProgress,
+ calculate_num_parts,
+ calculate_range_parameter,
+ get_callbacks,
+ invoke_progress_callbacks,
+)
logger = logging.getLogger(__name__)
-class DownloadOutputManager:
+class DownloadOutputManager:
"""Base manager class for handling various types of files for downloads
This class is typically used for the DownloadSubmissionTask class to help
@@ -46,7 +46,7 @@ class DownloadOutputManager:
that may be accepted. All implementations must subclass and override
public methods from this class.
"""
-
+
def __init__(self, osutil, transfer_coordinator, io_executor):
self._osutil = osutil
self._transfer_coordinator = transfer_coordinator
@@ -94,8 +94,8 @@ class DownloadOutputManager:
"""
self._transfer_coordinator.submit(
- self._io_executor, self.get_io_write_task(fileobj, data, offset)
- )
+ self._io_executor, self.get_io_write_task(fileobj, data, offset)
+ )
def get_io_write_task(self, fileobj, data, offset):
"""Get an IO write task for the requested set of data
@@ -120,7 +120,7 @@ class DownloadOutputManager:
'fileobj': fileobj,
'data': data,
'offset': offset,
- },
+ },
)
def get_final_io_task(self):
@@ -134,12 +134,12 @@ class DownloadOutputManager:
:rtype: s3transfer.tasks.Task
:returns: A final task to completed in the io executor
"""
- raise NotImplementedError('must implement get_final_io_task()')
+ raise NotImplementedError('must implement get_final_io_task()')
def _get_fileobj_from_filename(self, filename):
f = DeferredOpenFile(
- filename, mode='wb', open_function=self._osutil.open
- )
+ filename, mode='wb', open_function=self._osutil.open
+ )
# Make sure the file gets closed and we remove the temporary file
# if anything goes wrong during the process.
self._transfer_coordinator.add_failure_cleanup(f.close)
@@ -148,19 +148,19 @@ class DownloadOutputManager:
class DownloadFilenameOutputManager(DownloadOutputManager):
def __init__(self, osutil, transfer_coordinator, io_executor):
- super().__init__(osutil, transfer_coordinator, io_executor)
+ super().__init__(osutil, transfer_coordinator, io_executor)
self._final_filename = None
self._temp_filename = None
self._temp_fileobj = None
@classmethod
def is_compatible(cls, download_target, osutil):
- return isinstance(download_target, str)
+ return isinstance(download_target, str)
def get_fileobj_for_io_writes(self, transfer_future):
fileobj = transfer_future.meta.call_args.fileobj
self._final_filename = fileobj
- self._temp_filename = self._osutil.get_temp_filename(fileobj)
+ self._temp_filename = self._osutil.get_temp_filename(fileobj)
self._temp_fileobj = self._get_temp_fileobj()
return self._temp_fileobj
@@ -173,16 +173,16 @@ class DownloadFilenameOutputManager(DownloadOutputManager):
main_kwargs={
'fileobj': self._temp_fileobj,
'final_filename': self._final_filename,
- 'osutil': self._osutil,
+ 'osutil': self._osutil,
},
- is_final=True,
+ is_final=True,
)
def _get_temp_fileobj(self):
f = self._get_fileobj_from_filename(self._temp_filename)
self._transfer_coordinator.add_failure_cleanup(
- self._osutil.remove_file, self._temp_filename
- )
+ self._osutil.remove_file, self._temp_filename
+ )
return f
@@ -199,15 +199,15 @@ class DownloadSeekableOutputManager(DownloadOutputManager):
# This task will serve the purpose of signaling when all of the io
# writes have finished so done callbacks can be called.
return CompleteDownloadNOOPTask(
- transfer_coordinator=self._transfer_coordinator
- )
+ transfer_coordinator=self._transfer_coordinator
+ )
class DownloadNonSeekableOutputManager(DownloadOutputManager):
- def __init__(
- self, osutil, transfer_coordinator, io_executor, defer_queue=None
- ):
- super().__init__(osutil, transfer_coordinator, io_executor)
+ def __init__(
+ self, osutil, transfer_coordinator, io_executor, defer_queue=None
+ ):
+ super().__init__(osutil, transfer_coordinator, io_executor)
if defer_queue is None:
defer_queue = DeferQueue()
self._defer_queue = defer_queue
@@ -225,20 +225,20 @@ class DownloadNonSeekableOutputManager(DownloadOutputManager):
def get_final_io_task(self):
return CompleteDownloadNOOPTask(
- transfer_coordinator=self._transfer_coordinator
- )
+ transfer_coordinator=self._transfer_coordinator
+ )
def queue_file_io_task(self, fileobj, data, offset):
with self._io_submit_lock:
writes = self._defer_queue.request_writes(offset, data)
for write in writes:
data = write['data']
- logger.debug(
- "Queueing IO offset %s for fileobj: %s",
- write['offset'],
- fileobj,
- )
- super().queue_file_io_task(fileobj, data, offset)
+ logger.debug(
+ "Queueing IO offset %s for fileobj: %s",
+ write['offset'],
+ fileobj,
+ )
+ super().queue_file_io_task(fileobj, data, offset)
def get_io_write_task(self, fileobj, data, offset):
return IOStreamingWriteTask(
@@ -246,24 +246,24 @@ class DownloadNonSeekableOutputManager(DownloadOutputManager):
main_kwargs={
'fileobj': fileobj,
'data': data,
- },
+ },
)
class DownloadSpecialFilenameOutputManager(DownloadNonSeekableOutputManager):
- def __init__(
- self, osutil, transfer_coordinator, io_executor, defer_queue=None
- ):
- super().__init__(
- osutil, transfer_coordinator, io_executor, defer_queue
- )
+ def __init__(
+ self, osutil, transfer_coordinator, io_executor, defer_queue=None
+ ):
+ super().__init__(
+ osutil, transfer_coordinator, io_executor, defer_queue
+ )
self._fileobj = None
@classmethod
def is_compatible(cls, download_target, osutil):
- return isinstance(download_target, str) and osutil.is_special_file(
- download_target
- )
+ return isinstance(download_target, str) and osutil.is_special_file(
+ download_target
+ )
def get_fileobj_for_io_writes(self, transfer_future):
filename = transfer_future.meta.call_args.fileobj
@@ -275,8 +275,8 @@ class DownloadSpecialFilenameOutputManager(DownloadNonSeekableOutputManager):
return IOCloseTask(
transfer_coordinator=self._transfer_coordinator,
is_final=True,
- main_kwargs={'fileobj': self._fileobj},
- )
+ main_kwargs={'fileobj': self._fileobj},
+ )
class DownloadSubmissionTask(SubmissionTask):
@@ -307,21 +307,21 @@ class DownloadSubmissionTask(SubmissionTask):
if download_manager_cls.is_compatible(fileobj, osutil):
return download_manager_cls
raise RuntimeError(
- 'Output {} of type: {} is not supported.'.format(
- fileobj, type(fileobj)
- )
- )
-
- def _submit(
- self,
- client,
- config,
- osutil,
- request_executor,
- io_executor,
- transfer_future,
- bandwidth_limiter=None,
- ):
+ 'Output {} of type: {} is not supported.'.format(
+ fileobj, type(fileobj)
+ )
+ )
+
+ def _submit(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ transfer_future,
+ bandwidth_limiter=None,
+ ):
"""
:param client: The client associated with the transfer manager
@@ -354,59 +354,59 @@ class DownloadSubmissionTask(SubmissionTask):
response = client.head_object(
Bucket=transfer_future.meta.call_args.bucket,
Key=transfer_future.meta.call_args.key,
- **transfer_future.meta.call_args.extra_args,
+ **transfer_future.meta.call_args.extra_args,
)
transfer_future.meta.provide_transfer_size(
- response['ContentLength']
- )
+ response['ContentLength']
+ )
download_output_manager = self._get_download_output_manager_cls(
- transfer_future, osutil
- )(osutil, self._transfer_coordinator, io_executor)
+ transfer_future, osutil
+ )(osutil, self._transfer_coordinator, io_executor)
# If it is greater than threshold do a ranged download, otherwise
# do a regular GetObject download.
if transfer_future.meta.size < config.multipart_threshold:
self._submit_download_request(
- client,
- config,
- osutil,
- request_executor,
- io_executor,
- download_output_manager,
- transfer_future,
- bandwidth_limiter,
- )
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ )
else:
self._submit_ranged_download_request(
- client,
- config,
- osutil,
- request_executor,
- io_executor,
- download_output_manager,
- transfer_future,
- bandwidth_limiter,
- )
-
- def _submit_download_request(
- self,
- client,
- config,
- osutil,
- request_executor,
- io_executor,
- download_output_manager,
- transfer_future,
- bandwidth_limiter,
- ):
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ )
+
+ def _submit_download_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ ):
call_args = transfer_future.meta.call_args
# Get a handle to the file that will be used for writing downloaded
# contents
fileobj = download_output_manager.get_fileobj_for_io_writes(
- transfer_future
- )
+ transfer_future
+ )
# Get the needed callbacks for the task
progress_callbacks = get_callbacks(transfer_future, 'progress')
@@ -432,24 +432,24 @@ class DownloadSubmissionTask(SubmissionTask):
'max_attempts': config.num_download_attempts,
'download_output_manager': download_output_manager,
'io_chunksize': config.io_chunksize,
- 'bandwidth_limiter': bandwidth_limiter,
+ 'bandwidth_limiter': bandwidth_limiter,
},
- done_callbacks=[final_task],
+ done_callbacks=[final_task],
),
- tag=get_object_tag,
+ tag=get_object_tag,
)
- def _submit_ranged_download_request(
- self,
- client,
- config,
- osutil,
- request_executor,
- io_executor,
- download_output_manager,
- transfer_future,
- bandwidth_limiter,
- ):
+ def _submit_ranged_download_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ io_executor,
+ download_output_manager,
+ transfer_future,
+ bandwidth_limiter,
+ ):
call_args = transfer_future.meta.call_args
# Get the needed progress callbacks for the task
@@ -458,12 +458,12 @@ class DownloadSubmissionTask(SubmissionTask):
# Get a handle to the file that will be used for writing downloaded
# contents
fileobj = download_output_manager.get_fileobj_for_io_writes(
- transfer_future
- )
+ transfer_future
+ )
# Determine the number of parts
part_size = config.multipart_chunksize
- num_parts = calculate_num_parts(transfer_future.meta.size, part_size)
+ num_parts = calculate_num_parts(transfer_future.meta.size, part_size)
# Get any associated tags for the get object task.
get_object_tag = download_output_manager.get_download_task_tag()
@@ -478,8 +478,8 @@ class DownloadSubmissionTask(SubmissionTask):
for i in range(num_parts):
# Calculate the range parameter
range_parameter = calculate_range_parameter(
- part_size, i, num_parts
- )
+ part_size, i, num_parts
+ )
# Inject the Range parameter to the parameters to be passed in
# as extra args
@@ -502,21 +502,21 @@ class DownloadSubmissionTask(SubmissionTask):
'start_index': i * part_size,
'download_output_manager': download_output_manager,
'io_chunksize': config.io_chunksize,
- 'bandwidth_limiter': bandwidth_limiter,
+ 'bandwidth_limiter': bandwidth_limiter,
},
- done_callbacks=[finalize_download_invoker.decrement],
+ done_callbacks=[finalize_download_invoker.decrement],
),
- tag=get_object_tag,
+ tag=get_object_tag,
)
finalize_download_invoker.finalize()
- def _get_final_io_task_submission_callback(
- self, download_manager, io_executor
- ):
+ def _get_final_io_task_submission_callback(
+ self, download_manager, io_executor
+ ):
final_task = download_manager.get_final_io_task()
return FunctionContainer(
- self._transfer_coordinator.submit, io_executor, final_task
- )
+ self._transfer_coordinator.submit, io_executor, final_task
+ )
def _calculate_range_param(self, part_size, part_index, num_parts):
# Used to calculate the Range parameter
@@ -525,32 +525,32 @@ class DownloadSubmissionTask(SubmissionTask):
end_range = ''
else:
end_range = start_range + part_size - 1
- range_param = f'bytes={start_range}-{end_range}'
+ range_param = f'bytes={start_range}-{end_range}'
return range_param
class GetObjectTask(Task):
- def _main(
- self,
- client,
- bucket,
- key,
- fileobj,
- extra_args,
- callbacks,
- max_attempts,
- download_output_manager,
- io_chunksize,
- start_index=0,
- bandwidth_limiter=None,
- ):
+ def _main(
+ self,
+ client,
+ bucket,
+ key,
+ fileobj,
+ extra_args,
+ callbacks,
+ max_attempts,
+ download_output_manager,
+ io_chunksize,
+ start_index=0,
+ bandwidth_limiter=None,
+ ):
"""Downloads an object and places content into io queue
:param client: The client to use when calling GetObject
:param bucket: The bucket to download from
:param key: The key to download from
:param fileobj: The file handle to write content to
- :param exta_args: Any extra arguments to include in GetObject request
+ :param exta_args: Any extra arguments to include in GetObject request
:param callbacks: List of progress callbacks to invoke on download
:param max_attempts: The number of retries to do when downloading
:param download_output_manager: The download output manager associated
@@ -565,19 +565,19 @@ class GetObjectTask(Task):
last_exception = None
for i in range(max_attempts):
try:
- current_index = start_index
+ current_index = start_index
response = client.get_object(
- Bucket=bucket, Key=key, **extra_args
- )
+ Bucket=bucket, Key=key, **extra_args
+ )
streaming_body = StreamReaderProgress(
- response['Body'], callbacks
- )
+ response['Body'], callbacks
+ )
if bandwidth_limiter:
- streaming_body = (
+ streaming_body = (
bandwidth_limiter.get_bandwith_limited_stream(
- streaming_body, self._transfer_coordinator
- )
- )
+ streaming_body, self._transfer_coordinator
+ )
+ )
chunks = DownloadChunkIterator(streaming_body, io_chunksize)
for chunk in chunks:
@@ -586,31 +586,31 @@ class GetObjectTask(Task):
# data to be written and break out of the download.
if not self._transfer_coordinator.done():
self._handle_io(
- download_output_manager,
- fileobj,
- chunk,
- current_index,
+ download_output_manager,
+ fileobj,
+ chunk,
+ current_index,
)
current_index += len(chunk)
else:
return
return
- except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
- logger.debug(
- "Retrying exception caught (%s), "
- "retrying request, (attempt %s / %s)",
- e,
- i,
- max_attempts,
- exc_info=True,
- )
+ except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
+ logger.debug(
+ "Retrying exception caught (%s), "
+ "retrying request, (attempt %s / %s)",
+ e,
+ i,
+ max_attempts,
+ exc_info=True,
+ )
last_exception = e
# Also invoke the progress callbacks to indicate that we
# are trying to download the stream again and all progress
# for this GetObject has been lost.
invoke_progress_callbacks(
- callbacks, start_index - current_index
- )
+ callbacks, start_index - current_index
+ )
continue
raise RetriesExceededError(last_exception)
@@ -625,7 +625,7 @@ class ImmediatelyWriteIOGetObjectTask(GetObjectTask):
downloading the object so there is no reason to go through the
overhead of using an IO queue and executor.
"""
-
+
def _handle_io(self, download_output_manager, fileobj, chunk, index):
task = download_output_manager.get_io_write_task(fileobj, chunk, index)
task()
@@ -649,7 +649,7 @@ class IOStreamingWriteTask(Task):
def _main(self, fileobj, data):
"""Write data to a fileobj.
- Data will be written directly to the fileobj without
+ Data will be written directly to the fileobj without
any prior seeking.
:param fileobj: The fileobj to write content to
@@ -667,7 +667,7 @@ class IORenameFileTask(Task):
upon completion of writing the contents.
:param osutil: OS utility
"""
-
+
def _main(self, fileobj, final_filename, osutil):
fileobj.close()
osutil.rename_file(fileobj.name, final_filename)
@@ -678,7 +678,7 @@ class IOCloseTask(Task):
:param fileobj: The fileobj to close.
"""
-
+
def _main(self, fileobj):
fileobj.close()
@@ -689,28 +689,28 @@ class CompleteDownloadNOOPTask(Task):
Note that the default for is_final is set to True because this should
always be the last task.
"""
-
- def __init__(
- self,
- transfer_coordinator,
- main_kwargs=None,
- pending_main_kwargs=None,
- done_callbacks=None,
- is_final=True,
- ):
- super().__init__(
+
+ def __init__(
+ self,
+ transfer_coordinator,
+ main_kwargs=None,
+ pending_main_kwargs=None,
+ done_callbacks=None,
+ is_final=True,
+ ):
+ super().__init__(
transfer_coordinator=transfer_coordinator,
main_kwargs=main_kwargs,
pending_main_kwargs=pending_main_kwargs,
done_callbacks=done_callbacks,
- is_final=is_final,
+ is_final=is_final,
)
def _main(self):
pass
-class DownloadChunkIterator:
+class DownloadChunkIterator:
def __init__(self, body, chunksize):
"""Iterator to chunk out a downloaded S3 stream
@@ -732,7 +732,7 @@ class DownloadChunkIterator:
elif self._num_reads == 1:
# Even though the response may have not had any
# content, we still want to account for an empty object's
- # existence so return the empty chunk for that initial
+ # existence so return the empty chunk for that initial
# read.
return chunk
raise StopIteration()
@@ -740,7 +740,7 @@ class DownloadChunkIterator:
next = __next__
-class DeferQueue:
+class DeferQueue:
"""IO queue that defers write requests until they are queued sequentially.
This class is used to track IO data for a *single* fileobj.
@@ -749,7 +749,7 @@ class DeferQueue:
until it has the next contiguous block available (starting at 0).
"""
-
+
def __init__(self):
self._writes = []
self._pending_offsets = set()
diff --git a/contrib/python/s3transfer/py3/s3transfer/exceptions.py b/contrib/python/s3transfer/py3/s3transfer/exceptions.py
index 2b4ebea66c..6150fe650d 100644
--- a/contrib/python/s3transfer/py3/s3transfer/exceptions.py
+++ b/contrib/python/s3transfer/py3/s3transfer/exceptions.py
@@ -15,7 +15,7 @@ from concurrent.futures import CancelledError
class RetriesExceededError(Exception):
def __init__(self, last_exception, msg='Max Retries Exceeded'):
- super().__init__(msg)
+ super().__init__(msg)
self.last_exception = last_exception
@@ -33,5 +33,5 @@ class TransferNotDoneError(Exception):
class FatalError(CancelledError):
"""A CancelledError raised from an error in the TransferManager"""
-
+
pass
diff --git a/contrib/python/s3transfer/py3/s3transfer/futures.py b/contrib/python/s3transfer/py3/s3transfer/futures.py
index 5fe4a7c624..39e071fb60 100644
--- a/contrib/python/s3transfer/py3/s3transfer/futures.py
+++ b/contrib/python/s3transfer/py3/s3transfer/futures.py
@@ -14,61 +14,61 @@ import copy
import logging
import sys
import threading
-from collections import namedtuple
-from concurrent import futures
+from collections import namedtuple
+from concurrent import futures
from s3transfer.compat import MAXINT
from s3transfer.exceptions import CancelledError, TransferNotDoneError
-from s3transfer.utils import FunctionContainer, TaskSemaphore
+from s3transfer.utils import FunctionContainer, TaskSemaphore
logger = logging.getLogger(__name__)
-class BaseTransferFuture:
- @property
- def meta(self):
- """The metadata associated to the TransferFuture"""
- raise NotImplementedError('meta')
-
- def done(self):
- """Determines if a TransferFuture has completed
-
- :returns: True if completed. False, otherwise.
- """
- raise NotImplementedError('done()')
-
- def result(self):
- """Waits until TransferFuture is done and returns the result
-
- If the TransferFuture succeeded, it will return the result. If the
- TransferFuture failed, it will raise the exception associated to the
- failure.
- """
- raise NotImplementedError('result()')
-
- def cancel(self):
- """Cancels the request associated with the TransferFuture"""
- raise NotImplementedError('cancel()')
-
-
-class BaseTransferMeta:
- @property
- def call_args(self):
- """The call args used in the transfer request"""
- raise NotImplementedError('call_args')
-
- @property
- def transfer_id(self):
- """The unique id of the transfer"""
- raise NotImplementedError('transfer_id')
-
- @property
- def user_context(self):
- """A dictionary that requesters can store data in"""
- raise NotImplementedError('user_context')
-
-
-class TransferFuture(BaseTransferFuture):
+class BaseTransferFuture:
+ @property
+ def meta(self):
+ """The metadata associated to the TransferFuture"""
+ raise NotImplementedError('meta')
+
+ def done(self):
+ """Determines if a TransferFuture has completed
+
+ :returns: True if completed. False, otherwise.
+ """
+ raise NotImplementedError('done()')
+
+ def result(self):
+ """Waits until TransferFuture is done and returns the result
+
+ If the TransferFuture succeeded, it will return the result. If the
+ TransferFuture failed, it will raise the exception associated to the
+ failure.
+ """
+ raise NotImplementedError('result()')
+
+ def cancel(self):
+ """Cancels the request associated with the TransferFuture"""
+ raise NotImplementedError('cancel()')
+
+
+class BaseTransferMeta:
+ @property
+ def call_args(self):
+ """The call args used in the transfer request"""
+ raise NotImplementedError('call_args')
+
+ @property
+ def transfer_id(self):
+ """The unique id of the transfer"""
+ raise NotImplementedError('transfer_id')
+
+ @property
+ def user_context(self):
+ """A dictionary that requesters can store data in"""
+ raise NotImplementedError('user_context')
+
+
+class TransferFuture(BaseTransferFuture):
def __init__(self, meta=None, coordinator=None):
"""The future associated to a submitted transfer request
@@ -99,7 +99,7 @@ class TransferFuture(BaseTransferFuture):
try:
# Usually the result() method blocks until the transfer is done,
# however if a KeyboardInterrupt is raised we want want to exit
- # out of this and propagate the exception.
+ # out of this and propagate the exception.
return self._coordinator.result()
except KeyboardInterrupt as e:
self.cancel()
@@ -113,14 +113,14 @@ class TransferFuture(BaseTransferFuture):
if not self.done():
raise TransferNotDoneError(
'set_exception can only be called once the transfer is '
- 'complete.'
- )
+ 'complete.'
+ )
self._coordinator.set_exception(exception, override=True)
-class TransferMeta(BaseTransferMeta):
+class TransferMeta(BaseTransferMeta):
"""Holds metadata about the TransferFuture"""
-
+
def __init__(self, call_args=None, transfer_id=None):
self._call_args = call_args
self._transfer_id = transfer_id
@@ -157,9 +157,9 @@ class TransferMeta(BaseTransferMeta):
self._size = size
-class TransferCoordinator:
+class TransferCoordinator:
"""A helper class for managing TransferFuture"""
-
+
def __init__(self, transfer_id=None):
self.transfer_id = transfer_id
self._status = 'not-started'
@@ -175,9 +175,9 @@ class TransferCoordinator:
self._failure_cleanups_lock = threading.Lock()
def __repr__(self):
- return '{}(transfer_id={})'.format(
- self.__class__.__name__, self.transfer_id
- )
+ return '{}(transfer_id={})'.format(
+ self.__class__.__name__, self.transfer_id
+ )
@property
def exception(self):
@@ -296,8 +296,8 @@ class TransferCoordinator:
if self.done():
raise RuntimeError(
'Unable to transition from done state %s to non-done '
- 'state %s.' % (self.status, desired_state)
- )
+ 'state %s.' % (self.status, desired_state)
+ )
self._status = desired_state
def submit(self, executor, task, tag=None):
@@ -316,17 +316,17 @@ class TransferCoordinator:
:returns: A future representing the submitted task
"""
logger.debug(
- "Submitting task {} to executor {} for transfer request: {}.".format(
- task, executor, self.transfer_id
- )
+ "Submitting task {} to executor {} for transfer request: {}.".format(
+ task, executor, self.transfer_id
+ )
)
future = executor.submit(task, tag=tag)
# Add this created future to the list of associated future just
# in case it is needed during cleanups.
self.add_associated_future(future)
future.add_done_callback(
- FunctionContainer(self.remove_associated_future, future)
- )
+ FunctionContainer(self.remove_associated_future, future)
+ )
return future
def done(self):
@@ -358,8 +358,8 @@ class TransferCoordinator:
"""Adds a callback to call upon failure"""
with self._failure_cleanups_lock:
self._failure_cleanups.append(
- FunctionContainer(function, *args, **kwargs)
- )
+ FunctionContainer(function, *args, **kwargs)
+ )
def announce_done(self):
"""Announce that future is done running and run associated callbacks
@@ -398,18 +398,18 @@ class TransferCoordinator:
try:
callback()
# We do not want a callback interrupting the process, especially
- # in the failure cleanups. So log and catch, the exception.
+ # in the failure cleanups. So log and catch, the exception.
except Exception:
logger.debug("Exception raised in %s." % callback, exc_info=True)
-class BoundedExecutor:
+class BoundedExecutor:
EXECUTOR_CLS = futures.ThreadPoolExecutor
- def __init__(
- self, max_size, max_num_threads, tag_semaphores=None, executor_cls=None
- ):
- """An executor implementation that has a maximum queued up tasks
+ def __init__(
+ self, max_size, max_num_threads, tag_semaphores=None, executor_cls=None
+ ):
+ """An executor implementation that has a maximum queued up tasks
The executor will block if the number of tasks that have been
submitted and is currently working on is past its maximum.
@@ -455,7 +455,7 @@ class BoundedExecutor:
False, if not to wait and raise an error if not able to submit
a task.
- :returns: The future associated to the submitted task
+ :returns: The future associated to the submitted task
"""
semaphore = self._semaphore
# If a tag was provided, use the semaphore associated to that
@@ -468,8 +468,8 @@ class BoundedExecutor:
# Create a callback to invoke when task is done in order to call
# release on the semaphore.
release_callback = FunctionContainer(
- semaphore.release, task.transfer_id, acquire_token
- )
+ semaphore.release, task.transfer_id, acquire_token
+ )
# Submit the task to the underlying executor.
future = ExecutorFuture(self._executor.submit(task))
# Add the Semaphore.release() callback to the future such that
@@ -481,7 +481,7 @@ class BoundedExecutor:
self._executor.shutdown(wait)
-class ExecutorFuture:
+class ExecutorFuture:
def __init__(self, future):
"""A future returned from the executor
@@ -501,7 +501,7 @@ class ExecutorFuture:
def add_done_callback(self, fn):
"""Adds a callback to be completed once future is done
- :param fn: A callable that takes no arguments. Note that is different
+ :param fn: A callable that takes no arguments. Note that is different
than concurrent.futures.Future.add_done_callback that requires
a single argument for the future.
"""
@@ -510,16 +510,16 @@ class ExecutorFuture:
# proper signature wrapper that will invoke the callback provided.
def done_callback(future_passed_to_callback):
return fn()
-
+
self._future.add_done_callback(done_callback)
def done(self):
return self._future.done()
-class BaseExecutor:
+class BaseExecutor:
"""Base Executor class implementation needed to work with s3transfer"""
-
+
def __init__(self, max_workers=None):
pass
@@ -532,7 +532,7 @@ class BaseExecutor:
class NonThreadedExecutor(BaseExecutor):
"""A drop-in replacement non-threaded version of ThreadPoolExecutor"""
-
+
def submit(self, fn, *args, **kwargs):
future = NonThreadedExecutorFuture()
try:
@@ -542,9 +542,9 @@ class NonThreadedExecutor(BaseExecutor):
e, tb = sys.exc_info()[1:]
logger.debug(
'Setting exception for %s to %s with traceback %s',
- future,
- e,
- tb,
+ future,
+ e,
+ tb,
)
future.set_exception_info(e, tb)
return future
@@ -553,13 +553,13 @@ class NonThreadedExecutor(BaseExecutor):
pass
-class NonThreadedExecutorFuture:
+class NonThreadedExecutorFuture:
"""The Future returned from NonThreadedExecutor
Note that this future is **not** thread-safe as it is being used
from the context of a non-threaded environment.
"""
-
+
def __init__(self):
self._result = None
self._exception = None
@@ -578,7 +578,7 @@ class NonThreadedExecutorFuture:
def result(self, timeout=None):
if self._exception:
- raise self._exception.with_traceback(self._traceback)
+ raise self._exception.with_traceback(self._traceback)
return self._result
def _set_done(self):
diff --git a/contrib/python/s3transfer/py3/s3transfer/manager.py b/contrib/python/s3transfer/py3/s3transfer/manager.py
index 095917c426..ff6afa12c1 100644
--- a/contrib/python/s3transfer/py3/s3transfer/manager.py
+++ b/contrib/python/s3transfer/py3/s3transfer/manager.py
@@ -12,54 +12,54 @@
# language governing permissions and limitations under the License.
import copy
import logging
-import re
+import re
import threading
-from s3transfer.bandwidth import BandwidthLimiter, LeakyBucket
-from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, KB, MB
-from s3transfer.copies import CopySubmissionTask
-from s3transfer.delete import DeleteSubmissionTask
+from s3transfer.bandwidth import BandwidthLimiter, LeakyBucket
+from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, KB, MB
+from s3transfer.copies import CopySubmissionTask
+from s3transfer.delete import DeleteSubmissionTask
from s3transfer.download import DownloadSubmissionTask
-from s3transfer.exceptions import CancelledError, FatalError
-from s3transfer.futures import (
- IN_MEMORY_DOWNLOAD_TAG,
- IN_MEMORY_UPLOAD_TAG,
- BoundedExecutor,
- TransferCoordinator,
- TransferFuture,
- TransferMeta,
-)
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import (
+ IN_MEMORY_DOWNLOAD_TAG,
+ IN_MEMORY_UPLOAD_TAG,
+ BoundedExecutor,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
from s3transfer.upload import UploadSubmissionTask
-from s3transfer.utils import (
- CallArgs,
- OSUtils,
- SlidingWindowSemaphore,
- TaskSemaphore,
- get_callbacks,
- signal_not_transferring,
- signal_transferring,
-)
+from s3transfer.utils import (
+ CallArgs,
+ OSUtils,
+ SlidingWindowSemaphore,
+ TaskSemaphore,
+ get_callbacks,
+ signal_not_transferring,
+ signal_transferring,
+)
logger = logging.getLogger(__name__)
-class TransferConfig:
- def __init__(
- self,
- multipart_threshold=8 * MB,
- multipart_chunksize=8 * MB,
- max_request_concurrency=10,
- max_submission_concurrency=5,
- max_request_queue_size=1000,
- max_submission_queue_size=1000,
- max_io_queue_size=1000,
- io_chunksize=256 * KB,
- num_download_attempts=5,
- max_in_memory_upload_chunks=10,
- max_in_memory_download_chunks=10,
- max_bandwidth=None,
- ):
- """Configurations for the transfer manager
+class TransferConfig:
+ def __init__(
+ self,
+ multipart_threshold=8 * MB,
+ multipart_chunksize=8 * MB,
+ max_request_concurrency=10,
+ max_submission_concurrency=5,
+ max_request_queue_size=1000,
+ max_submission_queue_size=1000,
+ max_io_queue_size=1000,
+ io_chunksize=256 * KB,
+ num_download_attempts=5,
+ max_in_memory_upload_chunks=10,
+ max_in_memory_download_chunks=10,
+ max_bandwidth=None,
+ ):
+ """Configurations for the transfer manager
:param multipart_threshold: The threshold for which multipart
transfers occur.
@@ -71,7 +71,7 @@ class TransferConfig:
processing a call to a TransferManager method. Processing a
call usually entails determining which S3 API requests that need
to be enqueued, but does **not** entail making any of the
- S3 API data transferring requests needed to perform the transfer.
+ S3 API data transferring requests needed to perform the transfer.
The threads controlled by ``max_request_concurrency`` is
responsible for that.
@@ -79,14 +79,14 @@ class TransferConfig:
becomes a multipart transfer.
:param max_request_queue_size: The maximum amount of S3 API requests
- that can be queued at a time.
+ that can be queued at a time.
:param max_submission_queue_size: The maximum amount of
- TransferManager method calls that can be queued at a time.
+ TransferManager method calls that can be queued at a time.
:param max_io_queue_size: The maximum amount of read parts that
- can be queued to be written to disk per download. The default
- size for each elementin this queue is 8 KB.
+ can be queued to be written to disk per download. The default
+ size for each elementin this queue is 8 KB.
:param io_chunksize: The max size of each chunk in the io queue.
Currently, this is size used when reading from the downloaded
@@ -94,9 +94,9 @@ class TransferConfig:
:param num_download_attempts: The number of download attempts that
will be tried upon errors with downloading an object in S3. Note
- that these retries account for errors that occur when streaming
+ that these retries account for errors that occur when streaming
down the data from s3 (i.e. socket errors and read timeouts that
- occur after receiving an OK response from s3).
+ occur after receiving an OK response from s3).
Other retryable exceptions such as throttling errors and 5xx errors
are already retried by botocore (this default is 5). The
``num_download_attempts`` does not take into account the
@@ -120,7 +120,7 @@ class TransferConfig:
:param max_in_memory_download_chunks: The number of chunks that can
be buffered in memory and **not** in the io queue at a time for all
- ongoing download requests. This pertains specifically to file-like
+ ongoing download requests. This pertains specifically to file-like
objects that cannot be seeked. The total maximum memory footprint
due to a in-memory download chunks is roughly equal to:
@@ -145,16 +145,16 @@ class TransferConfig:
self._validate_attrs_are_nonzero()
def _validate_attrs_are_nonzero(self):
- for attr, attr_val in self.__dict__.items():
+ for attr, attr_val in self.__dict__.items():
if attr_val is not None and attr_val <= 0:
raise ValueError(
'Provided parameter %s of value %s must be greater than '
- '0.' % (attr, attr_val)
- )
+ '0.' % (attr, attr_val)
+ )
-class TransferManager:
- ALLOWED_DOWNLOAD_ARGS = ALLOWED_DOWNLOAD_ARGS
+class TransferManager:
+ ALLOWED_DOWNLOAD_ARGS = ALLOWED_DOWNLOAD_ARGS
ALLOWED_UPLOAD_ARGS = [
'ACL',
@@ -163,7 +163,7 @@ class TransferManager:
'ContentEncoding',
'ContentLanguage',
'ContentType',
- 'ExpectedBucketOwner',
+ 'ExpectedBucketOwner',
'Expires',
'GrantFullControl',
'GrantRead',
@@ -177,9 +177,9 @@ class TransferManager:
'SSECustomerKey',
'SSECustomerKeyMD5',
'SSEKMSKeyId',
- 'SSEKMSEncryptionContext',
- 'Tagging',
- 'WebsiteRedirectLocation',
+ 'SSEKMSEncryptionContext',
+ 'Tagging',
+ 'WebsiteRedirectLocation',
]
ALLOWED_COPY_ARGS = ALLOWED_UPLOAD_ARGS + [
@@ -190,26 +190,26 @@ class TransferManager:
'CopySourceSSECustomerAlgorithm',
'CopySourceSSECustomerKey',
'CopySourceSSECustomerKeyMD5',
- 'MetadataDirective',
- 'TaggingDirective',
+ 'MetadataDirective',
+ 'TaggingDirective',
]
ALLOWED_DELETE_ARGS = [
'MFA',
'VersionId',
'RequestPayer',
- 'ExpectedBucketOwner',
+ 'ExpectedBucketOwner',
]
- VALIDATE_SUPPORTED_BUCKET_VALUES = True
-
- _UNSUPPORTED_BUCKET_PATTERNS = {
- 'S3 Object Lambda': re.compile(
- r'^arn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:'
- r'accesspoint[/:][a-zA-Z0-9\-]{1,63}'
- ),
- }
-
+ VALIDATE_SUPPORTED_BUCKET_VALUES = True
+
+ _UNSUPPORTED_BUCKET_PATTERNS = {
+ 'S3 Object Lambda': re.compile(
+ r'^arn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:'
+ r'accesspoint[/:][a-zA-Z0-9\-]{1,63}'
+ ),
+ }
+
def __init__(self, client, config=None, osutil=None, executor_cls=None):
"""A transfer manager interface for Amazon S3
@@ -239,13 +239,13 @@ class TransferManager:
max_num_threads=self._config.max_request_concurrency,
tag_semaphores={
IN_MEMORY_UPLOAD_TAG: TaskSemaphore(
- self._config.max_in_memory_upload_chunks
- ),
+ self._config.max_in_memory_upload_chunks
+ ),
IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(
- self._config.max_in_memory_download_chunks
- ),
+ self._config.max_in_memory_download_chunks
+ ),
},
- executor_cls=executor_cls,
+ executor_cls=executor_cls,
)
# The executor responsible for submitting the necessary tasks to
@@ -253,7 +253,7 @@ class TransferManager:
self._submission_executor = BoundedExecutor(
max_size=self._config.max_submission_queue_size,
max_num_threads=self._config.max_submission_concurrency,
- executor_cls=executor_cls,
+ executor_cls=executor_cls,
)
# There is one thread available for writing to disk. It will handle
@@ -261,7 +261,7 @@ class TransferManager:
self._io_executor = BoundedExecutor(
max_size=self._config.max_io_queue_size,
max_num_threads=1,
- executor_cls=executor_cls,
+ executor_cls=executor_cls,
)
# The component responsible for limiting bandwidth usage if it
@@ -269,21 +269,21 @@ class TransferManager:
self._bandwidth_limiter = None
if self._config.max_bandwidth is not None:
logger.debug(
- 'Setting max_bandwidth to %s', self._config.max_bandwidth
- )
+ 'Setting max_bandwidth to %s', self._config.max_bandwidth
+ )
leaky_bucket = LeakyBucket(self._config.max_bandwidth)
self._bandwidth_limiter = BandwidthLimiter(leaky_bucket)
self._register_handlers()
- @property
- def client(self):
- return self._client
-
- @property
- def config(self):
- return self._config
-
+ @property
+ def client(self):
+ return self._client
+
+ @property
+ def config(self):
+ return self._config
+
def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None):
"""Uploads a file to S3
@@ -315,24 +315,24 @@ class TransferManager:
if subscribers is None:
subscribers = []
self._validate_all_known_args(extra_args, self.ALLOWED_UPLOAD_ARGS)
- self._validate_if_bucket_supported(bucket)
+ self._validate_if_bucket_supported(bucket)
call_args = CallArgs(
- fileobj=fileobj,
- bucket=bucket,
- key=key,
- extra_args=extra_args,
- subscribers=subscribers,
+ fileobj=fileobj,
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
)
extra_main_kwargs = {}
if self._bandwidth_limiter:
extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter
return self._submit_transfer(
- call_args, UploadSubmissionTask, extra_main_kwargs
- )
+ call_args, UploadSubmissionTask, extra_main_kwargs
+ )
- def download(
- self, bucket, key, fileobj, extra_args=None, subscribers=None
- ):
+ def download(
+ self, bucket, key, fileobj, extra_args=None, subscribers=None
+ ):
"""Downloads a file from S3
:type bucket: str
@@ -363,30 +363,30 @@ class TransferManager:
if subscribers is None:
subscribers = []
self._validate_all_known_args(extra_args, self.ALLOWED_DOWNLOAD_ARGS)
- self._validate_if_bucket_supported(bucket)
+ self._validate_if_bucket_supported(bucket)
call_args = CallArgs(
- bucket=bucket,
- key=key,
- fileobj=fileobj,
- extra_args=extra_args,
- subscribers=subscribers,
+ bucket=bucket,
+ key=key,
+ fileobj=fileobj,
+ extra_args=extra_args,
+ subscribers=subscribers,
)
extra_main_kwargs = {'io_executor': self._io_executor}
if self._bandwidth_limiter:
extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter
return self._submit_transfer(
- call_args, DownloadSubmissionTask, extra_main_kwargs
- )
-
- def copy(
- self,
- copy_source,
- bucket,
- key,
- extra_args=None,
- subscribers=None,
- source_client=None,
- ):
+ call_args, DownloadSubmissionTask, extra_main_kwargs
+ )
+
+ def copy(
+ self,
+ copy_source,
+ bucket,
+ key,
+ extra_args=None,
+ subscribers=None,
+ source_client=None,
+ ):
"""Copies a file in S3
:type copy_source: dict
@@ -428,16 +428,16 @@ class TransferManager:
if source_client is None:
source_client = self._client
self._validate_all_known_args(extra_args, self.ALLOWED_COPY_ARGS)
- if isinstance(copy_source, dict):
- self._validate_if_bucket_supported(copy_source.get('Bucket'))
- self._validate_if_bucket_supported(bucket)
+ if isinstance(copy_source, dict):
+ self._validate_if_bucket_supported(copy_source.get('Bucket'))
+ self._validate_if_bucket_supported(bucket)
call_args = CallArgs(
- copy_source=copy_source,
- bucket=bucket,
- key=key,
- extra_args=extra_args,
- subscribers=subscribers,
- source_client=source_client,
+ copy_source=copy_source,
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
+ source_client=source_client,
)
return self._submit_transfer(call_args, CopySubmissionTask)
@@ -468,46 +468,46 @@ class TransferManager:
if subscribers is None:
subscribers = []
self._validate_all_known_args(extra_args, self.ALLOWED_DELETE_ARGS)
- self._validate_if_bucket_supported(bucket)
+ self._validate_if_bucket_supported(bucket)
call_args = CallArgs(
- bucket=bucket,
- key=key,
- extra_args=extra_args,
- subscribers=subscribers,
+ bucket=bucket,
+ key=key,
+ extra_args=extra_args,
+ subscribers=subscribers,
)
return self._submit_transfer(call_args, DeleteSubmissionTask)
- def _validate_if_bucket_supported(self, bucket):
- # s3 high level operations don't support some resources
- # (eg. S3 Object Lambda) only direct API calls are available
- # for such resources
- if self.VALIDATE_SUPPORTED_BUCKET_VALUES:
- for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items():
- match = pattern.match(bucket)
- if match:
- raise ValueError(
- 'TransferManager methods do not support %s '
- 'resource. Use direct client calls instead.' % resource
- )
-
+ def _validate_if_bucket_supported(self, bucket):
+ # s3 high level operations don't support some resources
+ # (eg. S3 Object Lambda) only direct API calls are available
+ # for such resources
+ if self.VALIDATE_SUPPORTED_BUCKET_VALUES:
+ for resource, pattern in self._UNSUPPORTED_BUCKET_PATTERNS.items():
+ match = pattern.match(bucket)
+ if match:
+ raise ValueError(
+ 'TransferManager methods do not support %s '
+ 'resource. Use direct client calls instead.' % resource
+ )
+
def _validate_all_known_args(self, actual, allowed):
for kwarg in actual:
if kwarg not in allowed:
raise ValueError(
"Invalid extra_args key '%s', "
- "must be one of: %s" % (kwarg, ', '.join(allowed))
- )
+ "must be one of: %s" % (kwarg, ', '.join(allowed))
+ )
- def _submit_transfer(
- self, call_args, submission_task_cls, extra_main_kwargs=None
- ):
+ def _submit_transfer(
+ self, call_args, submission_task_cls, extra_main_kwargs=None
+ ):
if not extra_main_kwargs:
extra_main_kwargs = {}
# Create a TransferFuture to return back to the user
transfer_future, components = self._get_future_with_components(
- call_args
- )
+ call_args
+ )
# Add any provided done callbacks to the created transfer future
# to be invoked on the transfer future being complete.
@@ -516,15 +516,15 @@ class TransferManager:
# Get the main kwargs needed to instantiate the submission task
main_kwargs = self._get_submission_task_main_kwargs(
- transfer_future, extra_main_kwargs
- )
+ transfer_future, extra_main_kwargs
+ )
# Submit a SubmissionTask that will submit all of the necessary
# tasks needed to complete the S3 transfer.
self._submission_executor.submit(
submission_task_cls(
transfer_coordinator=components['coordinator'],
- main_kwargs=main_kwargs,
+ main_kwargs=main_kwargs,
)
)
@@ -539,30 +539,30 @@ class TransferManager:
transfer_coordinator = TransferCoordinator(transfer_id=transfer_id)
# Track the transfer coordinator for transfers to manage.
self._coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
+ transfer_coordinator
+ )
# Also make sure that the transfer coordinator is removed once
# the transfer completes so it does not stick around in memory.
transfer_coordinator.add_done_callback(
self._coordinator_controller.remove_transfer_coordinator,
- transfer_coordinator,
- )
+ transfer_coordinator,
+ )
components = {
'meta': TransferMeta(call_args, transfer_id=transfer_id),
- 'coordinator': transfer_coordinator,
+ 'coordinator': transfer_coordinator,
}
transfer_future = TransferFuture(**components)
return transfer_future, components
def _get_submission_task_main_kwargs(
- self, transfer_future, extra_main_kwargs
- ):
+ self, transfer_future, extra_main_kwargs
+ ):
main_kwargs = {
'client': self._client,
'config': self._config,
'osutil': self._osutil,
'request_executor': self._request_executor,
- 'transfer_future': transfer_future,
+ 'transfer_future': transfer_future,
}
main_kwargs.update(extra_main_kwargs)
return main_kwargs
@@ -571,13 +571,13 @@ class TransferManager:
# Register handlers to enable/disable callbacks on uploads.
event_name = 'request-created.s3'
self._client.meta.events.register_first(
- event_name,
- signal_not_transferring,
- unique_id='s3upload-not-transferring',
- )
+ event_name,
+ signal_not_transferring,
+ unique_id='s3upload-not-transferring',
+ )
self._client.meta.events.register_last(
- event_name, signal_transferring, unique_id='s3upload-transferring'
- )
+ event_name, signal_transferring, unique_id='s3upload-transferring'
+ )
def __enter__(self):
return self
@@ -590,7 +590,7 @@ class TransferManager:
# all of the inprogress futures in the shutdown.
if exc_type:
cancel = True
- cancel_msg = str(exc_value)
+ cancel_msg = str(exc_value)
if not cancel_msg:
cancel_msg = repr(exc_value)
# If it was a KeyboardInterrupt, the cancellation was initiated
@@ -641,7 +641,7 @@ class TransferManager:
self._io_executor.shutdown()
-class TransferCoordinatorController:
+class TransferCoordinatorController:
def __init__(self):
"""Abstraction to control all transfer coordinators
@@ -671,7 +671,7 @@ class TransferCoordinatorController:
self._tracked_transfer_coordinators.add(transfer_coordinator)
def remove_transfer_coordinator(self, transfer_coordinator):
- """Remove a transfer coordinator from cancellation consideration
+ """Remove a transfer coordinator from cancellation consideration
Typically, this method is invoked by the transfer coordinator itself
to remove its self when it completes its transfer.
@@ -700,7 +700,7 @@ class TransferCoordinatorController:
def wait(self):
"""Wait until there are no more inprogress transfers
- This will not stop when failures are encountered and not propagate any
+ This will not stop when failures are encountered and not propagate any
of these errors from failed transfers, but it can be interrupted with
a KeyboardInterrupt.
"""
@@ -716,8 +716,8 @@ class TransferCoordinatorController:
if transfer_coordinator:
logger.debug(
'On KeyboardInterrupt was waiting for %s',
- transfer_coordinator,
- )
+ transfer_coordinator,
+ )
raise
except Exception:
# A general exception could have been thrown because
diff --git a/contrib/python/s3transfer/py3/s3transfer/processpool.py b/contrib/python/s3transfer/py3/s3transfer/processpool.py
index 783518fd19..017eeb4499 100644
--- a/contrib/python/s3transfer/py3/s3transfer/processpool.py
+++ b/contrib/python/s3transfer/py3/s3transfer/processpool.py
@@ -1,1008 +1,1008 @@
-# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-"""Speeds up S3 throughput by using processes
-
-Getting Started
-===============
-
-The :class:`ProcessPoolDownloader` can be used to download a single file by
-calling :meth:`ProcessPoolDownloader.download_file`:
-
-.. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
-
- with ProcessPoolDownloader() as downloader:
- downloader.download_file('mybucket', 'mykey', 'myfile')
-
-
-This snippet downloads the S3 object located in the bucket ``mybucket`` at the
-key ``mykey`` to the local file ``myfile``. Any errors encountered during the
-transfer are not propagated. To determine if a transfer succeeded or
-failed, use the `Futures`_ interface.
-
-
-The :class:`ProcessPoolDownloader` can be used to download multiple files as
-well:
-
-.. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
-
- with ProcessPoolDownloader() as downloader:
- downloader.download_file('mybucket', 'mykey', 'myfile')
- downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
-
-
-When running this snippet, the downloading of ``mykey`` and ``myotherkey``
-happen in parallel. The first ``download_file`` call does not block the
-second ``download_file`` call. The snippet blocks when exiting
-the context manager and blocks until both downloads are complete.
-
-Alternatively, the ``ProcessPoolDownloader`` can be instantiated
-and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`:
-
-.. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
-
- downloader = ProcessPoolDownloader()
- downloader.download_file('mybucket', 'mykey', 'myfile')
- downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
- downloader.shutdown()
-
-
-For this code snippet, the call to ``shutdown`` blocks until both
-downloads are complete.
-
-
-Additional Parameters
-=====================
-
-Additional parameters can be provided to the ``download_file`` method:
-
-* ``extra_args``: A dictionary containing any additional client arguments
- to include in the
- `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_
- API request. For example:
-
- .. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
-
- with ProcessPoolDownloader() as downloader:
- downloader.download_file(
- 'mybucket', 'mykey', 'myfile',
- extra_args={'VersionId': 'myversion'})
-
-
-* ``expected_size``: By default, the downloader will make a HeadObject
- call to determine the size of the object. To opt-out of this additional
- API call, you can provide the size of the object in bytes:
-
- .. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
-
- MB = 1024 * 1024
- with ProcessPoolDownloader() as downloader:
- downloader.download_file(
- 'mybucket', 'mykey', 'myfile', expected_size=2 * MB)
-
-
-Futures
-=======
-
-When ``download_file`` is called, it immediately returns a
-:class:`ProcessPoolTransferFuture`. The future can be used to poll the state
-of a particular transfer. To get the result of the download,
-call :meth:`ProcessPoolTransferFuture.result`. The method blocks
-until the transfer completes, whether it succeeds or fails. For example:
-
-.. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
-
- with ProcessPoolDownloader() as downloader:
- future = downloader.download_file('mybucket', 'mykey', 'myfile')
- print(future.result())
-
-
-If the download succeeds, the future returns ``None``:
-
-.. code:: python
-
- None
-
-
-If the download fails, the exception causing the failure is raised. For
-example, if ``mykey`` did not exist, the following error would be raised
-
-
-.. code:: python
-
- botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found
-
-
-.. note::
-
- :meth:`ProcessPoolTransferFuture.result` can only be called while the
- ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or
- inside the context manager).
-
-
-Process Pool Configuration
-==========================
-
-By default, the downloader has the following configuration options:
-
-* ``multipart_threshold``: The threshold size for performing ranged downloads
- in bytes. By default, ranged downloads happen for S3 objects that are
- greater than or equal to 8 MB in size.
-
-* ``multipart_chunksize``: The size of each ranged download in bytes. By
- default, the size of each ranged download is 8 MB.
-
-* ``max_request_processes``: The maximum number of processes used to download
- S3 objects. By default, the maximum is 10 processes.
-
-
-To change the default configuration, use the :class:`ProcessTransferConfig`:
-
-.. code:: python
-
- from s3transfer.processpool import ProcessPoolDownloader
- from s3transfer.processpool import ProcessTransferConfig
-
- config = ProcessTransferConfig(
- multipart_threshold=64 * 1024 * 1024, # 64 MB
- max_request_processes=50
- )
- downloader = ProcessPoolDownloader(config=config)
-
-
-Client Configuration
-====================
-
-The process pool downloader creates ``botocore`` clients on your behalf. In
-order to affect how the client is created, pass the keyword arguments
-that would have been used in the :meth:`botocore.Session.create_client` call:
-
-.. code:: python
-
-
- from s3transfer.processpool import ProcessPoolDownloader
- from s3transfer.processpool import ProcessTransferConfig
-
- downloader = ProcessPoolDownloader(
- client_kwargs={'region_name': 'us-west-2'})
-
-
-This snippet ensures that all clients created by the ``ProcessPoolDownloader``
-are using ``us-west-2`` as their region.
-
-"""
-import collections
-import contextlib
-import logging
-import multiprocessing
-import signal
-import threading
-from copy import deepcopy
-
-import botocore.session
-from botocore.config import Config
-
-from s3transfer.compat import MAXINT, BaseManager
-from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT
-from s3transfer.exceptions import CancelledError, RetriesExceededError
-from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
-from s3transfer.utils import (
- S3_RETRYABLE_DOWNLOAD_ERRORS,
- CallArgs,
- OSUtils,
- calculate_num_parts,
- calculate_range_parameter,
-)
-
-logger = logging.getLogger(__name__)
-
-SHUTDOWN_SIGNAL = 'SHUTDOWN'
-
-# The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader
-# to the GetObjectSubmitter in order for the submitter to begin submitting
-# GetObjectJobs to the GetObjectWorkers.
-DownloadFileRequest = collections.namedtuple(
- 'DownloadFileRequest',
- [
- 'transfer_id', # The unique id for the transfer
- 'bucket', # The bucket to download the object from
- 'key', # The key to download the object from
- 'filename', # The user-requested download location
- 'extra_args', # Extra arguments to provide to client calls
- 'expected_size', # The user-provided expected size of the download
- ],
-)
-
-# The GetObjectJob tuple is submitted from the GetObjectSubmitter
-# to the GetObjectWorkers to download the file or parts of the file.
-GetObjectJob = collections.namedtuple(
- 'GetObjectJob',
- [
- 'transfer_id', # The unique id for the transfer
- 'bucket', # The bucket to download the object from
- 'key', # The key to download the object from
- 'temp_filename', # The temporary file to write the content to via
- # completed GetObject calls.
- 'extra_args', # Extra arguments to provide to the GetObject call
- 'offset', # The offset to write the content for the temp file.
- 'filename', # The user-requested download location. The worker
- # of final GetObjectJob will move the file located at
- # temp_filename to the location of filename.
- ],
-)
-
-
-@contextlib.contextmanager
-def ignore_ctrl_c():
- original_handler = _add_ignore_handler_for_interrupts()
- yield
- signal.signal(signal.SIGINT, original_handler)
-
-
-def _add_ignore_handler_for_interrupts():
- # Windows is unable to pickle signal.signal directly so it needs to
- # be wrapped in a function defined at the module level
- return signal.signal(signal.SIGINT, signal.SIG_IGN)
-
-
-class ProcessTransferConfig:
- def __init__(
- self,
- multipart_threshold=8 * MB,
- multipart_chunksize=8 * MB,
- max_request_processes=10,
- ):
- """Configuration for the ProcessPoolDownloader
-
- :param multipart_threshold: The threshold for which ranged downloads
- occur.
-
- :param multipart_chunksize: The chunk size of each ranged download.
-
- :param max_request_processes: The maximum number of processes that
- will be making S3 API transfer-related requests at a time.
- """
- self.multipart_threshold = multipart_threshold
- self.multipart_chunksize = multipart_chunksize
- self.max_request_processes = max_request_processes
-
-
-class ProcessPoolDownloader:
- def __init__(self, client_kwargs=None, config=None):
- """Downloads S3 objects using process pools
-
- :type client_kwargs: dict
- :param client_kwargs: The keyword arguments to provide when
- instantiating S3 clients. The arguments must match the keyword
- arguments provided to the
- `botocore.session.Session.create_client()` method.
-
- :type config: ProcessTransferConfig
- :param config: Configuration for the downloader
- """
- if client_kwargs is None:
- client_kwargs = {}
- self._client_factory = ClientFactory(client_kwargs)
-
- self._transfer_config = config
- if config is None:
- self._transfer_config = ProcessTransferConfig()
-
- self._download_request_queue = multiprocessing.Queue(1000)
- self._worker_queue = multiprocessing.Queue(1000)
- self._osutil = OSUtils()
-
- self._started = False
- self._start_lock = threading.Lock()
-
- # These below are initialized in the start() method
- self._manager = None
- self._transfer_monitor = None
- self._submitter = None
- self._workers = []
-
- def download_file(
- self, bucket, key, filename, extra_args=None, expected_size=None
- ):
- """Downloads the object's contents to a file
-
- :type bucket: str
- :param bucket: The name of the bucket to download from
-
- :type key: str
- :param key: The name of the key to download from
-
- :type filename: str
- :param filename: The name of a file to download to.
-
- :type extra_args: dict
- :param extra_args: Extra arguments that may be passed to the
- client operation
-
- :type expected_size: int
- :param expected_size: The expected size in bytes of the download. If
- provided, the downloader will not call HeadObject to determine the
- object's size and use the provided value instead. The size is
- needed to determine whether to do a multipart download.
-
- :rtype: s3transfer.futures.TransferFuture
- :returns: Transfer future representing the download
- """
- self._start_if_needed()
- if extra_args is None:
- extra_args = {}
- self._validate_all_known_args(extra_args)
- transfer_id = self._transfer_monitor.notify_new_transfer()
- download_file_request = DownloadFileRequest(
- transfer_id=transfer_id,
- bucket=bucket,
- key=key,
- filename=filename,
- extra_args=extra_args,
- expected_size=expected_size,
- )
- logger.debug(
- 'Submitting download file request: %s.', download_file_request
- )
- self._download_request_queue.put(download_file_request)
- call_args = CallArgs(
- bucket=bucket,
- key=key,
- filename=filename,
- extra_args=extra_args,
- expected_size=expected_size,
- )
- future = self._get_transfer_future(transfer_id, call_args)
- return future
-
- def shutdown(self):
- """Shutdown the downloader
-
- It will wait till all downloads are complete before returning.
- """
- self._shutdown_if_needed()
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, *args):
- if isinstance(exc_value, KeyboardInterrupt):
- if self._transfer_monitor is not None:
- self._transfer_monitor.notify_cancel_all_in_progress()
- self.shutdown()
-
- def _start_if_needed(self):
- with self._start_lock:
- if not self._started:
- self._start()
-
- def _start(self):
- self._start_transfer_monitor_manager()
- self._start_submitter()
- self._start_get_object_workers()
- self._started = True
-
- def _validate_all_known_args(self, provided):
- for kwarg in provided:
- if kwarg not in ALLOWED_DOWNLOAD_ARGS:
- download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS)
- raise ValueError(
- f"Invalid extra_args key '{kwarg}', "
- f"must be one of: {download_args}"
- )
-
- def _get_transfer_future(self, transfer_id, call_args):
- meta = ProcessPoolTransferMeta(
- call_args=call_args, transfer_id=transfer_id
- )
- future = ProcessPoolTransferFuture(
- monitor=self._transfer_monitor, meta=meta
- )
- return future
-
- def _start_transfer_monitor_manager(self):
- logger.debug('Starting the TransferMonitorManager.')
- self._manager = TransferMonitorManager()
- # We do not want Ctrl-C's to cause the manager to shutdown immediately
- # as worker processes will still need to communicate with it when they
- # are shutting down. So instead we ignore Ctrl-C and let the manager
- # be explicitly shutdown when shutting down the downloader.
- self._manager.start(_add_ignore_handler_for_interrupts)
- self._transfer_monitor = self._manager.TransferMonitor()
-
- def _start_submitter(self):
- logger.debug('Starting the GetObjectSubmitter.')
- self._submitter = GetObjectSubmitter(
- transfer_config=self._transfer_config,
- client_factory=self._client_factory,
- transfer_monitor=self._transfer_monitor,
- osutil=self._osutil,
- download_request_queue=self._download_request_queue,
- worker_queue=self._worker_queue,
- )
- self._submitter.start()
-
- def _start_get_object_workers(self):
- logger.debug(
- 'Starting %s GetObjectWorkers.',
- self._transfer_config.max_request_processes,
- )
- for _ in range(self._transfer_config.max_request_processes):
- worker = GetObjectWorker(
- queue=self._worker_queue,
- client_factory=self._client_factory,
- transfer_monitor=self._transfer_monitor,
- osutil=self._osutil,
- )
- worker.start()
- self._workers.append(worker)
-
- def _shutdown_if_needed(self):
- with self._start_lock:
- if self._started:
- self._shutdown()
-
- def _shutdown(self):
- self._shutdown_submitter()
- self._shutdown_get_object_workers()
- self._shutdown_transfer_monitor_manager()
- self._started = False
-
- def _shutdown_transfer_monitor_manager(self):
- logger.debug('Shutting down the TransferMonitorManager.')
- self._manager.shutdown()
-
- def _shutdown_submitter(self):
- logger.debug('Shutting down the GetObjectSubmitter.')
- self._download_request_queue.put(SHUTDOWN_SIGNAL)
- self._submitter.join()
-
- def _shutdown_get_object_workers(self):
- logger.debug('Shutting down the GetObjectWorkers.')
- for _ in self._workers:
- self._worker_queue.put(SHUTDOWN_SIGNAL)
- for worker in self._workers:
- worker.join()
-
-
-class ProcessPoolTransferFuture(BaseTransferFuture):
- def __init__(self, monitor, meta):
- """The future associated to a submitted process pool transfer request
-
- :type monitor: TransferMonitor
- :param monitor: The monitor associated to the process pool downloader
-
- :type meta: ProcessPoolTransferMeta
- :param meta: The metadata associated to the request. This object
- is visible to the requester.
- """
- self._monitor = monitor
- self._meta = meta
-
- @property
- def meta(self):
- return self._meta
-
- def done(self):
- return self._monitor.is_done(self._meta.transfer_id)
-
- def result(self):
- try:
- return self._monitor.poll_for_result(self._meta.transfer_id)
- except KeyboardInterrupt:
- # For the multiprocessing Manager, a thread is given a single
- # connection to reuse in communicating between the thread in the
- # main process and the Manager's process. If a Ctrl-C happens when
- # polling for the result, it will make the main thread stop trying
- # to receive from the connection, but the Manager process will not
- # know that the main process has stopped trying to receive and
- # will not close the connection. As a result if another message is
- # sent to the Manager process, the listener in the Manager
- # processes will not process the new message as it is still trying
- # trying to process the previous message (that was Ctrl-C'd) and
- # thus cause the thread in the main process to hang on its send.
- # The only way around this is to create a new connection and send
- # messages from that new connection instead.
- self._monitor._connect()
- self.cancel()
- raise
-
- def cancel(self):
- self._monitor.notify_exception(
- self._meta.transfer_id, CancelledError()
- )
-
-
-class ProcessPoolTransferMeta(BaseTransferMeta):
- """Holds metadata about the ProcessPoolTransferFuture"""
-
- def __init__(self, transfer_id, call_args):
- self._transfer_id = transfer_id
- self._call_args = call_args
- self._user_context = {}
-
- @property
- def call_args(self):
- return self._call_args
-
- @property
- def transfer_id(self):
- return self._transfer_id
-
- @property
- def user_context(self):
- return self._user_context
-
-
-class ClientFactory:
- def __init__(self, client_kwargs=None):
- """Creates S3 clients for processes
-
- Botocore sessions and clients are not pickleable so they cannot be
- inherited across Process boundaries. Instead, they must be instantiated
- once a process is running.
- """
- self._client_kwargs = client_kwargs
- if self._client_kwargs is None:
- self._client_kwargs = {}
-
- client_config = deepcopy(self._client_kwargs.get('config', Config()))
- if not client_config.user_agent_extra:
- client_config.user_agent_extra = PROCESS_USER_AGENT
- else:
- client_config.user_agent_extra += " " + PROCESS_USER_AGENT
- self._client_kwargs['config'] = client_config
-
- def create_client(self):
- """Create a botocore S3 client"""
- return botocore.session.Session().create_client(
- 's3', **self._client_kwargs
- )
-
-
-class TransferMonitor:
- def __init__(self):
- """Monitors transfers for cross-process communication
-
- Notifications can be sent to the monitor and information can be
- retrieved from the monitor for a particular transfer. This abstraction
- is ran in a ``multiprocessing.managers.BaseManager`` in order to be
- shared across processes.
- """
- # TODO: Add logic that removes the TransferState if the transfer is
- # marked as done and the reference to the future is no longer being
- # held onto. Without this logic, this dictionary will continue to
- # grow in size with no limit.
- self._transfer_states = {}
- self._id_count = 0
- self._init_lock = threading.Lock()
-
- def notify_new_transfer(self):
- with self._init_lock:
- transfer_id = self._id_count
- self._transfer_states[transfer_id] = TransferState()
- self._id_count += 1
- return transfer_id
-
- def is_done(self, transfer_id):
- """Determine a particular transfer is complete
-
- :param transfer_id: Unique identifier for the transfer
- :return: True, if done. False, otherwise.
- """
- return self._transfer_states[transfer_id].done
-
- def notify_done(self, transfer_id):
- """Notify a particular transfer is complete
-
- :param transfer_id: Unique identifier for the transfer
- """
- self._transfer_states[transfer_id].set_done()
-
- def poll_for_result(self, transfer_id):
- """Poll for the result of a transfer
-
- :param transfer_id: Unique identifier for the transfer
- :return: If the transfer succeeded, it will return the result. If the
- transfer failed, it will raise the exception associated to the
- failure.
- """
- self._transfer_states[transfer_id].wait_till_done()
- exception = self._transfer_states[transfer_id].exception
- if exception:
- raise exception
- return None
-
- def notify_exception(self, transfer_id, exception):
- """Notify an exception was encountered for a transfer
-
- :param transfer_id: Unique identifier for the transfer
- :param exception: The exception encountered for that transfer
- """
- # TODO: Not all exceptions are pickleable so if we are running
- # this in a multiprocessing.BaseManager we will want to
- # make sure to update this signature to ensure pickleability of the
- # arguments or have the ProxyObject do the serialization.
- self._transfer_states[transfer_id].exception = exception
-
- def notify_cancel_all_in_progress(self):
- for transfer_state in self._transfer_states.values():
- if not transfer_state.done:
- transfer_state.exception = CancelledError()
-
- def get_exception(self, transfer_id):
- """Retrieve the exception encountered for the transfer
-
- :param transfer_id: Unique identifier for the transfer
- :return: The exception encountered for that transfer. Otherwise
- if there were no exceptions, returns None.
- """
- return self._transfer_states[transfer_id].exception
-
- def notify_expected_jobs_to_complete(self, transfer_id, num_jobs):
- """Notify the amount of jobs expected for a transfer
-
- :param transfer_id: Unique identifier for the transfer
- :param num_jobs: The number of jobs to complete the transfer
- """
- self._transfer_states[transfer_id].jobs_to_complete = num_jobs
-
- def notify_job_complete(self, transfer_id):
- """Notify that a single job is completed for a transfer
-
- :param transfer_id: Unique identifier for the transfer
- :return: The number of jobs remaining to complete the transfer
- """
- return self._transfer_states[transfer_id].decrement_jobs_to_complete()
-
-
-class TransferState:
- """Represents the current state of an individual transfer"""
-
- # NOTE: Ideally the TransferState object would be used directly by the
- # various different abstractions in the ProcessPoolDownloader and remove
- # the need for the TransferMonitor. However, it would then impose the
- # constraint that two hops are required to make or get any changes in the
- # state of a transfer across processes: one hop to get a proxy object for
- # the TransferState and then a second hop to communicate calling the
- # specific TransferState method.
- def __init__(self):
- self._exception = None
- self._done_event = threading.Event()
- self._job_lock = threading.Lock()
- self._jobs_to_complete = 0
-
- @property
- def done(self):
- return self._done_event.is_set()
-
- def set_done(self):
- self._done_event.set()
-
- def wait_till_done(self):
- self._done_event.wait(MAXINT)
-
- @property
- def exception(self):
- return self._exception
-
- @exception.setter
- def exception(self, val):
- self._exception = val
-
- @property
- def jobs_to_complete(self):
- return self._jobs_to_complete
-
- @jobs_to_complete.setter
- def jobs_to_complete(self, val):
- self._jobs_to_complete = val
-
- def decrement_jobs_to_complete(self):
- with self._job_lock:
- self._jobs_to_complete -= 1
- return self._jobs_to_complete
-
-
-class TransferMonitorManager(BaseManager):
- pass
-
-
-TransferMonitorManager.register('TransferMonitor', TransferMonitor)
-
-
-class BaseS3TransferProcess(multiprocessing.Process):
- def __init__(self, client_factory):
- super().__init__()
- self._client_factory = client_factory
- self._client = None
-
- def run(self):
- # Clients are not pickleable so their instantiation cannot happen
- # in the __init__ for processes that are created under the
- # spawn method.
- self._client = self._client_factory.create_client()
- with ignore_ctrl_c():
- # By default these processes are ran as child processes to the
- # main process. Any Ctrl-c encountered in the main process is
- # propagated to the child process and interrupt it at any time.
- # To avoid any potentially bad states caused from an interrupt
- # (i.e. a transfer failing to notify its done or making the
- # communication protocol become out of sync with the
- # TransferMonitor), we ignore all Ctrl-C's and allow the main
- # process to notify these child processes when to stop processing
- # jobs.
- self._do_run()
-
- def _do_run(self):
- raise NotImplementedError('_do_run()')
-
-
-class GetObjectSubmitter(BaseS3TransferProcess):
- def __init__(
- self,
- transfer_config,
- client_factory,
- transfer_monitor,
- osutil,
- download_request_queue,
- worker_queue,
- ):
- """Submit GetObjectJobs to fulfill a download file request
-
- :param transfer_config: Configuration for transfers.
- :param client_factory: ClientFactory for creating S3 clients.
- :param transfer_monitor: Monitor for notifying and retrieving state
- of transfer.
- :param osutil: OSUtils object to use for os-related behavior when
- performing the transfer.
- :param download_request_queue: Queue to retrieve download file
- requests.
- :param worker_queue: Queue to submit GetObjectJobs for workers
- to perform.
- """
- super().__init__(client_factory)
- self._transfer_config = transfer_config
- self._transfer_monitor = transfer_monitor
- self._osutil = osutil
- self._download_request_queue = download_request_queue
- self._worker_queue = worker_queue
-
- def _do_run(self):
- while True:
- download_file_request = self._download_request_queue.get()
- if download_file_request == SHUTDOWN_SIGNAL:
- logger.debug('Submitter shutdown signal received.')
- return
- try:
- self._submit_get_object_jobs(download_file_request)
- except Exception as e:
- logger.debug(
- 'Exception caught when submitting jobs for '
- 'download file request %s: %s',
- download_file_request,
- e,
- exc_info=True,
- )
- self._transfer_monitor.notify_exception(
- download_file_request.transfer_id, e
- )
- self._transfer_monitor.notify_done(
- download_file_request.transfer_id
- )
-
- def _submit_get_object_jobs(self, download_file_request):
- size = self._get_size(download_file_request)
- temp_filename = self._allocate_temp_file(download_file_request, size)
- if size < self._transfer_config.multipart_threshold:
- self._submit_single_get_object_job(
- download_file_request, temp_filename
- )
- else:
- self._submit_ranged_get_object_jobs(
- download_file_request, temp_filename, size
- )
-
- def _get_size(self, download_file_request):
- expected_size = download_file_request.expected_size
- if expected_size is None:
- expected_size = self._client.head_object(
- Bucket=download_file_request.bucket,
- Key=download_file_request.key,
- **download_file_request.extra_args,
- )['ContentLength']
- return expected_size
-
- def _allocate_temp_file(self, download_file_request, size):
- temp_filename = self._osutil.get_temp_filename(
- download_file_request.filename
- )
- self._osutil.allocate(temp_filename, size)
- return temp_filename
-
- def _submit_single_get_object_job(
- self, download_file_request, temp_filename
- ):
- self._notify_jobs_to_complete(download_file_request.transfer_id, 1)
- self._submit_get_object_job(
- transfer_id=download_file_request.transfer_id,
- bucket=download_file_request.bucket,
- key=download_file_request.key,
- temp_filename=temp_filename,
- offset=0,
- extra_args=download_file_request.extra_args,
- filename=download_file_request.filename,
- )
-
- def _submit_ranged_get_object_jobs(
- self, download_file_request, temp_filename, size
- ):
- part_size = self._transfer_config.multipart_chunksize
- num_parts = calculate_num_parts(size, part_size)
- self._notify_jobs_to_complete(
- download_file_request.transfer_id, num_parts
- )
- for i in range(num_parts):
- offset = i * part_size
- range_parameter = calculate_range_parameter(
- part_size, i, num_parts
- )
- get_object_kwargs = {'Range': range_parameter}
- get_object_kwargs.update(download_file_request.extra_args)
- self._submit_get_object_job(
- transfer_id=download_file_request.transfer_id,
- bucket=download_file_request.bucket,
- key=download_file_request.key,
- temp_filename=temp_filename,
- offset=offset,
- extra_args=get_object_kwargs,
- filename=download_file_request.filename,
- )
-
- def _submit_get_object_job(self, **get_object_job_kwargs):
- self._worker_queue.put(GetObjectJob(**get_object_job_kwargs))
-
- def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete):
- logger.debug(
- 'Notifying %s job(s) to complete for transfer_id %s.',
- jobs_to_complete,
- transfer_id,
- )
- self._transfer_monitor.notify_expected_jobs_to_complete(
- transfer_id, jobs_to_complete
- )
-
-
-class GetObjectWorker(BaseS3TransferProcess):
- # TODO: It may make sense to expose these class variables as configuration
- # options if users want to tweak them.
- _MAX_ATTEMPTS = 5
- _IO_CHUNKSIZE = 2 * MB
-
- def __init__(self, queue, client_factory, transfer_monitor, osutil):
- """Fulfills GetObjectJobs
-
- Downloads the S3 object, writes it to the specified file, and
- renames the file to its final location if it completes the final
- job for a particular transfer.
-
- :param queue: Queue for retrieving GetObjectJob's
- :param client_factory: ClientFactory for creating S3 clients
- :param transfer_monitor: Monitor for notifying
- :param osutil: OSUtils object to use for os-related behavior when
- performing the transfer.
- """
- super().__init__(client_factory)
- self._queue = queue
- self._client_factory = client_factory
- self._transfer_monitor = transfer_monitor
- self._osutil = osutil
-
- def _do_run(self):
- while True:
- job = self._queue.get()
- if job == SHUTDOWN_SIGNAL:
- logger.debug('Worker shutdown signal received.')
- return
- if not self._transfer_monitor.get_exception(job.transfer_id):
- self._run_get_object_job(job)
- else:
- logger.debug(
- 'Skipping get object job %s because there was a previous '
- 'exception.',
- job,
- )
- remaining = self._transfer_monitor.notify_job_complete(
- job.transfer_id
- )
- logger.debug(
- '%s jobs remaining for transfer_id %s.',
- remaining,
- job.transfer_id,
- )
- if not remaining:
- self._finalize_download(
- job.transfer_id, job.temp_filename, job.filename
- )
-
- def _run_get_object_job(self, job):
- try:
- self._do_get_object(
- bucket=job.bucket,
- key=job.key,
- temp_filename=job.temp_filename,
- extra_args=job.extra_args,
- offset=job.offset,
- )
- except Exception as e:
- logger.debug(
- 'Exception caught when downloading object for '
- 'get object job %s: %s',
- job,
- e,
- exc_info=True,
- )
- self._transfer_monitor.notify_exception(job.transfer_id, e)
-
- def _do_get_object(self, bucket, key, extra_args, temp_filename, offset):
- last_exception = None
- for i in range(self._MAX_ATTEMPTS):
- try:
- response = self._client.get_object(
- Bucket=bucket, Key=key, **extra_args
- )
- self._write_to_file(temp_filename, offset, response['Body'])
- return
- except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
- logger.debug(
- 'Retrying exception caught (%s), '
- 'retrying request, (attempt %s / %s)',
- e,
- i + 1,
- self._MAX_ATTEMPTS,
- exc_info=True,
- )
- last_exception = e
- raise RetriesExceededError(last_exception)
-
- def _write_to_file(self, filename, offset, body):
- with open(filename, 'rb+') as f:
- f.seek(offset)
- chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'')
- for chunk in chunks:
- f.write(chunk)
-
- def _finalize_download(self, transfer_id, temp_filename, filename):
- if self._transfer_monitor.get_exception(transfer_id):
- self._osutil.remove_file(temp_filename)
- else:
- self._do_file_rename(transfer_id, temp_filename, filename)
- self._transfer_monitor.notify_done(transfer_id)
-
- def _do_file_rename(self, transfer_id, temp_filename, filename):
- try:
- self._osutil.rename_file(temp_filename, filename)
- except Exception as e:
- self._transfer_monitor.notify_exception(transfer_id, e)
- self._osutil.remove_file(temp_filename)
+# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Speeds up S3 throughput by using processes
+
+Getting Started
+===============
+
+The :class:`ProcessPoolDownloader` can be used to download a single file by
+calling :meth:`ProcessPoolDownloader.download_file`:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file('mybucket', 'mykey', 'myfile')
+
+
+This snippet downloads the S3 object located in the bucket ``mybucket`` at the
+key ``mykey`` to the local file ``myfile``. Any errors encountered during the
+transfer are not propagated. To determine if a transfer succeeded or
+failed, use the `Futures`_ interface.
+
+
+The :class:`ProcessPoolDownloader` can be used to download multiple files as
+well:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file('mybucket', 'mykey', 'myfile')
+ downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
+
+
+When running this snippet, the downloading of ``mykey`` and ``myotherkey``
+happen in parallel. The first ``download_file`` call does not block the
+second ``download_file`` call. The snippet blocks when exiting
+the context manager and blocks until both downloads are complete.
+
+Alternatively, the ``ProcessPoolDownloader`` can be instantiated
+and explicitly be shutdown using :meth:`ProcessPoolDownloader.shutdown`:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ downloader = ProcessPoolDownloader()
+ downloader.download_file('mybucket', 'mykey', 'myfile')
+ downloader.download_file('mybucket', 'myotherkey', 'myotherfile')
+ downloader.shutdown()
+
+
+For this code snippet, the call to ``shutdown`` blocks until both
+downloads are complete.
+
+
+Additional Parameters
+=====================
+
+Additional parameters can be provided to the ``download_file`` method:
+
+* ``extra_args``: A dictionary containing any additional client arguments
+ to include in the
+ `GetObject <https://botocore.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.get_object>`_
+ API request. For example:
+
+ .. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file(
+ 'mybucket', 'mykey', 'myfile',
+ extra_args={'VersionId': 'myversion'})
+
+
+* ``expected_size``: By default, the downloader will make a HeadObject
+ call to determine the size of the object. To opt-out of this additional
+ API call, you can provide the size of the object in bytes:
+
+ .. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ MB = 1024 * 1024
+ with ProcessPoolDownloader() as downloader:
+ downloader.download_file(
+ 'mybucket', 'mykey', 'myfile', expected_size=2 * MB)
+
+
+Futures
+=======
+
+When ``download_file`` is called, it immediately returns a
+:class:`ProcessPoolTransferFuture`. The future can be used to poll the state
+of a particular transfer. To get the result of the download,
+call :meth:`ProcessPoolTransferFuture.result`. The method blocks
+until the transfer completes, whether it succeeds or fails. For example:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+
+ with ProcessPoolDownloader() as downloader:
+ future = downloader.download_file('mybucket', 'mykey', 'myfile')
+ print(future.result())
+
+
+If the download succeeds, the future returns ``None``:
+
+.. code:: python
+
+ None
+
+
+If the download fails, the exception causing the failure is raised. For
+example, if ``mykey`` did not exist, the following error would be raised
+
+
+.. code:: python
+
+ botocore.exceptions.ClientError: An error occurred (404) when calling the HeadObject operation: Not Found
+
+
+.. note::
+
+ :meth:`ProcessPoolTransferFuture.result` can only be called while the
+ ``ProcessPoolDownloader`` is running (e.g. before calling ``shutdown`` or
+ inside the context manager).
+
+
+Process Pool Configuration
+==========================
+
+By default, the downloader has the following configuration options:
+
+* ``multipart_threshold``: The threshold size for performing ranged downloads
+ in bytes. By default, ranged downloads happen for S3 objects that are
+ greater than or equal to 8 MB in size.
+
+* ``multipart_chunksize``: The size of each ranged download in bytes. By
+ default, the size of each ranged download is 8 MB.
+
+* ``max_request_processes``: The maximum number of processes used to download
+ S3 objects. By default, the maximum is 10 processes.
+
+
+To change the default configuration, use the :class:`ProcessTransferConfig`:
+
+.. code:: python
+
+ from s3transfer.processpool import ProcessPoolDownloader
+ from s3transfer.processpool import ProcessTransferConfig
+
+ config = ProcessTransferConfig(
+ multipart_threshold=64 * 1024 * 1024, # 64 MB
+ max_request_processes=50
+ )
+ downloader = ProcessPoolDownloader(config=config)
+
+
+Client Configuration
+====================
+
+The process pool downloader creates ``botocore`` clients on your behalf. In
+order to affect how the client is created, pass the keyword arguments
+that would have been used in the :meth:`botocore.Session.create_client` call:
+
+.. code:: python
+
+
+ from s3transfer.processpool import ProcessPoolDownloader
+ from s3transfer.processpool import ProcessTransferConfig
+
+ downloader = ProcessPoolDownloader(
+ client_kwargs={'region_name': 'us-west-2'})
+
+
+This snippet ensures that all clients created by the ``ProcessPoolDownloader``
+are using ``us-west-2`` as their region.
+
+"""
+import collections
+import contextlib
+import logging
+import multiprocessing
+import signal
+import threading
+from copy import deepcopy
+
+import botocore.session
+from botocore.config import Config
+
+from s3transfer.compat import MAXINT, BaseManager
+from s3transfer.constants import ALLOWED_DOWNLOAD_ARGS, MB, PROCESS_USER_AGENT
+from s3transfer.exceptions import CancelledError, RetriesExceededError
+from s3transfer.futures import BaseTransferFuture, BaseTransferMeta
+from s3transfer.utils import (
+ S3_RETRYABLE_DOWNLOAD_ERRORS,
+ CallArgs,
+ OSUtils,
+ calculate_num_parts,
+ calculate_range_parameter,
+)
+
+logger = logging.getLogger(__name__)
+
+SHUTDOWN_SIGNAL = 'SHUTDOWN'
+
+# The DownloadFileRequest tuple is submitted from the ProcessPoolDownloader
+# to the GetObjectSubmitter in order for the submitter to begin submitting
+# GetObjectJobs to the GetObjectWorkers.
+DownloadFileRequest = collections.namedtuple(
+ 'DownloadFileRequest',
+ [
+ 'transfer_id', # The unique id for the transfer
+ 'bucket', # The bucket to download the object from
+ 'key', # The key to download the object from
+ 'filename', # The user-requested download location
+ 'extra_args', # Extra arguments to provide to client calls
+ 'expected_size', # The user-provided expected size of the download
+ ],
+)
+
+# The GetObjectJob tuple is submitted from the GetObjectSubmitter
+# to the GetObjectWorkers to download the file or parts of the file.
+GetObjectJob = collections.namedtuple(
+ 'GetObjectJob',
+ [
+ 'transfer_id', # The unique id for the transfer
+ 'bucket', # The bucket to download the object from
+ 'key', # The key to download the object from
+ 'temp_filename', # The temporary file to write the content to via
+ # completed GetObject calls.
+ 'extra_args', # Extra arguments to provide to the GetObject call
+ 'offset', # The offset to write the content for the temp file.
+ 'filename', # The user-requested download location. The worker
+ # of final GetObjectJob will move the file located at
+ # temp_filename to the location of filename.
+ ],
+)
+
+
+@contextlib.contextmanager
+def ignore_ctrl_c():
+ original_handler = _add_ignore_handler_for_interrupts()
+ yield
+ signal.signal(signal.SIGINT, original_handler)
+
+
+def _add_ignore_handler_for_interrupts():
+ # Windows is unable to pickle signal.signal directly so it needs to
+ # be wrapped in a function defined at the module level
+ return signal.signal(signal.SIGINT, signal.SIG_IGN)
+
+
+class ProcessTransferConfig:
+ def __init__(
+ self,
+ multipart_threshold=8 * MB,
+ multipart_chunksize=8 * MB,
+ max_request_processes=10,
+ ):
+ """Configuration for the ProcessPoolDownloader
+
+ :param multipart_threshold: The threshold for which ranged downloads
+ occur.
+
+ :param multipart_chunksize: The chunk size of each ranged download.
+
+ :param max_request_processes: The maximum number of processes that
+ will be making S3 API transfer-related requests at a time.
+ """
+ self.multipart_threshold = multipart_threshold
+ self.multipart_chunksize = multipart_chunksize
+ self.max_request_processes = max_request_processes
+
+
+class ProcessPoolDownloader:
+ def __init__(self, client_kwargs=None, config=None):
+ """Downloads S3 objects using process pools
+
+ :type client_kwargs: dict
+ :param client_kwargs: The keyword arguments to provide when
+ instantiating S3 clients. The arguments must match the keyword
+ arguments provided to the
+ `botocore.session.Session.create_client()` method.
+
+ :type config: ProcessTransferConfig
+ :param config: Configuration for the downloader
+ """
+ if client_kwargs is None:
+ client_kwargs = {}
+ self._client_factory = ClientFactory(client_kwargs)
+
+ self._transfer_config = config
+ if config is None:
+ self._transfer_config = ProcessTransferConfig()
+
+ self._download_request_queue = multiprocessing.Queue(1000)
+ self._worker_queue = multiprocessing.Queue(1000)
+ self._osutil = OSUtils()
+
+ self._started = False
+ self._start_lock = threading.Lock()
+
+ # These below are initialized in the start() method
+ self._manager = None
+ self._transfer_monitor = None
+ self._submitter = None
+ self._workers = []
+
+ def download_file(
+ self, bucket, key, filename, extra_args=None, expected_size=None
+ ):
+ """Downloads the object's contents to a file
+
+ :type bucket: str
+ :param bucket: The name of the bucket to download from
+
+ :type key: str
+ :param key: The name of the key to download from
+
+ :type filename: str
+ :param filename: The name of a file to download to.
+
+ :type extra_args: dict
+ :param extra_args: Extra arguments that may be passed to the
+ client operation
+
+ :type expected_size: int
+ :param expected_size: The expected size in bytes of the download. If
+ provided, the downloader will not call HeadObject to determine the
+ object's size and use the provided value instead. The size is
+ needed to determine whether to do a multipart download.
+
+ :rtype: s3transfer.futures.TransferFuture
+ :returns: Transfer future representing the download
+ """
+ self._start_if_needed()
+ if extra_args is None:
+ extra_args = {}
+ self._validate_all_known_args(extra_args)
+ transfer_id = self._transfer_monitor.notify_new_transfer()
+ download_file_request = DownloadFileRequest(
+ transfer_id=transfer_id,
+ bucket=bucket,
+ key=key,
+ filename=filename,
+ extra_args=extra_args,
+ expected_size=expected_size,
+ )
+ logger.debug(
+ 'Submitting download file request: %s.', download_file_request
+ )
+ self._download_request_queue.put(download_file_request)
+ call_args = CallArgs(
+ bucket=bucket,
+ key=key,
+ filename=filename,
+ extra_args=extra_args,
+ expected_size=expected_size,
+ )
+ future = self._get_transfer_future(transfer_id, call_args)
+ return future
+
+ def shutdown(self):
+ """Shutdown the downloader
+
+ It will wait till all downloads are complete before returning.
+ """
+ self._shutdown_if_needed()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, *args):
+ if isinstance(exc_value, KeyboardInterrupt):
+ if self._transfer_monitor is not None:
+ self._transfer_monitor.notify_cancel_all_in_progress()
+ self.shutdown()
+
+ def _start_if_needed(self):
+ with self._start_lock:
+ if not self._started:
+ self._start()
+
+ def _start(self):
+ self._start_transfer_monitor_manager()
+ self._start_submitter()
+ self._start_get_object_workers()
+ self._started = True
+
+ def _validate_all_known_args(self, provided):
+ for kwarg in provided:
+ if kwarg not in ALLOWED_DOWNLOAD_ARGS:
+ download_args = ', '.join(ALLOWED_DOWNLOAD_ARGS)
+ raise ValueError(
+ f"Invalid extra_args key '{kwarg}', "
+ f"must be one of: {download_args}"
+ )
+
+ def _get_transfer_future(self, transfer_id, call_args):
+ meta = ProcessPoolTransferMeta(
+ call_args=call_args, transfer_id=transfer_id
+ )
+ future = ProcessPoolTransferFuture(
+ monitor=self._transfer_monitor, meta=meta
+ )
+ return future
+
+ def _start_transfer_monitor_manager(self):
+ logger.debug('Starting the TransferMonitorManager.')
+ self._manager = TransferMonitorManager()
+ # We do not want Ctrl-C's to cause the manager to shutdown immediately
+ # as worker processes will still need to communicate with it when they
+ # are shutting down. So instead we ignore Ctrl-C and let the manager
+ # be explicitly shutdown when shutting down the downloader.
+ self._manager.start(_add_ignore_handler_for_interrupts)
+ self._transfer_monitor = self._manager.TransferMonitor()
+
+ def _start_submitter(self):
+ logger.debug('Starting the GetObjectSubmitter.')
+ self._submitter = GetObjectSubmitter(
+ transfer_config=self._transfer_config,
+ client_factory=self._client_factory,
+ transfer_monitor=self._transfer_monitor,
+ osutil=self._osutil,
+ download_request_queue=self._download_request_queue,
+ worker_queue=self._worker_queue,
+ )
+ self._submitter.start()
+
+ def _start_get_object_workers(self):
+ logger.debug(
+ 'Starting %s GetObjectWorkers.',
+ self._transfer_config.max_request_processes,
+ )
+ for _ in range(self._transfer_config.max_request_processes):
+ worker = GetObjectWorker(
+ queue=self._worker_queue,
+ client_factory=self._client_factory,
+ transfer_monitor=self._transfer_monitor,
+ osutil=self._osutil,
+ )
+ worker.start()
+ self._workers.append(worker)
+
+ def _shutdown_if_needed(self):
+ with self._start_lock:
+ if self._started:
+ self._shutdown()
+
+ def _shutdown(self):
+ self._shutdown_submitter()
+ self._shutdown_get_object_workers()
+ self._shutdown_transfer_monitor_manager()
+ self._started = False
+
+ def _shutdown_transfer_monitor_manager(self):
+ logger.debug('Shutting down the TransferMonitorManager.')
+ self._manager.shutdown()
+
+ def _shutdown_submitter(self):
+ logger.debug('Shutting down the GetObjectSubmitter.')
+ self._download_request_queue.put(SHUTDOWN_SIGNAL)
+ self._submitter.join()
+
+ def _shutdown_get_object_workers(self):
+ logger.debug('Shutting down the GetObjectWorkers.')
+ for _ in self._workers:
+ self._worker_queue.put(SHUTDOWN_SIGNAL)
+ for worker in self._workers:
+ worker.join()
+
+
+class ProcessPoolTransferFuture(BaseTransferFuture):
+ def __init__(self, monitor, meta):
+ """The future associated to a submitted process pool transfer request
+
+ :type monitor: TransferMonitor
+ :param monitor: The monitor associated to the process pool downloader
+
+ :type meta: ProcessPoolTransferMeta
+ :param meta: The metadata associated to the request. This object
+ is visible to the requester.
+ """
+ self._monitor = monitor
+ self._meta = meta
+
+ @property
+ def meta(self):
+ return self._meta
+
+ def done(self):
+ return self._monitor.is_done(self._meta.transfer_id)
+
+ def result(self):
+ try:
+ return self._monitor.poll_for_result(self._meta.transfer_id)
+ except KeyboardInterrupt:
+ # For the multiprocessing Manager, a thread is given a single
+ # connection to reuse in communicating between the thread in the
+ # main process and the Manager's process. If a Ctrl-C happens when
+ # polling for the result, it will make the main thread stop trying
+ # to receive from the connection, but the Manager process will not
+ # know that the main process has stopped trying to receive and
+ # will not close the connection. As a result if another message is
+ # sent to the Manager process, the listener in the Manager
+ # processes will not process the new message as it is still trying
+ # trying to process the previous message (that was Ctrl-C'd) and
+ # thus cause the thread in the main process to hang on its send.
+ # The only way around this is to create a new connection and send
+ # messages from that new connection instead.
+ self._monitor._connect()
+ self.cancel()
+ raise
+
+ def cancel(self):
+ self._monitor.notify_exception(
+ self._meta.transfer_id, CancelledError()
+ )
+
+
+class ProcessPoolTransferMeta(BaseTransferMeta):
+ """Holds metadata about the ProcessPoolTransferFuture"""
+
+ def __init__(self, transfer_id, call_args):
+ self._transfer_id = transfer_id
+ self._call_args = call_args
+ self._user_context = {}
+
+ @property
+ def call_args(self):
+ return self._call_args
+
+ @property
+ def transfer_id(self):
+ return self._transfer_id
+
+ @property
+ def user_context(self):
+ return self._user_context
+
+
+class ClientFactory:
+ def __init__(self, client_kwargs=None):
+ """Creates S3 clients for processes
+
+ Botocore sessions and clients are not pickleable so they cannot be
+ inherited across Process boundaries. Instead, they must be instantiated
+ once a process is running.
+ """
+ self._client_kwargs = client_kwargs
+ if self._client_kwargs is None:
+ self._client_kwargs = {}
+
+ client_config = deepcopy(self._client_kwargs.get('config', Config()))
+ if not client_config.user_agent_extra:
+ client_config.user_agent_extra = PROCESS_USER_AGENT
+ else:
+ client_config.user_agent_extra += " " + PROCESS_USER_AGENT
+ self._client_kwargs['config'] = client_config
+
+ def create_client(self):
+ """Create a botocore S3 client"""
+ return botocore.session.Session().create_client(
+ 's3', **self._client_kwargs
+ )
+
+
+class TransferMonitor:
+ def __init__(self):
+ """Monitors transfers for cross-process communication
+
+ Notifications can be sent to the monitor and information can be
+ retrieved from the monitor for a particular transfer. This abstraction
+ is ran in a ``multiprocessing.managers.BaseManager`` in order to be
+ shared across processes.
+ """
+ # TODO: Add logic that removes the TransferState if the transfer is
+ # marked as done and the reference to the future is no longer being
+ # held onto. Without this logic, this dictionary will continue to
+ # grow in size with no limit.
+ self._transfer_states = {}
+ self._id_count = 0
+ self._init_lock = threading.Lock()
+
+ def notify_new_transfer(self):
+ with self._init_lock:
+ transfer_id = self._id_count
+ self._transfer_states[transfer_id] = TransferState()
+ self._id_count += 1
+ return transfer_id
+
+ def is_done(self, transfer_id):
+ """Determine a particular transfer is complete
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: True, if done. False, otherwise.
+ """
+ return self._transfer_states[transfer_id].done
+
+ def notify_done(self, transfer_id):
+ """Notify a particular transfer is complete
+
+ :param transfer_id: Unique identifier for the transfer
+ """
+ self._transfer_states[transfer_id].set_done()
+
+ def poll_for_result(self, transfer_id):
+ """Poll for the result of a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: If the transfer succeeded, it will return the result. If the
+ transfer failed, it will raise the exception associated to the
+ failure.
+ """
+ self._transfer_states[transfer_id].wait_till_done()
+ exception = self._transfer_states[transfer_id].exception
+ if exception:
+ raise exception
+ return None
+
+ def notify_exception(self, transfer_id, exception):
+ """Notify an exception was encountered for a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :param exception: The exception encountered for that transfer
+ """
+ # TODO: Not all exceptions are pickleable so if we are running
+ # this in a multiprocessing.BaseManager we will want to
+ # make sure to update this signature to ensure pickleability of the
+ # arguments or have the ProxyObject do the serialization.
+ self._transfer_states[transfer_id].exception = exception
+
+ def notify_cancel_all_in_progress(self):
+ for transfer_state in self._transfer_states.values():
+ if not transfer_state.done:
+ transfer_state.exception = CancelledError()
+
+ def get_exception(self, transfer_id):
+ """Retrieve the exception encountered for the transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: The exception encountered for that transfer. Otherwise
+ if there were no exceptions, returns None.
+ """
+ return self._transfer_states[transfer_id].exception
+
+ def notify_expected_jobs_to_complete(self, transfer_id, num_jobs):
+ """Notify the amount of jobs expected for a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :param num_jobs: The number of jobs to complete the transfer
+ """
+ self._transfer_states[transfer_id].jobs_to_complete = num_jobs
+
+ def notify_job_complete(self, transfer_id):
+ """Notify that a single job is completed for a transfer
+
+ :param transfer_id: Unique identifier for the transfer
+ :return: The number of jobs remaining to complete the transfer
+ """
+ return self._transfer_states[transfer_id].decrement_jobs_to_complete()
+
+
+class TransferState:
+ """Represents the current state of an individual transfer"""
+
+ # NOTE: Ideally the TransferState object would be used directly by the
+ # various different abstractions in the ProcessPoolDownloader and remove
+ # the need for the TransferMonitor. However, it would then impose the
+ # constraint that two hops are required to make or get any changes in the
+ # state of a transfer across processes: one hop to get a proxy object for
+ # the TransferState and then a second hop to communicate calling the
+ # specific TransferState method.
+ def __init__(self):
+ self._exception = None
+ self._done_event = threading.Event()
+ self._job_lock = threading.Lock()
+ self._jobs_to_complete = 0
+
+ @property
+ def done(self):
+ return self._done_event.is_set()
+
+ def set_done(self):
+ self._done_event.set()
+
+ def wait_till_done(self):
+ self._done_event.wait(MAXINT)
+
+ @property
+ def exception(self):
+ return self._exception
+
+ @exception.setter
+ def exception(self, val):
+ self._exception = val
+
+ @property
+ def jobs_to_complete(self):
+ return self._jobs_to_complete
+
+ @jobs_to_complete.setter
+ def jobs_to_complete(self, val):
+ self._jobs_to_complete = val
+
+ def decrement_jobs_to_complete(self):
+ with self._job_lock:
+ self._jobs_to_complete -= 1
+ return self._jobs_to_complete
+
+
+class TransferMonitorManager(BaseManager):
+ pass
+
+
+TransferMonitorManager.register('TransferMonitor', TransferMonitor)
+
+
+class BaseS3TransferProcess(multiprocessing.Process):
+ def __init__(self, client_factory):
+ super().__init__()
+ self._client_factory = client_factory
+ self._client = None
+
+ def run(self):
+ # Clients are not pickleable so their instantiation cannot happen
+ # in the __init__ for processes that are created under the
+ # spawn method.
+ self._client = self._client_factory.create_client()
+ with ignore_ctrl_c():
+ # By default these processes are ran as child processes to the
+ # main process. Any Ctrl-c encountered in the main process is
+ # propagated to the child process and interrupt it at any time.
+ # To avoid any potentially bad states caused from an interrupt
+ # (i.e. a transfer failing to notify its done or making the
+ # communication protocol become out of sync with the
+ # TransferMonitor), we ignore all Ctrl-C's and allow the main
+ # process to notify these child processes when to stop processing
+ # jobs.
+ self._do_run()
+
+ def _do_run(self):
+ raise NotImplementedError('_do_run()')
+
+
+class GetObjectSubmitter(BaseS3TransferProcess):
+ def __init__(
+ self,
+ transfer_config,
+ client_factory,
+ transfer_monitor,
+ osutil,
+ download_request_queue,
+ worker_queue,
+ ):
+ """Submit GetObjectJobs to fulfill a download file request
+
+ :param transfer_config: Configuration for transfers.
+ :param client_factory: ClientFactory for creating S3 clients.
+ :param transfer_monitor: Monitor for notifying and retrieving state
+ of transfer.
+ :param osutil: OSUtils object to use for os-related behavior when
+ performing the transfer.
+ :param download_request_queue: Queue to retrieve download file
+ requests.
+ :param worker_queue: Queue to submit GetObjectJobs for workers
+ to perform.
+ """
+ super().__init__(client_factory)
+ self._transfer_config = transfer_config
+ self._transfer_monitor = transfer_monitor
+ self._osutil = osutil
+ self._download_request_queue = download_request_queue
+ self._worker_queue = worker_queue
+
+ def _do_run(self):
+ while True:
+ download_file_request = self._download_request_queue.get()
+ if download_file_request == SHUTDOWN_SIGNAL:
+ logger.debug('Submitter shutdown signal received.')
+ return
+ try:
+ self._submit_get_object_jobs(download_file_request)
+ except Exception as e:
+ logger.debug(
+ 'Exception caught when submitting jobs for '
+ 'download file request %s: %s',
+ download_file_request,
+ e,
+ exc_info=True,
+ )
+ self._transfer_monitor.notify_exception(
+ download_file_request.transfer_id, e
+ )
+ self._transfer_monitor.notify_done(
+ download_file_request.transfer_id
+ )
+
+ def _submit_get_object_jobs(self, download_file_request):
+ size = self._get_size(download_file_request)
+ temp_filename = self._allocate_temp_file(download_file_request, size)
+ if size < self._transfer_config.multipart_threshold:
+ self._submit_single_get_object_job(
+ download_file_request, temp_filename
+ )
+ else:
+ self._submit_ranged_get_object_jobs(
+ download_file_request, temp_filename, size
+ )
+
+ def _get_size(self, download_file_request):
+ expected_size = download_file_request.expected_size
+ if expected_size is None:
+ expected_size = self._client.head_object(
+ Bucket=download_file_request.bucket,
+ Key=download_file_request.key,
+ **download_file_request.extra_args,
+ )['ContentLength']
+ return expected_size
+
+ def _allocate_temp_file(self, download_file_request, size):
+ temp_filename = self._osutil.get_temp_filename(
+ download_file_request.filename
+ )
+ self._osutil.allocate(temp_filename, size)
+ return temp_filename
+
+ def _submit_single_get_object_job(
+ self, download_file_request, temp_filename
+ ):
+ self._notify_jobs_to_complete(download_file_request.transfer_id, 1)
+ self._submit_get_object_job(
+ transfer_id=download_file_request.transfer_id,
+ bucket=download_file_request.bucket,
+ key=download_file_request.key,
+ temp_filename=temp_filename,
+ offset=0,
+ extra_args=download_file_request.extra_args,
+ filename=download_file_request.filename,
+ )
+
+ def _submit_ranged_get_object_jobs(
+ self, download_file_request, temp_filename, size
+ ):
+ part_size = self._transfer_config.multipart_chunksize
+ num_parts = calculate_num_parts(size, part_size)
+ self._notify_jobs_to_complete(
+ download_file_request.transfer_id, num_parts
+ )
+ for i in range(num_parts):
+ offset = i * part_size
+ range_parameter = calculate_range_parameter(
+ part_size, i, num_parts
+ )
+ get_object_kwargs = {'Range': range_parameter}
+ get_object_kwargs.update(download_file_request.extra_args)
+ self._submit_get_object_job(
+ transfer_id=download_file_request.transfer_id,
+ bucket=download_file_request.bucket,
+ key=download_file_request.key,
+ temp_filename=temp_filename,
+ offset=offset,
+ extra_args=get_object_kwargs,
+ filename=download_file_request.filename,
+ )
+
+ def _submit_get_object_job(self, **get_object_job_kwargs):
+ self._worker_queue.put(GetObjectJob(**get_object_job_kwargs))
+
+ def _notify_jobs_to_complete(self, transfer_id, jobs_to_complete):
+ logger.debug(
+ 'Notifying %s job(s) to complete for transfer_id %s.',
+ jobs_to_complete,
+ transfer_id,
+ )
+ self._transfer_monitor.notify_expected_jobs_to_complete(
+ transfer_id, jobs_to_complete
+ )
+
+
+class GetObjectWorker(BaseS3TransferProcess):
+ # TODO: It may make sense to expose these class variables as configuration
+ # options if users want to tweak them.
+ _MAX_ATTEMPTS = 5
+ _IO_CHUNKSIZE = 2 * MB
+
+ def __init__(self, queue, client_factory, transfer_monitor, osutil):
+ """Fulfills GetObjectJobs
+
+ Downloads the S3 object, writes it to the specified file, and
+ renames the file to its final location if it completes the final
+ job for a particular transfer.
+
+ :param queue: Queue for retrieving GetObjectJob's
+ :param client_factory: ClientFactory for creating S3 clients
+ :param transfer_monitor: Monitor for notifying
+ :param osutil: OSUtils object to use for os-related behavior when
+ performing the transfer.
+ """
+ super().__init__(client_factory)
+ self._queue = queue
+ self._client_factory = client_factory
+ self._transfer_monitor = transfer_monitor
+ self._osutil = osutil
+
+ def _do_run(self):
+ while True:
+ job = self._queue.get()
+ if job == SHUTDOWN_SIGNAL:
+ logger.debug('Worker shutdown signal received.')
+ return
+ if not self._transfer_monitor.get_exception(job.transfer_id):
+ self._run_get_object_job(job)
+ else:
+ logger.debug(
+ 'Skipping get object job %s because there was a previous '
+ 'exception.',
+ job,
+ )
+ remaining = self._transfer_monitor.notify_job_complete(
+ job.transfer_id
+ )
+ logger.debug(
+ '%s jobs remaining for transfer_id %s.',
+ remaining,
+ job.transfer_id,
+ )
+ if not remaining:
+ self._finalize_download(
+ job.transfer_id, job.temp_filename, job.filename
+ )
+
+ def _run_get_object_job(self, job):
+ try:
+ self._do_get_object(
+ bucket=job.bucket,
+ key=job.key,
+ temp_filename=job.temp_filename,
+ extra_args=job.extra_args,
+ offset=job.offset,
+ )
+ except Exception as e:
+ logger.debug(
+ 'Exception caught when downloading object for '
+ 'get object job %s: %s',
+ job,
+ e,
+ exc_info=True,
+ )
+ self._transfer_monitor.notify_exception(job.transfer_id, e)
+
+ def _do_get_object(self, bucket, key, extra_args, temp_filename, offset):
+ last_exception = None
+ for i in range(self._MAX_ATTEMPTS):
+ try:
+ response = self._client.get_object(
+ Bucket=bucket, Key=key, **extra_args
+ )
+ self._write_to_file(temp_filename, offset, response['Body'])
+ return
+ except S3_RETRYABLE_DOWNLOAD_ERRORS as e:
+ logger.debug(
+ 'Retrying exception caught (%s), '
+ 'retrying request, (attempt %s / %s)',
+ e,
+ i + 1,
+ self._MAX_ATTEMPTS,
+ exc_info=True,
+ )
+ last_exception = e
+ raise RetriesExceededError(last_exception)
+
+ def _write_to_file(self, filename, offset, body):
+ with open(filename, 'rb+') as f:
+ f.seek(offset)
+ chunks = iter(lambda: body.read(self._IO_CHUNKSIZE), b'')
+ for chunk in chunks:
+ f.write(chunk)
+
+ def _finalize_download(self, transfer_id, temp_filename, filename):
+ if self._transfer_monitor.get_exception(transfer_id):
+ self._osutil.remove_file(temp_filename)
+ else:
+ self._do_file_rename(transfer_id, temp_filename, filename)
+ self._transfer_monitor.notify_done(transfer_id)
+
+ def _do_file_rename(self, transfer_id, temp_filename, filename):
+ try:
+ self._osutil.rename_file(temp_filename, filename)
+ except Exception as e:
+ self._transfer_monitor.notify_exception(transfer_id, e)
+ self._osutil.remove_file(temp_filename)
diff --git a/contrib/python/s3transfer/py3/s3transfer/subscribers.py b/contrib/python/s3transfer/py3/s3transfer/subscribers.py
index 27e8fe0c4b..cf0dbaa0d7 100644
--- a/contrib/python/s3transfer/py3/s3transfer/subscribers.py
+++ b/contrib/python/s3transfer/py3/s3transfer/subscribers.py
@@ -14,34 +14,34 @@ from s3transfer.compat import accepts_kwargs
from s3transfer.exceptions import InvalidSubscriberMethodError
-class BaseSubscriber:
+class BaseSubscriber:
"""The base subscriber class
It is recommended that all subscriber implementations subclass and then
override the subscription methods (i.e. on_{subsribe_type}() methods).
"""
- VALID_SUBSCRIBER_TYPES = ['queued', 'progress', 'done']
-
+ VALID_SUBSCRIBER_TYPES = ['queued', 'progress', 'done']
+
def __new__(cls, *args, **kwargs):
cls._validate_subscriber_methods()
- return super().__new__(cls)
+ return super().__new__(cls)
@classmethod
def _validate_subscriber_methods(cls):
for subscriber_type in cls.VALID_SUBSCRIBER_TYPES:
subscriber_method = getattr(cls, 'on_' + subscriber_type)
- if not callable(subscriber_method):
+ if not callable(subscriber_method):
raise InvalidSubscriberMethodError(
- 'Subscriber method %s must be callable.'
- % subscriber_method
- )
+ 'Subscriber method %s must be callable.'
+ % subscriber_method
+ )
if not accepts_kwargs(subscriber_method):
raise InvalidSubscriberMethodError(
'Subscriber method %s must accept keyword '
- 'arguments (**kwargs)' % subscriber_method
- )
+ 'arguments (**kwargs)' % subscriber_method
+ )
def on_queued(self, future, **kwargs):
"""Callback to be invoked when transfer request gets queued
diff --git a/contrib/python/s3transfer/py3/s3transfer/tasks.py b/contrib/python/s3transfer/py3/s3transfer/tasks.py
index 7153ac547c..1bad981264 100644
--- a/contrib/python/s3transfer/py3/s3transfer/tasks.py
+++ b/contrib/python/s3transfer/py3/s3transfer/tasks.py
@@ -18,21 +18,21 @@ from s3transfer.utils import get_callbacks
logger = logging.getLogger(__name__)
-class Task:
+class Task:
"""A task associated to a TransferFuture request
This is a base class for other classes to subclass from. All subclassed
classes must implement the main() method.
"""
-
- def __init__(
- self,
- transfer_coordinator,
- main_kwargs=None,
- pending_main_kwargs=None,
- done_callbacks=None,
- is_final=False,
- ):
+
+ def __init__(
+ self,
+ transfer_coordinator,
+ main_kwargs=None,
+ pending_main_kwargs=None,
+ done_callbacks=None,
+ is_final=False,
+ ):
"""
:type transfer_coordinator: s3transfer.futures.TransferCoordinator
:param transfer_coordinator: The context associated to the
@@ -85,22 +85,22 @@ class Task:
# These are the general main_kwarg parameters that we want to
# display in the repr.
params_to_display = [
- 'bucket',
- 'key',
- 'part_number',
- 'final_filename',
- 'transfer_future',
- 'offset',
- 'extra_args',
+ 'bucket',
+ 'key',
+ 'part_number',
+ 'final_filename',
+ 'transfer_future',
+ 'offset',
+ 'extra_args',
]
main_kwargs_to_display = self._get_kwargs_with_params_to_include(
- self._main_kwargs, params_to_display
- )
- return '{}(transfer_id={}, {})'.format(
- self.__class__.__name__,
- self._transfer_coordinator.transfer_id,
- main_kwargs_to_display,
- )
+ self._main_kwargs, params_to_display
+ )
+ return '{}(transfer_id={}, {})'.format(
+ self.__class__.__name__,
+ self._transfer_coordinator.transfer_id,
+ main_kwargs_to_display,
+ )
@property
def transfer_id(self):
@@ -130,7 +130,7 @@ class Task:
# Gather up all of the main keyword arguments for main().
# This includes the immediately provided main_kwargs and
# the values for pending_main_kwargs that source from the return
- # values from the task's dependent futures.
+ # values from the task's dependent futures.
kwargs = self._get_all_main_kwargs()
# If the task is not done (really only if some other related
# task to the TransferFuture had failed) then execute the task's
@@ -154,10 +154,10 @@ class Task:
# if they are going to make the logs hard to follow.
params_to_exclude = ['data']
kwargs_to_display = self._get_kwargs_with_params_to_exclude(
- kwargs, params_to_exclude
- )
+ kwargs, params_to_exclude
+ )
# Log what is about to be executed.
- logger.debug(f"Executing task {self} with kwargs {kwargs_to_display}")
+ logger.debug(f"Executing task {self} with kwargs {kwargs_to_display}")
return_value = self._main(**kwargs)
# If the task is the final task, then set the TransferFuture's
@@ -187,7 +187,7 @@ class Task:
# If the pending main keyword arg is a list then extend the list.
if isinstance(future, list):
futures_to_wait_on.extend(future)
- # If the pending main keyword arg is a future append it to the list.
+ # If the pending main keyword arg is a future append it to the list.
else:
futures_to_wait_on.append(future)
# Now wait for all of the futures to complete.
@@ -198,20 +198,20 @@ class Task:
#
# concurrent.futures.wait() is not used instead because of this
# reported issue: https://bugs.python.org/issue20319.
- # The issue would occasionally cause multipart uploads to hang
+ # The issue would occasionally cause multipart uploads to hang
# when wait() was called. With this approach, it avoids the
# concurrency bug by removing any association with concurrent.futures
# implementation of waiters.
logger.debug(
- '%s about to wait for the following futures %s', self, futures
- )
+ '%s about to wait for the following futures %s', self, futures
+ )
for future in futures:
try:
logger.debug('%s about to wait for %s', self, future)
future.result()
except Exception:
# result() can also produce exceptions. We want to ignore
- # these to be deferred to error handling down the road.
+ # these to be deferred to error handling down the road.
pass
logger.debug('%s done waiting for dependent futures', self)
@@ -243,7 +243,7 @@ class SubmissionTask(Task):
Submission tasks are the top-level task used to submit a series of tasks
to execute a particular transfer.
"""
-
+
def _main(self, transfer_future, **kwargs):
"""
:type transfer_future: s3transfer.futures.TransferFuture
@@ -274,7 +274,7 @@ class SubmissionTask(Task):
# the first place so we need to account accordingly.
#
# Note that BaseException is caught, instead of Exception, because
- # for some implementations of executors, specifically the serial
+ # for some implementations of executors, specifically the serial
# implementation, the SubmissionTask is directly exposed to
# KeyboardInterupts and so needs to cleanup and signal done
# for those as well.
@@ -283,7 +283,7 @@ class SubmissionTask(Task):
self._log_and_set_exception(e)
# Wait for all possibly associated futures that may have spawned
- # from this submission task have finished before we announce the
+ # from this submission task have finished before we announce the
# transfer done.
self._wait_for_all_submitted_futures_to_complete()
@@ -292,7 +292,7 @@ class SubmissionTask(Task):
self._transfer_coordinator.announce_done()
def _submit(self, transfer_future, **kwargs):
- """The submission method to be implemented
+ """The submission method to be implemented
:type transfer_future: s3transfer.futures.TransferFuture
:param transfer_future: The transfer future associated with the
@@ -317,9 +317,9 @@ class SubmissionTask(Task):
self._wait_until_all_complete(submitted_futures)
# However, more futures may have been submitted as we waited so
# we need to check again for any more associated futures.
- possibly_more_submitted_futures = (
+ possibly_more_submitted_futures = (
self._transfer_coordinator.associated_futures
- )
+ )
# If the current list of submitted futures is equal to the
# the list of associated futures for when after the wait completes,
# we can ensure no more futures were submitted in waiting on
@@ -333,36 +333,36 @@ class SubmissionTask(Task):
class CreateMultipartUploadTask(Task):
"""Task to initiate a multipart upload"""
-
+
def _main(self, client, bucket, key, extra_args):
"""
:param client: The client to use when calling CreateMultipartUpload
:param bucket: The name of the bucket to upload to
:param key: The name of the key to upload to
:param extra_args: A dictionary of any extra arguments that may be
- used in the initialization.
+ used in the initialization.
:returns: The upload id of the multipart upload
"""
# Create the multipart upload.
response = client.create_multipart_upload(
- Bucket=bucket, Key=key, **extra_args
- )
+ Bucket=bucket, Key=key, **extra_args
+ )
upload_id = response['UploadId']
# Add a cleanup if the multipart upload fails at any point.
self._transfer_coordinator.add_failure_cleanup(
- client.abort_multipart_upload,
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
+ client.abort_multipart_upload,
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
)
return upload_id
class CompleteMultipartUploadTask(Task):
"""Task to complete a multipart upload"""
-
+
def _main(self, client, bucket, key, upload_id, parts, extra_args):
"""
:param client: The client to use when calling CompleteMultipartUpload
@@ -379,9 +379,9 @@ class CompleteMultipartUploadTask(Task):
used in completing the multipart transfer.
"""
client.complete_multipart_upload(
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
MultipartUpload={'Parts': parts},
- **extra_args,
- )
+ **extra_args,
+ )
diff --git a/contrib/python/s3transfer/py3/s3transfer/upload.py b/contrib/python/s3transfer/py3/s3transfer/upload.py
index 1407d14361..31ade051d7 100644
--- a/contrib/python/s3transfer/py3/s3transfer/upload.py
+++ b/contrib/python/s3transfer/py3/s3transfer/upload.py
@@ -11,25 +11,25 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import math
-from io import BytesIO
+from io import BytesIO
-from s3transfer.compat import readable, seekable
+from s3transfer.compat import readable, seekable
from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
-from s3transfer.tasks import (
- CompleteMultipartUploadTask,
- CreateMultipartUploadTask,
- SubmissionTask,
- Task,
-)
-from s3transfer.utils import (
- ChunksizeAdjuster,
- DeferredOpenFile,
- get_callbacks,
- get_filtered_dict,
-)
-
-
-class AggregatedProgressCallback:
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import (
+ ChunksizeAdjuster,
+ DeferredOpenFile,
+ get_callbacks,
+ get_filtered_dict,
+)
+
+
+class AggregatedProgressCallback:
def __init__(self, callbacks, threshold=1024 * 256):
"""Aggregates progress updates for every provided progress callback
@@ -62,10 +62,10 @@ class AggregatedProgressCallback:
self._bytes_seen = 0
-class InterruptReader:
+class InterruptReader:
"""Wrapper that can interrupt reading using an error
- It uses a transfer coordinator to propagate an error if it notices
+ It uses a transfer coordinator to propagate an error if it notices
that a read is being made while the file is being read from.
:type fileobj: file-like obj
@@ -75,7 +75,7 @@ class InterruptReader:
:param transfer_coordinator: The transfer coordinator to use if the
reader needs to be interrupted.
"""
-
+
def __init__(self, fileobj, transfer_coordinator):
self._fileobj = fileobj
self._transfer_coordinator = transfer_coordinator
@@ -90,8 +90,8 @@ class InterruptReader:
raise self._transfer_coordinator.exception
return self._fileobj.read(amount)
- def seek(self, where, whence=0):
- self._fileobj.seek(where, whence)
+ def seek(self, where, whence=0):
+ self._fileobj.seek(where, whence)
def tell(self):
return self._fileobj.tell()
@@ -106,7 +106,7 @@ class InterruptReader:
self.close()
-class UploadInputManager:
+class UploadInputManager:
"""Base manager class for handling various types of files for uploads
This class is typically used for the UploadSubmissionTask class to help
@@ -121,7 +121,7 @@ class UploadInputManager:
that may be accepted. All implementations must subclass and override
public methods from this class.
"""
-
+
def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None):
self._osutil = osutil
self._transfer_coordinator = transfer_coordinator
@@ -152,7 +152,7 @@ class UploadInputManager:
memory. False if the manager will not directly store the body in
memory.
"""
- raise NotImplementedError('must implement store_body_in_memory()')
+ raise NotImplementedError('must implement store_body_in_memory()')
def provide_transfer_size(self, transfer_future):
"""Provides the transfer size of an upload
@@ -173,7 +173,7 @@ class UploadInputManager:
:rtype: boolean
:returns: True, if the upload should be multipart based on
- configuration and size. False, otherwise.
+ configuration and size. False, otherwise.
"""
raise NotImplementedError('must implement requires_multipart_upload()')
@@ -212,8 +212,8 @@ class UploadInputManager:
fileobj = InterruptReader(fileobj, self._transfer_coordinator)
if self._bandwidth_limiter:
fileobj = self._bandwidth_limiter.get_bandwith_limited_stream(
- fileobj, self._transfer_coordinator, enabled=False
- )
+ fileobj, self._transfer_coordinator, enabled=False
+ )
return fileobj
def _get_progress_callbacks(self, transfer_future):
@@ -231,18 +231,18 @@ class UploadInputManager:
class UploadFilenameInputManager(UploadInputManager):
"""Upload utility for filenames"""
-
+
@classmethod
def is_compatible(cls, upload_source):
- return isinstance(upload_source, str)
+ return isinstance(upload_source, str)
def stores_body_in_memory(self, operation_name):
return False
def provide_transfer_size(self, transfer_future):
transfer_future.meta.provide_transfer_size(
- self._osutil.get_file_size(transfer_future.meta.call_args.fileobj)
- )
+ self._osutil.get_file_size(transfer_future.meta.call_args.fileobj)
+ )
def requires_multipart_upload(self, transfer_future, config):
return transfer_future.meta.size >= config.multipart_threshold
@@ -250,8 +250,8 @@ class UploadFilenameInputManager(UploadInputManager):
def get_put_object_body(self, transfer_future):
# Get a file-like object for the given input
fileobj, full_size = self._get_put_object_fileobj_with_full_size(
- transfer_future
- )
+ transfer_future
+ )
# Wrap fileobj with interrupt reader that will quickly cancel
# uploads if needed instead of having to wait for the socket
@@ -264,12 +264,12 @@ class UploadFilenameInputManager(UploadInputManager):
# Return the file-like object wrapped into a ReadFileChunk to get
# progress.
return self._osutil.open_file_chunk_reader_from_fileobj(
- fileobj=fileobj,
- chunk_size=size,
- full_file_size=full_size,
- callbacks=callbacks,
- close_callbacks=close_callbacks,
- )
+ fileobj=fileobj,
+ chunk_size=size,
+ full_file_size=full_size,
+ callbacks=callbacks,
+ close_callbacks=close_callbacks,
+ )
def yield_upload_part_bodies(self, transfer_future, chunksize):
full_file_size = transfer_future.meta.size
@@ -281,11 +281,11 @@ class UploadFilenameInputManager(UploadInputManager):
# Get a file-like object for that part and the size of the full
# file size for the associated file-like object for that part.
fileobj, full_size = self._get_upload_part_fileobj_with_full_size(
- transfer_future.meta.call_args.fileobj,
- start_byte=start_byte,
- part_size=chunksize,
- full_file_size=full_file_size,
- )
+ transfer_future.meta.call_args.fileobj,
+ start_byte=start_byte,
+ part_size=chunksize,
+ full_file_size=full_file_size,
+ )
# Wrap fileobj with interrupt reader that will quickly cancel
# uploads if needed instead of having to wait for the socket
@@ -294,18 +294,18 @@ class UploadFilenameInputManager(UploadInputManager):
# Wrap the file-like object into a ReadFileChunk to get progress.
read_file_chunk = self._osutil.open_file_chunk_reader_from_fileobj(
- fileobj=fileobj,
- chunk_size=chunksize,
- full_file_size=full_size,
- callbacks=callbacks,
- close_callbacks=close_callbacks,
- )
+ fileobj=fileobj,
+ chunk_size=chunksize,
+ full_file_size=full_size,
+ callbacks=callbacks,
+ close_callbacks=close_callbacks,
+ )
yield part_number, read_file_chunk
def _get_deferred_open_file(self, fileobj, start_byte):
fileobj = DeferredOpenFile(
- fileobj, start_byte, open_function=self._osutil.open
- )
+ fileobj, start_byte, open_function=self._osutil.open
+ )
return fileobj
def _get_put_object_fileobj_with_full_size(self, transfer_future):
@@ -319,12 +319,12 @@ class UploadFilenameInputManager(UploadInputManager):
return self._get_deferred_open_file(fileobj, start_byte), full_size
def _get_num_parts(self, transfer_future, part_size):
- return int(math.ceil(transfer_future.meta.size / float(part_size)))
+ return int(math.ceil(transfer_future.meta.size / float(part_size)))
class UploadSeekableInputManager(UploadFilenameInputManager):
"""Upload utility for an open file object"""
-
+
@classmethod
def is_compatible(cls, upload_source):
return readable(upload_source) and seekable(upload_source)
@@ -345,8 +345,8 @@ class UploadSeekableInputManager(UploadFilenameInputManager):
end_position = fileobj.tell()
fileobj.seek(start_position)
transfer_future.meta.provide_transfer_size(
- end_position - start_position
- )
+ end_position - start_position
+ )
def _get_upload_part_fileobj_with_full_size(self, fileobj, **kwargs):
# Note: It is unfortunate that in order to do a multithreaded
@@ -354,14 +354,14 @@ class UploadSeekableInputManager(UploadFilenameInputManager):
# since there is not really a mechanism in python (i.e. os.dup
# points to the same OS filehandle which causes concurrency
# issues). So instead we need to read from the fileobj and
- # chunk the data out to separate file-like objects in memory.
+ # chunk the data out to separate file-like objects in memory.
data = fileobj.read(kwargs['part_size'])
# We return the length of the data instead of the full_file_size
- # because we partitioned the data into separate BytesIO objects
+ # because we partitioned the data into separate BytesIO objects
# meaning the BytesIO object has no knowledge of its start position
# relative the input source nor access to the rest of the input
# source. So we must treat it as its own standalone file.
- return BytesIO(data), len(data)
+ return BytesIO(data), len(data)
def _get_put_object_fileobj_with_full_size(self, transfer_future):
fileobj = transfer_future.meta.call_args.fileobj
@@ -373,9 +373,9 @@ class UploadSeekableInputManager(UploadFilenameInputManager):
class UploadNonSeekableInputManager(UploadInputManager):
"""Upload utility for a file-like object that cannot seek."""
-
+
def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None):
- super().__init__(osutil, transfer_coordinator, bandwidth_limiter)
+ super().__init__(osutil, transfer_coordinator, bandwidth_limiter)
self._initial_data = b''
@classmethod
@@ -413,8 +413,8 @@ class UploadNonSeekableInputManager(UploadInputManager):
fileobj = transfer_future.meta.call_args.fileobj
body = self._wrap_data(
- self._initial_data + fileobj.read(), callbacks, close_callbacks
- )
+ self._initial_data + fileobj.read(), callbacks, close_callbacks
+ )
# Zero out the stored data so we don't have additional copies
# hanging around in memory.
@@ -434,8 +434,8 @@ class UploadNonSeekableInputManager(UploadInputManager):
if not part_content:
break
part_object = self._wrap_data(
- part_content, callbacks, close_callbacks
- )
+ part_content, callbacks, close_callbacks
+ )
# Zero out part_content to avoid hanging on to additional data.
part_content = None
@@ -462,7 +462,7 @@ class UploadNonSeekableInputManager(UploadInputManager):
if len(self._initial_data) == 0:
return fileobj.read(amount)
- # If the requested number of bytes is less than the amount of
+ # If the requested number of bytes is less than the amount of
# initial data, pull entirely from initial data.
if amount <= len(self._initial_data):
data = self._initial_data[:amount]
@@ -499,14 +499,14 @@ class UploadNonSeekableInputManager(UploadInputManager):
:return: Fully wrapped data.
"""
- fileobj = self._wrap_fileobj(BytesIO(data))
+ fileobj = self._wrap_fileobj(BytesIO(data))
return self._osutil.open_file_chunk_reader_from_fileobj(
- fileobj=fileobj,
- chunk_size=len(data),
- full_file_size=len(data),
- callbacks=callbacks,
- close_callbacks=close_callbacks,
- )
+ fileobj=fileobj,
+ chunk_size=len(data),
+ full_file_size=len(data),
+ callbacks=callbacks,
+ close_callbacks=close_callbacks,
+ )
class UploadSubmissionTask(SubmissionTask):
@@ -517,13 +517,13 @@ class UploadSubmissionTask(SubmissionTask):
'SSECustomerAlgorithm',
'SSECustomerKeyMD5',
'RequestPayer',
- 'ExpectedBucketOwner',
+ 'ExpectedBucketOwner',
]
- COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
+ COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
def _get_upload_input_manager_cls(self, transfer_future):
- """Retrieves a class for managing input for an upload based on file type
+ """Retrieves a class for managing input for an upload based on file type
:type transfer_future: s3transfer.futures.TransferFuture
:param transfer_future: The transfer future for the request
@@ -535,7 +535,7 @@ class UploadSubmissionTask(SubmissionTask):
upload_manager_resolver_chain = [
UploadFilenameInputManager,
UploadSeekableInputManager,
- UploadNonSeekableInputManager,
+ UploadNonSeekableInputManager,
]
fileobj = transfer_future.meta.call_args.fileobj
@@ -543,20 +543,20 @@ class UploadSubmissionTask(SubmissionTask):
if upload_manager_cls.is_compatible(fileobj):
return upload_manager_cls
raise RuntimeError(
- 'Input {} of type: {} is not supported.'.format(
- fileobj, type(fileobj)
- )
- )
-
- def _submit(
- self,
- client,
- config,
- osutil,
- request_executor,
- transfer_future,
- bandwidth_limiter=None,
- ):
+ 'Input {} of type: {} is not supported.'.format(
+ fileobj, type(fileobj)
+ )
+ )
+
+ def _submit(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ bandwidth_limiter=None,
+ ):
"""
:param client: The client associated with the transfer manager
@@ -576,8 +576,8 @@ class UploadSubmissionTask(SubmissionTask):
transfer request that tasks are being submitted for
"""
upload_input_manager = self._get_upload_input_manager_cls(
- transfer_future
- )(osutil, self._transfer_coordinator, bandwidth_limiter)
+ transfer_future
+ )(osutil, self._transfer_coordinator, bandwidth_limiter)
# Determine the size if it was not provided
if transfer_future.meta.size is None:
@@ -585,41 +585,41 @@ class UploadSubmissionTask(SubmissionTask):
# Do a multipart upload if needed, otherwise do a regular put object.
if not upload_input_manager.requires_multipart_upload(
- transfer_future, config
- ):
+ transfer_future, config
+ ):
self._submit_upload_request(
- client,
- config,
- osutil,
- request_executor,
- transfer_future,
- upload_input_manager,
- )
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ )
else:
self._submit_multipart_request(
- client,
- config,
- osutil,
- request_executor,
- transfer_future,
- upload_input_manager,
- )
-
- def _submit_upload_request(
- self,
- client,
- config,
- osutil,
- request_executor,
- transfer_future,
- upload_input_manager,
- ):
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ )
+
+ def _submit_upload_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ ):
call_args = transfer_future.meta.call_args
# Get any tags that need to be associated to the put object task
put_object_tag = self._get_upload_task_tag(
- upload_input_manager, 'put_object'
- )
+ upload_input_manager, 'put_object'
+ )
# Submit the request of a single upload.
self._transfer_coordinator.submit(
@@ -629,26 +629,26 @@ class UploadSubmissionTask(SubmissionTask):
main_kwargs={
'client': client,
'fileobj': upload_input_manager.get_put_object_body(
- transfer_future
- ),
+ transfer_future
+ ),
'bucket': call_args.bucket,
'key': call_args.key,
- 'extra_args': call_args.extra_args,
+ 'extra_args': call_args.extra_args,
},
- is_final=True,
+ is_final=True,
),
- tag=put_object_tag,
+ tag=put_object_tag,
)
- def _submit_multipart_request(
- self,
- client,
- config,
- osutil,
- request_executor,
- transfer_future,
- upload_input_manager,
- ):
+ def _submit_multipart_request(
+ self,
+ client,
+ config,
+ osutil,
+ request_executor,
+ transfer_future,
+ upload_input_manager,
+ ):
call_args = transfer_future.meta.call_args
# Submit the request to create a multipart upload.
@@ -661,8 +661,8 @@ class UploadSubmissionTask(SubmissionTask):
'bucket': call_args.bucket,
'key': call_args.key,
'extra_args': call_args.extra_args,
- },
- ),
+ },
+ ),
)
# Submit requests to upload the parts of the file.
@@ -672,15 +672,15 @@ class UploadSubmissionTask(SubmissionTask):
# Get any tags that need to be associated to the submitted task
# for upload the data
upload_part_tag = self._get_upload_task_tag(
- upload_input_manager, 'upload_part'
- )
+ upload_input_manager, 'upload_part'
+ )
size = transfer_future.meta.size
adjuster = ChunksizeAdjuster()
chunksize = adjuster.adjust_chunksize(config.multipart_chunksize, size)
part_iterator = upload_input_manager.yield_upload_part_bodies(
- transfer_future, chunksize
- )
+ transfer_future, chunksize
+ )
for part_number, fileobj in part_iterator:
part_futures.append(
@@ -694,19 +694,19 @@ class UploadSubmissionTask(SubmissionTask):
'bucket': call_args.bucket,
'key': call_args.key,
'part_number': part_number,
- 'extra_args': extra_part_args,
+ 'extra_args': extra_part_args,
},
pending_main_kwargs={
'upload_id': create_multipart_future
- },
+ },
),
- tag=upload_part_tag,
+ tag=upload_part_tag,
)
)
complete_multipart_extra_args = self._extra_complete_multipart_args(
- call_args.extra_args
- )
+ call_args.extra_args
+ )
# Submit the request to complete the multipart upload.
self._transfer_coordinator.submit(
request_executor,
@@ -720,10 +720,10 @@ class UploadSubmissionTask(SubmissionTask):
},
pending_main_kwargs={
'upload_id': create_multipart_future,
- 'parts': part_futures,
+ 'parts': part_futures,
},
- is_final=True,
- ),
+ is_final=True,
+ ),
)
def _extra_upload_part_args(self, extra_args):
@@ -743,7 +743,7 @@ class UploadSubmissionTask(SubmissionTask):
class PutObjectTask(Task):
"""Task to do a nonmultipart upload"""
-
+
def _main(self, client, fileobj, bucket, key, extra_args):
"""
:param client: The client to use when calling PutObject
@@ -759,10 +759,10 @@ class PutObjectTask(Task):
class UploadPartTask(Task):
"""Task to upload a part in a multipart upload"""
-
- def _main(
- self, client, fileobj, bucket, key, upload_id, part_number, extra_args
- ):
+
+ def _main(
+ self, client, fileobj, bucket, key, upload_id, part_number, extra_args
+ ):
"""
:param client: The client to use when calling PutObject
:param fileobj: The file to upload.
@@ -784,12 +784,12 @@ class UploadPartTask(Task):
"""
with fileobj as body:
response = client.upload_part(
- Bucket=bucket,
- Key=key,
- UploadId=upload_id,
- PartNumber=part_number,
- Body=body,
- **extra_args
- )
+ Bucket=bucket,
+ Key=key,
+ UploadId=upload_id,
+ PartNumber=part_number,
+ Body=body,
+ **extra_args
+ )
etag = response['ETag']
return {'ETag': etag, 'PartNumber': part_number}
diff --git a/contrib/python/s3transfer/py3/s3transfer/utils.py b/contrib/python/s3transfer/py3/s3transfer/utils.py
index 31ada34c65..ba881c67dd 100644
--- a/contrib/python/s3transfer/py3/s3transfer/utils.py
+++ b/contrib/python/s3transfer/py3/s3transfer/utils.py
@@ -11,19 +11,19 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import functools
-import logging
+import logging
import math
import os
-import random
-import socket
+import random
+import socket
import stat
import string
import threading
from collections import defaultdict
-from botocore.exceptions import IncompleteReadError, ReadTimeoutError
-
-from s3transfer.compat import SOCKET_ERROR, fallocate, rename_file
+from botocore.exceptions import IncompleteReadError, ReadTimeoutError
+
+from s3transfer.compat import SOCKET_ERROR, fallocate, rename_file
MAX_PARTS = 10000
# The maximum file size you can upload via S3 per request.
@@ -34,39 +34,39 @@ MIN_UPLOAD_CHUNKSIZE = 5 * (1024 ** 2)
logger = logging.getLogger(__name__)
-S3_RETRYABLE_DOWNLOAD_ERRORS = (
- socket.timeout,
- SOCKET_ERROR,
- ReadTimeoutError,
- IncompleteReadError,
-)
-
-
+S3_RETRYABLE_DOWNLOAD_ERRORS = (
+ socket.timeout,
+ SOCKET_ERROR,
+ ReadTimeoutError,
+ IncompleteReadError,
+)
+
+
def random_file_extension(num_digits=8):
return ''.join(random.choice(string.hexdigits) for _ in range(num_digits))
def signal_not_transferring(request, operation_name, **kwargs):
- if operation_name in ['PutObject', 'UploadPart'] and hasattr(
- request.body, 'signal_not_transferring'
- ):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'signal_not_transferring'
+ ):
request.body.signal_not_transferring()
def signal_transferring(request, operation_name, **kwargs):
- if operation_name in ['PutObject', 'UploadPart'] and hasattr(
- request.body, 'signal_transferring'
- ):
+ if operation_name in ['PutObject', 'UploadPart'] and hasattr(
+ request.body, 'signal_transferring'
+ ):
request.body.signal_transferring()
-def calculate_num_parts(size, part_size):
- return int(math.ceil(size / float(part_size)))
-
-
-def calculate_range_parameter(
- part_size, part_index, num_parts, total_size=None
-):
+def calculate_num_parts(size, part_size):
+ return int(math.ceil(size / float(part_size)))
+
+
+def calculate_range_parameter(
+ part_size, part_index, num_parts, total_size=None
+):
"""Calculate the range parameter for multipart downloads/copies
:type part_size: int
@@ -90,7 +90,7 @@ def calculate_range_parameter(
end_range = str(total_size - 1)
else:
end_range = start_range + part_size - 1
- range_param = f'bytes={start_range}-{end_range}'
+ range_param = f'bytes={start_range}-{end_range}'
return range_param
@@ -117,7 +117,7 @@ def get_callbacks(transfer_future, callback_type):
if hasattr(subscriber, callback_name):
callbacks.append(
functools.partial(
- getattr(subscriber, callback_name), future=transfer_future
+ getattr(subscriber, callback_name), future=transfer_future
)
)
return callbacks
@@ -157,7 +157,7 @@ def get_filtered_dict(original_dict, whitelisted_keys):
return filtered_dict
-class CallArgs:
+class CallArgs:
def __init__(self, **kwargs):
"""A class that records call arguments
@@ -169,33 +169,33 @@ class CallArgs:
setattr(self, arg, value)
-class FunctionContainer:
+class FunctionContainer:
"""An object that contains a function and any args or kwargs to call it
When called the provided function will be called with provided args
and kwargs.
"""
-
+
def __init__(self, func, *args, **kwargs):
self._func = func
self._args = args
self._kwargs = kwargs
def __repr__(self):
- return 'Function: {} with args {} and kwargs {}'.format(
- self._func, self._args, self._kwargs
- )
+ return 'Function: {} with args {} and kwargs {}'.format(
+ self._func, self._args, self._kwargs
+ )
def __call__(self):
return self._func(*self._args, **self._kwargs)
-class CountCallbackInvoker:
+class CountCallbackInvoker:
"""An abstraction to invoke a callback when a shared count reaches zero
:param callback: Callback invoke when finalized count reaches zero
"""
-
+
def __init__(self, callback):
self._lock = threading.Lock()
self._callback = callback
@@ -222,8 +222,8 @@ class CountCallbackInvoker:
with self._lock:
if self._count == 0:
raise RuntimeError(
- 'Counter is at zero. It cannot dip below zero'
- )
+ 'Counter is at zero. It cannot dip below zero'
+ )
self._count -= 1
if self._is_finalized and self._count == 0:
self._callback()
@@ -240,33 +240,33 @@ class CountCallbackInvoker:
self._callback()
-class OSUtils:
- _MAX_FILENAME_LEN = 255
-
+class OSUtils:
+ _MAX_FILENAME_LEN = 255
+
def get_file_size(self, filename):
return os.path.getsize(filename)
def open_file_chunk_reader(self, filename, start_byte, size, callbacks):
- return ReadFileChunk.from_filename(
- filename, start_byte, size, callbacks, enable_callbacks=False
- )
-
- def open_file_chunk_reader_from_fileobj(
- self,
- fileobj,
- chunk_size,
- full_file_size,
- callbacks,
- close_callbacks=None,
- ):
+ return ReadFileChunk.from_filename(
+ filename, start_byte, size, callbacks, enable_callbacks=False
+ )
+
+ def open_file_chunk_reader_from_fileobj(
+ self,
+ fileobj,
+ chunk_size,
+ full_file_size,
+ callbacks,
+ close_callbacks=None,
+ ):
return ReadFileChunk(
- fileobj,
- chunk_size,
- full_file_size,
- callbacks=callbacks,
- enable_callbacks=False,
- close_callbacks=close_callbacks,
- )
+ fileobj,
+ chunk_size,
+ full_file_size,
+ callbacks=callbacks,
+ enable_callbacks=False,
+ close_callbacks=close_callbacks,
+ )
def open(self, filename, mode):
return open(filename, mode)
@@ -312,23 +312,23 @@ class OSUtils:
return True
return False
- def get_temp_filename(self, filename):
- suffix = os.extsep + random_file_extension()
- path = os.path.dirname(filename)
- name = os.path.basename(filename)
- temp_filename = name[: self._MAX_FILENAME_LEN - len(suffix)] + suffix
- return os.path.join(path, temp_filename)
-
- def allocate(self, filename, size):
- try:
- with self.open(filename, 'wb') as f:
- fallocate(f, size)
- except OSError:
- self.remove_file(filename)
- raise
-
-
-class DeferredOpenFile:
+ def get_temp_filename(self, filename):
+ suffix = os.extsep + random_file_extension()
+ path = os.path.dirname(filename)
+ name = os.path.basename(filename)
+ temp_filename = name[: self._MAX_FILENAME_LEN - len(suffix)] + suffix
+ return os.path.join(path, temp_filename)
+
+ def allocate(self, filename, size):
+ try:
+ with self.open(filename, 'wb') as f:
+ fallocate(f, size)
+ except OSError:
+ self.remove_file(filename)
+ raise
+
+
+class DeferredOpenFile:
def __init__(self, filename, start_byte=0, mode='rb', open_function=open):
"""A class that defers the opening of a file till needed
@@ -374,9 +374,9 @@ class DeferredOpenFile:
self._open_if_needed()
self._fileobj.write(data)
- def seek(self, where, whence=0):
+ def seek(self, where, whence=0):
self._open_if_needed()
- self._fileobj.seek(where, whence)
+ self._fileobj.seek(where, whence)
def tell(self):
if self._fileobj is None:
@@ -395,16 +395,16 @@ class DeferredOpenFile:
self.close()
-class ReadFileChunk:
- def __init__(
- self,
- fileobj,
- chunk_size,
- full_file_size,
- callbacks=None,
- enable_callbacks=True,
- close_callbacks=None,
- ):
+class ReadFileChunk:
+ def __init__(
+ self,
+ fileobj,
+ chunk_size,
+ full_file_size,
+ callbacks=None,
+ enable_callbacks=True,
+ close_callbacks=None,
+ ):
"""
Given a file object shown below::
@@ -441,13 +441,13 @@ class ReadFileChunk:
self._fileobj = fileobj
self._start_byte = self._fileobj.tell()
self._size = self._calculate_file_size(
- self._fileobj,
- requested_size=chunk_size,
- start_byte=self._start_byte,
- actual_file_size=full_file_size,
- )
- # _amount_read represents the position in the chunk and may exceed
- # the chunk size, but won't allow reads out of bounds.
+ self._fileobj,
+ requested_size=chunk_size,
+ start_byte=self._start_byte,
+ actual_file_size=full_file_size,
+ )
+ # _amount_read represents the position in the chunk and may exceed
+ # the chunk size, but won't allow reads out of bounds.
self._amount_read = 0
self._callbacks = callbacks
if callbacks is None:
@@ -458,14 +458,14 @@ class ReadFileChunk:
self._close_callbacks = close_callbacks
@classmethod
- def from_filename(
- cls,
- filename,
- start_byte,
- chunk_size,
- callbacks=None,
- enable_callbacks=True,
- ):
+ def from_filename(
+ cls,
+ filename,
+ start_byte,
+ chunk_size,
+ callbacks=None,
+ enable_callbacks=True,
+ ):
"""Convenience factory function to create from a filename.
:type start_byte: int
@@ -496,18 +496,18 @@ class ReadFileChunk:
file_size = os.fstat(f.fileno()).st_size
return cls(f, chunk_size, file_size, callbacks, enable_callbacks)
- def _calculate_file_size(
- self, fileobj, requested_size, start_byte, actual_file_size
- ):
+ def _calculate_file_size(
+ self, fileobj, requested_size, start_byte, actual_file_size
+ ):
max_chunk_size = actual_file_size - start_byte
return min(max_chunk_size, requested_size)
def read(self, amount=None):
- amount_left = max(self._size - self._amount_read, 0)
+ amount_left = max(self._size - self._amount_read, 0)
if amount is None:
- amount_to_read = amount_left
+ amount_to_read = amount_left
else:
- amount_to_read = min(amount_left, amount)
+ amount_to_read = min(amount_left, amount)
data = self._fileobj.read(amount_to_read)
self._amount_read += len(data)
if self._callbacks is not None and self._callbacks_enabled:
@@ -530,29 +530,29 @@ class ReadFileChunk:
def disable_callback(self):
self._callbacks_enabled = False
- def seek(self, where, whence=0):
- if whence not in (0, 1, 2):
- # Mimic io's error for invalid whence values
- raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)")
-
- # Recalculate where based on chunk attributes so seek from file
- # start (whence=0) is always used
- where += self._start_byte
- if whence == 1:
- where += self._amount_read
- elif whence == 2:
- where += self._size
-
- self._fileobj.seek(max(where, self._start_byte))
+ def seek(self, where, whence=0):
+ if whence not in (0, 1, 2):
+ # Mimic io's error for invalid whence values
+ raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)")
+
+ # Recalculate where based on chunk attributes so seek from file
+ # start (whence=0) is always used
+ where += self._start_byte
+ if whence == 1:
+ where += self._amount_read
+ elif whence == 2:
+ where += self._size
+
+ self._fileobj.seek(max(where, self._start_byte))
if self._callbacks is not None and self._callbacks_enabled:
# To also rewind the callback() for an accurate progress report
- bounded_where = max(min(where - self._start_byte, self._size), 0)
- bounded_amount_read = min(self._amount_read, self._size)
- amount = bounded_where - bounded_amount_read
+ bounded_where = max(min(where - self._start_byte, self._size), 0)
+ bounded_amount_read = min(self._amount_read, self._size)
+ amount = bounded_where - bounded_amount_read
invoke_progress_callbacks(
- self._callbacks, bytes_transferred=amount
- )
- self._amount_read = max(where - self._start_byte, 0)
+ self._callbacks, bytes_transferred=amount
+ )
+ self._amount_read = max(where - self._start_byte, 0)
def close(self):
if self._close_callbacks is not None and self._callbacks_enabled:
@@ -586,9 +586,9 @@ class ReadFileChunk:
return iter([])
-class StreamReaderProgress:
+class StreamReaderProgress:
"""Wrapper for a read only stream that adds progress callbacks."""
-
+
def __init__(self, stream, callbacks=None):
self._stream = stream
self._callbacks = callbacks
@@ -605,7 +605,7 @@ class NoResourcesAvailable(Exception):
pass
-class TaskSemaphore:
+class TaskSemaphore:
def __init__(self, count):
"""A semaphore for the purpose of limiting the number of tasks
@@ -621,7 +621,7 @@ class TaskSemaphore:
needed for API compatibility with the SlidingWindowSemaphore
implementation.
:param block: If True, block until it can be acquired. If False,
- do not block and raise an exception if cannot be acquired.
+ do not block and raise an exception if cannot be acquired.
:returns: A token (can be None) to use when releasing the semaphore
"""
@@ -638,7 +638,7 @@ class TaskSemaphore:
class but is needed for API compatibility with the
SlidingWindowSemaphore implementation.
"""
- logger.debug(f"Releasing acquire {tag}/{acquire_token}")
+ logger.debug(f"Releasing acquire {tag}/{acquire_token}")
self._semaphore.release()
@@ -664,7 +664,7 @@ class SlidingWindowSemaphore(TaskSemaphore):
when the minimum sequence number for a tag is released.
"""
-
+
def __init__(self, count):
self._count = count
# Dict[tag, next_sequence_number].
@@ -727,26 +727,26 @@ class SlidingWindowSemaphore(TaskSemaphore):
# We can't do anything right now because we're still waiting
# for the min sequence for the tag to be released. We have
# to queue this for pending release.
- self._pending_release.setdefault(tag, []).append(
- sequence_number
- )
+ self._pending_release.setdefault(tag, []).append(
+ sequence_number
+ )
self._pending_release[tag].sort(reverse=True)
else:
raise ValueError(
"Attempted to release unknown sequence number "
- "%s for tag: %s" % (sequence_number, tag)
- )
+ "%s for tag: %s" % (sequence_number, tag)
+ )
finally:
self._condition.release()
-class ChunksizeAdjuster:
- def __init__(
- self,
- max_size=MAX_SINGLE_UPLOAD_SIZE,
- min_size=MIN_UPLOAD_CHUNKSIZE,
- max_parts=MAX_PARTS,
- ):
+class ChunksizeAdjuster:
+ def __init__(
+ self,
+ max_size=MAX_SINGLE_UPLOAD_SIZE,
+ min_size=MIN_UPLOAD_CHUNKSIZE,
+ max_parts=MAX_PARTS,
+ ):
self.max_size = max_size
self.min_size = min_size
self.max_parts = max_parts
@@ -772,14 +772,14 @@ class ChunksizeAdjuster:
if current_chunksize > self.max_size:
logger.debug(
"Chunksize greater than maximum chunksize. "
- "Setting to %s from %s." % (self.max_size, current_chunksize)
- )
+ "Setting to %s from %s." % (self.max_size, current_chunksize)
+ )
return self.max_size
elif current_chunksize < self.min_size:
logger.debug(
"Chunksize less than minimum chunksize. "
- "Setting to %s from %s." % (self.min_size, current_chunksize)
- )
+ "Setting to %s from %s." % (self.min_size, current_chunksize)
+ )
return self.min_size
else:
return current_chunksize
@@ -795,8 +795,8 @@ class ChunksizeAdjuster:
if chunksize != current_chunksize:
logger.debug(
"Chunksize would result in the number of parts exceeding the "
- "maximum. Setting to %s from %s."
- % (chunksize, current_chunksize)
- )
+ "maximum. Setting to %s from %s."
+ % (chunksize, current_chunksize)
+ )
return chunksize
diff --git a/contrib/python/s3transfer/py3/tests/__init__.py b/contrib/python/s3transfer/py3/tests/__init__.py
index 229371009a..e36c4936bf 100644
--- a/contrib/python/s3transfer/py3/tests/__init__.py
+++ b/contrib/python/s3transfer/py3/tests/__init__.py
@@ -1,531 +1,531 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License'). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the 'license' file accompanying this file. This file is
-# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import hashlib
-import io
-import math
-import os
-import platform
-import shutil
-import string
-import tempfile
-import unittest
-from unittest import mock # noqa: F401
-
-import botocore.session
-from botocore.stub import Stubber
-
-from s3transfer.futures import (
- IN_MEMORY_DOWNLOAD_TAG,
- IN_MEMORY_UPLOAD_TAG,
- BoundedExecutor,
- NonThreadedExecutor,
- TransferCoordinator,
- TransferFuture,
- TransferMeta,
-)
-from s3transfer.manager import TransferConfig
-from s3transfer.subscribers import BaseSubscriber
-from s3transfer.utils import (
- CallArgs,
- OSUtils,
- SlidingWindowSemaphore,
- TaskSemaphore,
-)
-
-ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS
-# Detect if CRT is available for use
-try:
- import awscrt.s3 # noqa: F401
-
- HAS_CRT = True
-except ImportError:
- HAS_CRT = False
-
-
-def setup_package():
- if is_serial_implementation():
- BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor
-
-
-def teardown_package():
- BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS
-
-
-def is_serial_implementation():
- return os.environ.get('USE_SERIAL_EXECUTOR', False)
-
-
-def assert_files_equal(first, second):
- if os.path.getsize(first) != os.path.getsize(second):
- raise AssertionError(f"Files are not equal: {first}, {second}")
- first_md5 = md5_checksum(first)
- second_md5 = md5_checksum(second)
- if first_md5 != second_md5:
- raise AssertionError(
- "Files are not equal: {}(md5={}) != {}(md5={})".format(
- first, first_md5, second, second_md5
- )
- )
-
-
-def md5_checksum(filename):
- checksum = hashlib.md5()
- with open(filename, 'rb') as f:
- for chunk in iter(lambda: f.read(8192), b''):
- checksum.update(chunk)
- return checksum.hexdigest()
-
-
-def random_bucket_name(prefix='s3transfer', num_chars=10):
- base = string.ascii_lowercase + string.digits
- random_bytes = bytearray(os.urandom(num_chars))
- return prefix + ''.join([base[b % len(base)] for b in random_bytes])
-
-
-def skip_if_windows(reason):
- """Decorator to skip tests that should not be run on windows.
-
- Example usage:
-
- @skip_if_windows("Not valid")
- def test_some_non_windows_stuff(self):
- self.assertEqual(...)
-
- """
-
- def decorator(func):
- return unittest.skipIf(
- platform.system() not in ['Darwin', 'Linux'], reason
- )(func)
-
- return decorator
-
-
-def skip_if_using_serial_implementation(reason):
- """Decorator to skip tests when running as the serial implementation"""
-
- def decorator(func):
- return unittest.skipIf(is_serial_implementation(), reason)(func)
-
- return decorator
-
-
-def requires_crt(cls, reason=None):
- if reason is None:
- reason = "Test requires awscrt to be installed."
- return unittest.skipIf(not HAS_CRT, reason)(cls)
-
-
-class StreamWithError:
- """A wrapper to simulate errors while reading from a stream
-
- :param stream: The underlying stream to read from
- :param exception_type: The exception type to throw
- :param num_reads: The number of times to allow a read before raising
- the exception. A value of zero indicates to raise the error on the
- first read.
- """
-
- def __init__(self, stream, exception_type, num_reads=0):
- self._stream = stream
- self._exception_type = exception_type
- self._num_reads = num_reads
- self._count = 0
-
- def read(self, n=-1):
- if self._count == self._num_reads:
- raise self._exception_type
- self._count += 1
- return self._stream.read(n)
-
-
-class FileSizeProvider:
- def __init__(self, file_size):
- self.file_size = file_size
-
- def on_queued(self, future, **kwargs):
- future.meta.provide_transfer_size(self.file_size)
-
-
-class FileCreator:
- def __init__(self):
- self.rootdir = tempfile.mkdtemp()
-
- def remove_all(self):
- shutil.rmtree(self.rootdir)
-
- def create_file(self, filename, contents, mode='w'):
- """Creates a file in a tmpdir
- ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
- It will be translated into a full path in a tmp dir.
- ``mode`` is the mode the file should be opened either as ``w`` or
- `wb``.
- Returns the full path to the file.
- """
- full_path = os.path.join(self.rootdir, filename)
- if not os.path.isdir(os.path.dirname(full_path)):
- os.makedirs(os.path.dirname(full_path))
- with open(full_path, mode) as f:
- f.write(contents)
- return full_path
-
- def create_file_with_size(self, filename, filesize):
- filename = self.create_file(filename, contents='')
- chunksize = 8192
- with open(filename, 'wb') as f:
- for i in range(int(math.ceil(filesize / float(chunksize)))):
- f.write(b'a' * chunksize)
- return filename
-
- def append_file(self, filename, contents):
- """Append contents to a file
- ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
- It will be translated into a full path in a tmp dir.
- Returns the full path to the file.
- """
- full_path = os.path.join(self.rootdir, filename)
- if not os.path.isdir(os.path.dirname(full_path)):
- os.makedirs(os.path.dirname(full_path))
- with open(full_path, 'a') as f:
- f.write(contents)
- return full_path
-
- def full_path(self, filename):
- """Translate relative path to full path in temp dir.
- f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
- """
- return os.path.join(self.rootdir, filename)
-
-
-class RecordingOSUtils(OSUtils):
- """An OSUtil abstraction that records openings and renamings"""
-
- def __init__(self):
- super().__init__()
- self.open_records = []
- self.rename_records = []
-
- def open(self, filename, mode):
- self.open_records.append((filename, mode))
- return super().open(filename, mode)
-
- def rename_file(self, current_filename, new_filename):
- self.rename_records.append((current_filename, new_filename))
- super().rename_file(current_filename, new_filename)
-
-
-class RecordingSubscriber(BaseSubscriber):
- def __init__(self):
- self.on_queued_calls = []
- self.on_progress_calls = []
- self.on_done_calls = []
-
- def on_queued(self, **kwargs):
- self.on_queued_calls.append(kwargs)
-
- def on_progress(self, **kwargs):
- self.on_progress_calls.append(kwargs)
-
- def on_done(self, **kwargs):
- self.on_done_calls.append(kwargs)
-
- def calculate_bytes_seen(self, **kwargs):
- amount_seen = 0
- for call in self.on_progress_calls:
- amount_seen += call['bytes_transferred']
- return amount_seen
-
-
-class TransferCoordinatorWithInterrupt(TransferCoordinator):
- """Used to inject keyboard interrupts"""
-
- def result(self):
- raise KeyboardInterrupt()
-
-
-class RecordingExecutor:
- """A wrapper on an executor to record calls made to submit()
-
- You can access the submissions property to receive a list of dictionaries
- that represents all submissions where the dictionary is formatted::
-
- {
- 'fn': function
- 'args': positional args (as tuple)
- 'kwargs': keyword args (as dict)
- }
- """
-
- def __init__(self, executor):
- self._executor = executor
- self.submissions = []
-
- def submit(self, task, tag=None, block=True):
- future = self._executor.submit(task, tag, block)
- self.submissions.append({'task': task, 'tag': tag, 'block': block})
- return future
-
- def shutdown(self):
- self._executor.shutdown()
-
-
-class StubbedClientTest(unittest.TestCase):
- def setUp(self):
- self.session = botocore.session.get_session()
- self.region = 'us-west-2'
- self.client = self.session.create_client(
- 's3',
- self.region,
- aws_access_key_id='foo',
- aws_secret_access_key='bar',
- )
- self.stubber = Stubber(self.client)
- self.stubber.activate()
-
- def tearDown(self):
- self.stubber.deactivate()
-
- def reset_stubber_with_new_client(self, override_client_kwargs):
- client_kwargs = {
- 'service_name': 's3',
- 'region_name': self.region,
- 'aws_access_key_id': 'foo',
- 'aws_secret_access_key': 'bar',
- }
- client_kwargs.update(override_client_kwargs)
- self.client = self.session.create_client(**client_kwargs)
- self.stubber = Stubber(self.client)
- self.stubber.activate()
-
-
-class BaseTaskTest(StubbedClientTest):
- def setUp(self):
- super().setUp()
- self.transfer_coordinator = TransferCoordinator()
-
- def get_task(self, task_cls, **kwargs):
- if 'transfer_coordinator' not in kwargs:
- kwargs['transfer_coordinator'] = self.transfer_coordinator
- return task_cls(**kwargs)
-
- def get_transfer_future(self, call_args=None):
- return TransferFuture(
- meta=TransferMeta(call_args), coordinator=self.transfer_coordinator
- )
-
-
-class BaseSubmissionTaskTest(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.config = TransferConfig()
- self.osutil = OSUtils()
- self.executor = BoundedExecutor(
- 1000,
- 1,
- {
- IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10),
- IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10),
- },
- )
-
- def tearDown(self):
- super().tearDown()
- self.executor.shutdown()
-
-
-class BaseGeneralInterfaceTest(StubbedClientTest):
- """A general test class to ensure consistency across TransferManger methods
-
- This test should never be called and should be subclassed from to pick up
- the various tests that all TransferManager method must pass from a
- functionality standpoint.
- """
-
- __test__ = False
-
- def manager(self):
- """The transfer manager to use"""
- raise NotImplementedError('method is not implemented')
-
- @property
- def method(self):
- """The transfer manager method to invoke i.e. upload()"""
- raise NotImplementedError('method is not implemented')
-
- def create_call_kwargs(self):
- """The kwargs to be passed to the transfer manager method"""
- raise NotImplementedError('create_call_kwargs is not implemented')
-
- def create_invalid_extra_args(self):
- """A value for extra_args that will cause validation errors"""
- raise NotImplementedError(
- 'create_invalid_extra_args is not implemented'
- )
-
- def create_stubbed_responses(self):
- """A list of stubbed responses that will cause the request to succeed
-
- The elements of this list is a dictionary that will be used as key
- word arguments to botocore.Stubber.add_response(). For example::
-
- [{'method': 'put_object', 'service_response': {}}]
- """
- raise NotImplementedError(
- 'create_stubbed_responses is not implemented'
- )
-
- def create_expected_progress_callback_info(self):
- """A list of kwargs expected to be passed to each progress callback
-
- Note that the future kwargs does not need to be added to each
- dictionary provided in the list. This is injected for you. An example
- is::
-
- [
- {'bytes_transferred': 4},
- {'bytes_transferred': 4},
- {'bytes_transferred': 2}
- ]
-
- This indicates that the progress callback will be called three
- times and pass along the specified keyword arguments and corresponding
- values.
- """
- raise NotImplementedError(
- 'create_expected_progress_callback_info is not implemented'
- )
-
- def _setup_default_stubbed_responses(self):
- for stubbed_response in self.create_stubbed_responses():
- self.stubber.add_response(**stubbed_response)
-
- def test_returns_future_with_meta(self):
- self._setup_default_stubbed_responses()
- future = self.method(**self.create_call_kwargs())
- # The result is called so we ensure that the entire process executes
- # before we try to clean up resources in the tearDown.
- future.result()
-
- # Assert the return value is a future with metadata associated to it.
- self.assertIsInstance(future, TransferFuture)
- self.assertIsInstance(future.meta, TransferMeta)
-
- def test_returns_correct_call_args(self):
- self._setup_default_stubbed_responses()
- call_kwargs = self.create_call_kwargs()
- future = self.method(**call_kwargs)
- # The result is called so we ensure that the entire process executes
- # before we try to clean up resources in the tearDown.
- future.result()
-
- # Assert that there are call args associated to the metadata
- self.assertIsInstance(future.meta.call_args, CallArgs)
- # Assert that all of the arguments passed to the method exist and
- # are of the correct value in call_args.
- for param, value in call_kwargs.items():
- self.assertEqual(value, getattr(future.meta.call_args, param))
-
- def test_has_transfer_id_associated_to_future(self):
- self._setup_default_stubbed_responses()
- call_kwargs = self.create_call_kwargs()
- future = self.method(**call_kwargs)
- # The result is called so we ensure that the entire process executes
- # before we try to clean up resources in the tearDown.
- future.result()
-
- # Assert that an transfer id was associated to the future.
- # Since there is only one transfer request is made for that transfer
- # manager the id will be zero since it will be the first transfer
- # request made for that transfer manager.
- self.assertEqual(future.meta.transfer_id, 0)
-
- # If we make a second request, the transfer id should have incremented
- # by one for that new TransferFuture.
- self._setup_default_stubbed_responses()
- future = self.method(**call_kwargs)
- future.result()
- self.assertEqual(future.meta.transfer_id, 1)
-
- def test_invalid_extra_args(self):
- with self.assertRaisesRegex(ValueError, 'Invalid extra_args'):
- self.method(
- extra_args=self.create_invalid_extra_args(),
- **self.create_call_kwargs(),
- )
-
- def test_for_callback_kwargs_correctness(self):
- # Add the stubbed responses before invoking the method
- self._setup_default_stubbed_responses()
-
- subscriber = RecordingSubscriber()
- future = self.method(
- subscribers=[subscriber], **self.create_call_kwargs()
- )
- # We call shutdown instead of result on future because the future
- # could be finished but the done callback could still be going.
- # The manager's shutdown method ensures everything completes.
- self.manager.shutdown()
-
- # Assert the various subscribers were called with the
- # expected kwargs
- expected_progress_calls = self.create_expected_progress_callback_info()
- for expected_progress_call in expected_progress_calls:
- expected_progress_call['future'] = future
-
- self.assertEqual(subscriber.on_queued_calls, [{'future': future}])
- self.assertEqual(subscriber.on_progress_calls, expected_progress_calls)
- self.assertEqual(subscriber.on_done_calls, [{'future': future}])
-
-
-class NonSeekableReader(io.RawIOBase):
- def __init__(self, b=b''):
- super().__init__()
- self._data = io.BytesIO(b)
-
- def seekable(self):
- return False
-
- def writable(self):
- return False
-
- def readable(self):
- return True
-
- def write(self, b):
- # This is needed because python will not always return the correct
- # kind of error even though writeable returns False.
- raise io.UnsupportedOperation("write")
-
- def read(self, n=-1):
- return self._data.read(n)
-
-
-class NonSeekableWriter(io.RawIOBase):
- def __init__(self, fileobj):
- super().__init__()
- self._fileobj = fileobj
-
- def seekable(self):
- return False
-
- def writable(self):
- return True
-
- def readable(self):
- return False
-
- def write(self, b):
- self._fileobj.write(b)
-
- def read(self, n=-1):
- raise io.UnsupportedOperation("read")
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import hashlib
+import io
+import math
+import os
+import platform
+import shutil
+import string
+import tempfile
+import unittest
+from unittest import mock # noqa: F401
+
+import botocore.session
+from botocore.stub import Stubber
+
+from s3transfer.futures import (
+ IN_MEMORY_DOWNLOAD_TAG,
+ IN_MEMORY_UPLOAD_TAG,
+ BoundedExecutor,
+ NonThreadedExecutor,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
+from s3transfer.manager import TransferConfig
+from s3transfer.subscribers import BaseSubscriber
+from s3transfer.utils import (
+ CallArgs,
+ OSUtils,
+ SlidingWindowSemaphore,
+ TaskSemaphore,
+)
+
+ORIGINAL_EXECUTOR_CLS = BoundedExecutor.EXECUTOR_CLS
+# Detect if CRT is available for use
+try:
+ import awscrt.s3 # noqa: F401
+
+ HAS_CRT = True
+except ImportError:
+ HAS_CRT = False
+
+
+def setup_package():
+ if is_serial_implementation():
+ BoundedExecutor.EXECUTOR_CLS = NonThreadedExecutor
+
+
+def teardown_package():
+ BoundedExecutor.EXECUTOR_CLS = ORIGINAL_EXECUTOR_CLS
+
+
+def is_serial_implementation():
+ return os.environ.get('USE_SERIAL_EXECUTOR', False)
+
+
+def assert_files_equal(first, second):
+ if os.path.getsize(first) != os.path.getsize(second):
+ raise AssertionError(f"Files are not equal: {first}, {second}")
+ first_md5 = md5_checksum(first)
+ second_md5 = md5_checksum(second)
+ if first_md5 != second_md5:
+ raise AssertionError(
+ "Files are not equal: {}(md5={}) != {}(md5={})".format(
+ first, first_md5, second, second_md5
+ )
+ )
+
+
+def md5_checksum(filename):
+ checksum = hashlib.md5()
+ with open(filename, 'rb') as f:
+ for chunk in iter(lambda: f.read(8192), b''):
+ checksum.update(chunk)
+ return checksum.hexdigest()
+
+
+def random_bucket_name(prefix='s3transfer', num_chars=10):
+ base = string.ascii_lowercase + string.digits
+ random_bytes = bytearray(os.urandom(num_chars))
+ return prefix + ''.join([base[b % len(base)] for b in random_bytes])
+
+
+def skip_if_windows(reason):
+ """Decorator to skip tests that should not be run on windows.
+
+ Example usage:
+
+ @skip_if_windows("Not valid")
+ def test_some_non_windows_stuff(self):
+ self.assertEqual(...)
+
+ """
+
+ def decorator(func):
+ return unittest.skipIf(
+ platform.system() not in ['Darwin', 'Linux'], reason
+ )(func)
+
+ return decorator
+
+
+def skip_if_using_serial_implementation(reason):
+ """Decorator to skip tests when running as the serial implementation"""
+
+ def decorator(func):
+ return unittest.skipIf(is_serial_implementation(), reason)(func)
+
+ return decorator
+
+
+def requires_crt(cls, reason=None):
+ if reason is None:
+ reason = "Test requires awscrt to be installed."
+ return unittest.skipIf(not HAS_CRT, reason)(cls)
+
+
+class StreamWithError:
+ """A wrapper to simulate errors while reading from a stream
+
+ :param stream: The underlying stream to read from
+ :param exception_type: The exception type to throw
+ :param num_reads: The number of times to allow a read before raising
+ the exception. A value of zero indicates to raise the error on the
+ first read.
+ """
+
+ def __init__(self, stream, exception_type, num_reads=0):
+ self._stream = stream
+ self._exception_type = exception_type
+ self._num_reads = num_reads
+ self._count = 0
+
+ def read(self, n=-1):
+ if self._count == self._num_reads:
+ raise self._exception_type
+ self._count += 1
+ return self._stream.read(n)
+
+
+class FileSizeProvider:
+ def __init__(self, file_size):
+ self.file_size = file_size
+
+ def on_queued(self, future, **kwargs):
+ future.meta.provide_transfer_size(self.file_size)
+
+
+class FileCreator:
+ def __init__(self):
+ self.rootdir = tempfile.mkdtemp()
+
+ def remove_all(self):
+ shutil.rmtree(self.rootdir)
+
+ def create_file(self, filename, contents, mode='w'):
+ """Creates a file in a tmpdir
+ ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
+ It will be translated into a full path in a tmp dir.
+ ``mode`` is the mode the file should be opened either as ``w`` or
+ `wb``.
+ Returns the full path to the file.
+ """
+ full_path = os.path.join(self.rootdir, filename)
+ if not os.path.isdir(os.path.dirname(full_path)):
+ os.makedirs(os.path.dirname(full_path))
+ with open(full_path, mode) as f:
+ f.write(contents)
+ return full_path
+
+ def create_file_with_size(self, filename, filesize):
+ filename = self.create_file(filename, contents='')
+ chunksize = 8192
+ with open(filename, 'wb') as f:
+ for i in range(int(math.ceil(filesize / float(chunksize)))):
+ f.write(b'a' * chunksize)
+ return filename
+
+ def append_file(self, filename, contents):
+ """Append contents to a file
+ ``filename`` should be a relative path, e.g. "foo/bar/baz.txt"
+ It will be translated into a full path in a tmp dir.
+ Returns the full path to the file.
+ """
+ full_path = os.path.join(self.rootdir, filename)
+ if not os.path.isdir(os.path.dirname(full_path)):
+ os.makedirs(os.path.dirname(full_path))
+ with open(full_path, 'a') as f:
+ f.write(contents)
+ return full_path
+
+ def full_path(self, filename):
+ """Translate relative path to full path in temp dir.
+ f.full_path('foo/bar.txt') -> /tmp/asdfasd/foo/bar.txt
+ """
+ return os.path.join(self.rootdir, filename)
+
+
+class RecordingOSUtils(OSUtils):
+ """An OSUtil abstraction that records openings and renamings"""
+
+ def __init__(self):
+ super().__init__()
+ self.open_records = []
+ self.rename_records = []
+
+ def open(self, filename, mode):
+ self.open_records.append((filename, mode))
+ return super().open(filename, mode)
+
+ def rename_file(self, current_filename, new_filename):
+ self.rename_records.append((current_filename, new_filename))
+ super().rename_file(current_filename, new_filename)
+
+
+class RecordingSubscriber(BaseSubscriber):
+ def __init__(self):
+ self.on_queued_calls = []
+ self.on_progress_calls = []
+ self.on_done_calls = []
+
+ def on_queued(self, **kwargs):
+ self.on_queued_calls.append(kwargs)
+
+ def on_progress(self, **kwargs):
+ self.on_progress_calls.append(kwargs)
+
+ def on_done(self, **kwargs):
+ self.on_done_calls.append(kwargs)
+
+ def calculate_bytes_seen(self, **kwargs):
+ amount_seen = 0
+ for call in self.on_progress_calls:
+ amount_seen += call['bytes_transferred']
+ return amount_seen
+
+
+class TransferCoordinatorWithInterrupt(TransferCoordinator):
+ """Used to inject keyboard interrupts"""
+
+ def result(self):
+ raise KeyboardInterrupt()
+
+
+class RecordingExecutor:
+ """A wrapper on an executor to record calls made to submit()
+
+ You can access the submissions property to receive a list of dictionaries
+ that represents all submissions where the dictionary is formatted::
+
+ {
+ 'fn': function
+ 'args': positional args (as tuple)
+ 'kwargs': keyword args (as dict)
+ }
+ """
+
+ def __init__(self, executor):
+ self._executor = executor
+ self.submissions = []
+
+ def submit(self, task, tag=None, block=True):
+ future = self._executor.submit(task, tag, block)
+ self.submissions.append({'task': task, 'tag': tag, 'block': block})
+ return future
+
+ def shutdown(self):
+ self._executor.shutdown()
+
+
+class StubbedClientTest(unittest.TestCase):
+ def setUp(self):
+ self.session = botocore.session.get_session()
+ self.region = 'us-west-2'
+ self.client = self.session.create_client(
+ 's3',
+ self.region,
+ aws_access_key_id='foo',
+ aws_secret_access_key='bar',
+ )
+ self.stubber = Stubber(self.client)
+ self.stubber.activate()
+
+ def tearDown(self):
+ self.stubber.deactivate()
+
+ def reset_stubber_with_new_client(self, override_client_kwargs):
+ client_kwargs = {
+ 'service_name': 's3',
+ 'region_name': self.region,
+ 'aws_access_key_id': 'foo',
+ 'aws_secret_access_key': 'bar',
+ }
+ client_kwargs.update(override_client_kwargs)
+ self.client = self.session.create_client(**client_kwargs)
+ self.stubber = Stubber(self.client)
+ self.stubber.activate()
+
+
+class BaseTaskTest(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.transfer_coordinator = TransferCoordinator()
+
+ def get_task(self, task_cls, **kwargs):
+ if 'transfer_coordinator' not in kwargs:
+ kwargs['transfer_coordinator'] = self.transfer_coordinator
+ return task_cls(**kwargs)
+
+ def get_transfer_future(self, call_args=None):
+ return TransferFuture(
+ meta=TransferMeta(call_args), coordinator=self.transfer_coordinator
+ )
+
+
+class BaseSubmissionTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig()
+ self.osutil = OSUtils()
+ self.executor = BoundedExecutor(
+ 1000,
+ 1,
+ {
+ IN_MEMORY_UPLOAD_TAG: TaskSemaphore(10),
+ IN_MEMORY_DOWNLOAD_TAG: SlidingWindowSemaphore(10),
+ },
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ self.executor.shutdown()
+
+
+class BaseGeneralInterfaceTest(StubbedClientTest):
+ """A general test class to ensure consistency across TransferManger methods
+
+ This test should never be called and should be subclassed from to pick up
+ the various tests that all TransferManager method must pass from a
+ functionality standpoint.
+ """
+
+ __test__ = False
+
+ def manager(self):
+ """The transfer manager to use"""
+ raise NotImplementedError('method is not implemented')
+
+ @property
+ def method(self):
+ """The transfer manager method to invoke i.e. upload()"""
+ raise NotImplementedError('method is not implemented')
+
+ def create_call_kwargs(self):
+ """The kwargs to be passed to the transfer manager method"""
+ raise NotImplementedError('create_call_kwargs is not implemented')
+
+ def create_invalid_extra_args(self):
+ """A value for extra_args that will cause validation errors"""
+ raise NotImplementedError(
+ 'create_invalid_extra_args is not implemented'
+ )
+
+ def create_stubbed_responses(self):
+ """A list of stubbed responses that will cause the request to succeed
+
+ The elements of this list is a dictionary that will be used as key
+ word arguments to botocore.Stubber.add_response(). For example::
+
+ [{'method': 'put_object', 'service_response': {}}]
+ """
+ raise NotImplementedError(
+ 'create_stubbed_responses is not implemented'
+ )
+
+ def create_expected_progress_callback_info(self):
+ """A list of kwargs expected to be passed to each progress callback
+
+ Note that the future kwargs does not need to be added to each
+ dictionary provided in the list. This is injected for you. An example
+ is::
+
+ [
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 2}
+ ]
+
+ This indicates that the progress callback will be called three
+ times and pass along the specified keyword arguments and corresponding
+ values.
+ """
+ raise NotImplementedError(
+ 'create_expected_progress_callback_info is not implemented'
+ )
+
+ def _setup_default_stubbed_responses(self):
+ for stubbed_response in self.create_stubbed_responses():
+ self.stubber.add_response(**stubbed_response)
+
+ def test_returns_future_with_meta(self):
+ self._setup_default_stubbed_responses()
+ future = self.method(**self.create_call_kwargs())
+ # The result is called so we ensure that the entire process executes
+ # before we try to clean up resources in the tearDown.
+ future.result()
+
+ # Assert the return value is a future with metadata associated to it.
+ self.assertIsInstance(future, TransferFuture)
+ self.assertIsInstance(future.meta, TransferMeta)
+
+ def test_returns_correct_call_args(self):
+ self._setup_default_stubbed_responses()
+ call_kwargs = self.create_call_kwargs()
+ future = self.method(**call_kwargs)
+ # The result is called so we ensure that the entire process executes
+ # before we try to clean up resources in the tearDown.
+ future.result()
+
+ # Assert that there are call args associated to the metadata
+ self.assertIsInstance(future.meta.call_args, CallArgs)
+ # Assert that all of the arguments passed to the method exist and
+ # are of the correct value in call_args.
+ for param, value in call_kwargs.items():
+ self.assertEqual(value, getattr(future.meta.call_args, param))
+
+ def test_has_transfer_id_associated_to_future(self):
+ self._setup_default_stubbed_responses()
+ call_kwargs = self.create_call_kwargs()
+ future = self.method(**call_kwargs)
+ # The result is called so we ensure that the entire process executes
+ # before we try to clean up resources in the tearDown.
+ future.result()
+
+ # Assert that an transfer id was associated to the future.
+ # Since there is only one transfer request is made for that transfer
+ # manager the id will be zero since it will be the first transfer
+ # request made for that transfer manager.
+ self.assertEqual(future.meta.transfer_id, 0)
+
+ # If we make a second request, the transfer id should have incremented
+ # by one for that new TransferFuture.
+ self._setup_default_stubbed_responses()
+ future = self.method(**call_kwargs)
+ future.result()
+ self.assertEqual(future.meta.transfer_id, 1)
+
+ def test_invalid_extra_args(self):
+ with self.assertRaisesRegex(ValueError, 'Invalid extra_args'):
+ self.method(
+ extra_args=self.create_invalid_extra_args(),
+ **self.create_call_kwargs(),
+ )
+
+ def test_for_callback_kwargs_correctness(self):
+ # Add the stubbed responses before invoking the method
+ self._setup_default_stubbed_responses()
+
+ subscriber = RecordingSubscriber()
+ future = self.method(
+ subscribers=[subscriber], **self.create_call_kwargs()
+ )
+ # We call shutdown instead of result on future because the future
+ # could be finished but the done callback could still be going.
+ # The manager's shutdown method ensures everything completes.
+ self.manager.shutdown()
+
+ # Assert the various subscribers were called with the
+ # expected kwargs
+ expected_progress_calls = self.create_expected_progress_callback_info()
+ for expected_progress_call in expected_progress_calls:
+ expected_progress_call['future'] = future
+
+ self.assertEqual(subscriber.on_queued_calls, [{'future': future}])
+ self.assertEqual(subscriber.on_progress_calls, expected_progress_calls)
+ self.assertEqual(subscriber.on_done_calls, [{'future': future}])
+
+
+class NonSeekableReader(io.RawIOBase):
+ def __init__(self, b=b''):
+ super().__init__()
+ self._data = io.BytesIO(b)
+
+ def seekable(self):
+ return False
+
+ def writable(self):
+ return False
+
+ def readable(self):
+ return True
+
+ def write(self, b):
+ # This is needed because python will not always return the correct
+ # kind of error even though writeable returns False.
+ raise io.UnsupportedOperation("write")
+
+ def read(self, n=-1):
+ return self._data.read(n)
+
+
+class NonSeekableWriter(io.RawIOBase):
+ def __init__(self, fileobj):
+ super().__init__()
+ self._fileobj = fileobj
+
+ def seekable(self):
+ return False
+
+ def writable(self):
+ return True
+
+ def readable(self):
+ return False
+
+ def write(self, b):
+ self._fileobj.write(b)
+
+ def read(self, n=-1):
+ raise io.UnsupportedOperation("read")
diff --git a/contrib/python/s3transfer/py3/tests/functional/__init__.py b/contrib/python/s3transfer/py3/tests/functional/__init__.py
index d69af8958d..fa58dbdb55 100644
--- a/contrib/python/s3transfer/py3/tests/functional/__init__.py
+++ b/contrib/python/s3transfer/py3/tests/functional/__init__.py
@@ -1,12 +1,12 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_copy.py b/contrib/python/s3transfer/py3/tests/functional/test_copy.py
index 5f6aeabf83..801c9003bb 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_copy.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_copy.py
@@ -1,554 +1,554 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from botocore.exceptions import ClientError
-from botocore.stub import Stubber
-
-from s3transfer.manager import TransferConfig, TransferManager
-from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE
-from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider
-
-
-class BaseCopyTest(BaseGeneralInterfaceTest):
- def setUp(self):
- super().setUp()
- self.config = TransferConfig(
- max_request_concurrency=1,
- multipart_chunksize=MIN_UPLOAD_CHUNKSIZE,
- multipart_threshold=MIN_UPLOAD_CHUNKSIZE * 4,
- )
- self._manager = TransferManager(self.client, self.config)
-
- # Initialize some default arguments
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
- self.extra_args = {}
- self.subscribers = []
-
- self.half_chunksize = int(MIN_UPLOAD_CHUNKSIZE / 2)
- self.content = b'0' * (2 * MIN_UPLOAD_CHUNKSIZE + self.half_chunksize)
-
- @property
- def manager(self):
- return self._manager
-
- @property
- def method(self):
- return self.manager.copy
-
- def create_call_kwargs(self):
- return {
- 'copy_source': self.copy_source,
- 'bucket': self.bucket,
- 'key': self.key,
- }
-
- def create_invalid_extra_args(self):
- return {'Foo': 'bar'}
-
- def create_stubbed_responses(self):
- return [
- {
- 'method': 'head_object',
- 'service_response': {'ContentLength': len(self.content)},
- },
- {'method': 'copy_object', 'service_response': {}},
- ]
-
- def create_expected_progress_callback_info(self):
- return [
- {'bytes_transferred': len(self.content)},
- ]
-
- def add_head_object_response(self, expected_params=None, stubber=None):
- if not stubber:
- stubber = self.stubber
- head_response = self.create_stubbed_responses()[0]
- if expected_params:
- head_response['expected_params'] = expected_params
- stubber.add_response(**head_response)
-
- def add_successful_copy_responses(
- self,
- expected_copy_params=None,
- expected_create_mpu_params=None,
- expected_complete_mpu_params=None,
- ):
-
- # Add all responses needed to do the copy of the object.
- # Should account for both ranged and nonranged downloads.
- stubbed_responses = self.create_stubbed_responses()[1:]
-
- # If the length of copy responses is greater than one then it is
- # a multipart copy.
- copy_responses = stubbed_responses[0:1]
- if len(stubbed_responses) > 1:
- copy_responses = stubbed_responses[1:-1]
-
- # Add the expected create multipart upload params.
- if expected_create_mpu_params:
- stubbed_responses[0][
- 'expected_params'
- ] = expected_create_mpu_params
-
- # Add any expected copy parameters.
- if expected_copy_params:
- for i, copy_response in enumerate(copy_responses):
- if isinstance(expected_copy_params, list):
- copy_response['expected_params'] = expected_copy_params[i]
- else:
- copy_response['expected_params'] = expected_copy_params
-
- # Add the expected complete multipart upload params.
- if expected_complete_mpu_params:
- stubbed_responses[-1][
- 'expected_params'
- ] = expected_complete_mpu_params
-
- # Add the responses to the stubber.
- for stubbed_response in stubbed_responses:
- self.stubber.add_response(**stubbed_response)
-
- def test_can_provide_file_size(self):
- self.add_successful_copy_responses()
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))]
-
- future = self.manager.copy(**call_kwargs)
- future.result()
-
- # The HeadObject should have not happened and should have been able
- # to successfully copy the file.
- self.stubber.assert_no_pending_responses()
-
- def test_provide_copy_source_as_dict(self):
- self.copy_source['VersionId'] = 'mysourceversionid'
- expected_params = {
- 'Bucket': 'mysourcebucket',
- 'Key': 'mysourcekey',
- 'VersionId': 'mysourceversionid',
- }
-
- self.add_head_object_response(expected_params=expected_params)
- self.add_successful_copy_responses()
-
- future = self.manager.copy(**self.create_call_kwargs())
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_invalid_copy_source(self):
- self.copy_source = ['bucket', 'key']
- future = self.manager.copy(**self.create_call_kwargs())
- with self.assertRaises(TypeError):
- future.result()
-
- def test_provide_copy_source_client(self):
- source_client = self.session.create_client(
- 's3',
- 'eu-central-1',
- aws_access_key_id='foo',
- aws_secret_access_key='bar',
- )
- source_stubber = Stubber(source_client)
- source_stubber.activate()
- self.addCleanup(source_stubber.deactivate)
-
- self.add_head_object_response(stubber=source_stubber)
- self.add_successful_copy_responses()
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['source_client'] = source_client
- future = self.manager.copy(**call_kwargs)
- future.result()
-
- # Make sure that all of the responses were properly
- # used for both clients.
- source_stubber.assert_no_pending_responses()
- self.stubber.assert_no_pending_responses()
-
-
-class TestNonMultipartCopy(BaseCopyTest):
- __test__ = True
-
- def test_copy(self):
- expected_head_params = {
- 'Bucket': 'mysourcebucket',
- 'Key': 'mysourcekey',
- }
- expected_copy_object = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- }
- self.add_head_object_response(expected_params=expected_head_params)
- self.add_successful_copy_responses(
- expected_copy_params=expected_copy_object
- )
-
- future = self.manager.copy(**self.create_call_kwargs())
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_with_extra_args(self):
- self.extra_args['MetadataDirective'] = 'REPLACE'
-
- expected_head_params = {
- 'Bucket': 'mysourcebucket',
- 'Key': 'mysourcekey',
- }
- expected_copy_object = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'MetadataDirective': 'REPLACE',
- }
-
- self.add_head_object_response(expected_params=expected_head_params)
- self.add_successful_copy_responses(
- expected_copy_params=expected_copy_object
- )
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_maps_extra_args_to_head_object(self):
- self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256'
-
- expected_head_params = {
- 'Bucket': 'mysourcebucket',
- 'Key': 'mysourcekey',
- 'SSECustomerAlgorithm': 'AES256',
- }
- expected_copy_object = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'CopySourceSSECustomerAlgorithm': 'AES256',
- }
-
- self.add_head_object_response(expected_params=expected_head_params)
- self.add_successful_copy_responses(
- expected_copy_params=expected_copy_object
- )
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_allowed_copy_params_are_valid(self):
- op_model = self.client.meta.service_model.operation_model('CopyObject')
- for allowed_upload_arg in self._manager.ALLOWED_COPY_ARGS:
- self.assertIn(allowed_upload_arg, op_model.input_shape.members)
-
- def test_copy_with_tagging(self):
- extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'}
- self.add_head_object_response()
- self.add_successful_copy_responses(
- expected_copy_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'Tagging': 'tag1=val1',
- 'TaggingDirective': 'REPLACE',
- }
- )
- future = self.manager.copy(
- self.copy_source, self.bucket, self.key, extra_args
- )
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_raise_exception_on_s3_object_lambda_resource(self):
- s3_object_lambda_arn = (
- 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
- 'accesspoint:my-accesspoint'
- )
- with self.assertRaisesRegex(ValueError, 'methods do not support'):
- self.manager.copy(self.copy_source, s3_object_lambda_arn, self.key)
-
- def test_raise_exception_on_s3_object_lambda_resource_as_source(self):
- source = {
- 'Bucket': 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
- 'accesspoint:my-accesspoint'
- }
- with self.assertRaisesRegex(ValueError, 'methods do not support'):
- self.manager.copy(source, self.bucket, self.key)
-
-
-class TestMultipartCopy(BaseCopyTest):
- __test__ = True
-
- def setUp(self):
- super().setUp()
- self.config = TransferConfig(
- max_request_concurrency=1,
- multipart_threshold=1,
- multipart_chunksize=4,
- )
- self._manager = TransferManager(self.client, self.config)
-
- def create_stubbed_responses(self):
- return [
- {
- 'method': 'head_object',
- 'service_response': {'ContentLength': len(self.content)},
- },
- {
- 'method': 'create_multipart_upload',
- 'service_response': {'UploadId': 'my-upload-id'},
- },
- {
- 'method': 'upload_part_copy',
- 'service_response': {'CopyPartResult': {'ETag': 'etag-1'}},
- },
- {
- 'method': 'upload_part_copy',
- 'service_response': {'CopyPartResult': {'ETag': 'etag-2'}},
- },
- {
- 'method': 'upload_part_copy',
- 'service_response': {'CopyPartResult': {'ETag': 'etag-3'}},
- },
- {'method': 'complete_multipart_upload', 'service_response': {}},
- ]
-
- def create_expected_progress_callback_info(self):
- # Note that last read is from the empty sentinel indicating
- # that the stream is done.
- return [
- {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE},
- {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE},
- {'bytes_transferred': self.half_chunksize},
- ]
-
- def add_create_multipart_upload_response(self):
- self.stubber.add_response(**self.create_stubbed_responses()[1])
-
- def _get_expected_params(self):
- upload_id = 'my-upload-id'
-
- # Add expected parameters to the head object
- expected_head_params = {
- 'Bucket': 'mysourcebucket',
- 'Key': 'mysourcekey',
- }
-
- # Add expected parameters for the create multipart
- expected_create_mpu_params = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- }
-
- expected_copy_params = []
- # Add expected parameters to the copy part
- ranges = [
- 'bytes=0-5242879',
- 'bytes=5242880-10485759',
- 'bytes=10485760-13107199',
- ]
- for i, range_val in enumerate(ranges):
- expected_copy_params.append(
- {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'UploadId': upload_id,
- 'PartNumber': i + 1,
- 'CopySourceRange': range_val,
- }
- )
-
- # Add expected parameters for the complete multipart
- expected_complete_mpu_params = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': upload_id,
- 'MultipartUpload': {
- 'Parts': [
- {'ETag': 'etag-1', 'PartNumber': 1},
- {'ETag': 'etag-2', 'PartNumber': 2},
- {'ETag': 'etag-3', 'PartNumber': 3},
- ]
- },
- }
-
- return expected_head_params, {
- 'expected_create_mpu_params': expected_create_mpu_params,
- 'expected_copy_params': expected_copy_params,
- 'expected_complete_mpu_params': expected_complete_mpu_params,
- }
-
- def _add_params_to_expected_params(
- self, add_copy_kwargs, operation_types, new_params
- ):
-
- expected_params_to_update = []
- for operation_type in operation_types:
- add_copy_kwargs_key = 'expected_' + operation_type + '_params'
- expected_params = add_copy_kwargs[add_copy_kwargs_key]
- if isinstance(expected_params, list):
- expected_params_to_update.extend(expected_params)
- else:
- expected_params_to_update.append(expected_params)
-
- for expected_params in expected_params_to_update:
- expected_params.update(new_params)
-
- def test_copy(self):
- head_params, add_copy_kwargs = self._get_expected_params()
- self.add_head_object_response(expected_params=head_params)
- self.add_successful_copy_responses(**add_copy_kwargs)
-
- future = self.manager.copy(**self.create_call_kwargs())
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_with_extra_args(self):
- # This extra argument should be added to the head object,
- # the create multipart upload, and upload part copy.
- self.extra_args['RequestPayer'] = 'requester'
-
- head_params, add_copy_kwargs = self._get_expected_params()
- head_params.update(self.extra_args)
- self.add_head_object_response(expected_params=head_params)
-
- self._add_params_to_expected_params(
- add_copy_kwargs,
- ['create_mpu', 'copy', 'complete_mpu'],
- self.extra_args,
- )
- self.add_successful_copy_responses(**add_copy_kwargs)
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_blacklists_args_to_create_multipart(self):
- # This argument can never be used for multipart uploads
- self.extra_args['MetadataDirective'] = 'COPY'
-
- head_params, add_copy_kwargs = self._get_expected_params()
- self.add_head_object_response(expected_params=head_params)
- self.add_successful_copy_responses(**add_copy_kwargs)
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_args_to_only_create_multipart(self):
- self.extra_args['ACL'] = 'private'
-
- head_params, add_copy_kwargs = self._get_expected_params()
- self.add_head_object_response(expected_params=head_params)
-
- self._add_params_to_expected_params(
- add_copy_kwargs, ['create_mpu'], self.extra_args
- )
- self.add_successful_copy_responses(**add_copy_kwargs)
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_passes_args_to_create_multipart_and_upload_part(self):
- # This will only be used for the complete multipart upload
- # and upload part.
- self.extra_args['SSECustomerAlgorithm'] = 'AES256'
-
- head_params, add_copy_kwargs = self._get_expected_params()
- self.add_head_object_response(expected_params=head_params)
-
- self._add_params_to_expected_params(
- add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args
- )
- self.add_successful_copy_responses(**add_copy_kwargs)
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_copy_maps_extra_args_to_head_object(self):
- self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256'
-
- head_params, add_copy_kwargs = self._get_expected_params()
-
- # The CopySourceSSECustomerAlgorithm needs to get mapped to
- # SSECustomerAlgorithm for HeadObject
- head_params['SSECustomerAlgorithm'] = 'AES256'
- self.add_head_object_response(expected_params=head_params)
-
- # However, it needs to remain the same for UploadPartCopy.
- self._add_params_to_expected_params(
- add_copy_kwargs, ['copy'], self.extra_args
- )
- self.add_successful_copy_responses(**add_copy_kwargs)
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['extra_args'] = self.extra_args
- future = self.manager.copy(**call_kwargs)
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_abort_on_failure(self):
- # First add the head object and create multipart upload
- self.add_head_object_response()
- self.add_create_multipart_upload_response()
-
- # Cause an error on upload_part_copy
- self.stubber.add_client_error('upload_part_copy', 'ArbitraryFailure')
-
- # Add the abort multipart to ensure it gets cleaned up on failure
- self.stubber.add_response(
- 'abort_multipart_upload',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': 'my-upload-id',
- },
- )
-
- future = self.manager.copy(**self.create_call_kwargs())
- with self.assertRaisesRegex(ClientError, 'ArbitraryFailure'):
- future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_mp_copy_with_tagging_directive(self):
- extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'}
- self.add_head_object_response()
- self.add_successful_copy_responses(
- expected_create_mpu_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'Tagging': 'tag1=val1',
- }
- )
- future = self.manager.copy(
- self.copy_source, self.bucket, self.key, extra_args
- )
- future.result()
- self.stubber.assert_no_pending_responses()
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from botocore.exceptions import ClientError
+from botocore.stub import Stubber
+
+from s3transfer.manager import TransferConfig, TransferManager
+from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE
+from __tests__ import BaseGeneralInterfaceTest, FileSizeProvider
+
+
+class BaseCopyTest(BaseGeneralInterfaceTest):
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_chunksize=MIN_UPLOAD_CHUNKSIZE,
+ multipart_threshold=MIN_UPLOAD_CHUNKSIZE * 4,
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ # Initialize some default arguments
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
+ self.extra_args = {}
+ self.subscribers = []
+
+ self.half_chunksize = int(MIN_UPLOAD_CHUNKSIZE / 2)
+ self.content = b'0' * (2 * MIN_UPLOAD_CHUNKSIZE + self.half_chunksize)
+
+ @property
+ def manager(self):
+ return self._manager
+
+ @property
+ def method(self):
+ return self.manager.copy
+
+ def create_call_kwargs(self):
+ return {
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ }
+
+ def create_invalid_extra_args(self):
+ return {'Foo': 'bar'}
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {'method': 'copy_object', 'service_response': {}},
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return [
+ {'bytes_transferred': len(self.content)},
+ ]
+
+ def add_head_object_response(self, expected_params=None, stubber=None):
+ if not stubber:
+ stubber = self.stubber
+ head_response = self.create_stubbed_responses()[0]
+ if expected_params:
+ head_response['expected_params'] = expected_params
+ stubber.add_response(**head_response)
+
+ def add_successful_copy_responses(
+ self,
+ expected_copy_params=None,
+ expected_create_mpu_params=None,
+ expected_complete_mpu_params=None,
+ ):
+
+ # Add all responses needed to do the copy of the object.
+ # Should account for both ranged and nonranged downloads.
+ stubbed_responses = self.create_stubbed_responses()[1:]
+
+ # If the length of copy responses is greater than one then it is
+ # a multipart copy.
+ copy_responses = stubbed_responses[0:1]
+ if len(stubbed_responses) > 1:
+ copy_responses = stubbed_responses[1:-1]
+
+ # Add the expected create multipart upload params.
+ if expected_create_mpu_params:
+ stubbed_responses[0][
+ 'expected_params'
+ ] = expected_create_mpu_params
+
+ # Add any expected copy parameters.
+ if expected_copy_params:
+ for i, copy_response in enumerate(copy_responses):
+ if isinstance(expected_copy_params, list):
+ copy_response['expected_params'] = expected_copy_params[i]
+ else:
+ copy_response['expected_params'] = expected_copy_params
+
+ # Add the expected complete multipart upload params.
+ if expected_complete_mpu_params:
+ stubbed_responses[-1][
+ 'expected_params'
+ ] = expected_complete_mpu_params
+
+ # Add the responses to the stubber.
+ for stubbed_response in stubbed_responses:
+ self.stubber.add_response(**stubbed_response)
+
+ def test_can_provide_file_size(self):
+ self.add_successful_copy_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))]
+
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+
+ # The HeadObject should have not happened and should have been able
+ # to successfully copy the file.
+ self.stubber.assert_no_pending_responses()
+
+ def test_provide_copy_source_as_dict(self):
+ self.copy_source['VersionId'] = 'mysourceversionid'
+ expected_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ 'VersionId': 'mysourceversionid',
+ }
+
+ self.add_head_object_response(expected_params=expected_params)
+ self.add_successful_copy_responses()
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_invalid_copy_source(self):
+ self.copy_source = ['bucket', 'key']
+ future = self.manager.copy(**self.create_call_kwargs())
+ with self.assertRaises(TypeError):
+ future.result()
+
+ def test_provide_copy_source_client(self):
+ source_client = self.session.create_client(
+ 's3',
+ 'eu-central-1',
+ aws_access_key_id='foo',
+ aws_secret_access_key='bar',
+ )
+ source_stubber = Stubber(source_client)
+ source_stubber.activate()
+ self.addCleanup(source_stubber.deactivate)
+
+ self.add_head_object_response(stubber=source_stubber)
+ self.add_successful_copy_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['source_client'] = source_client
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+
+ # Make sure that all of the responses were properly
+ # used for both clients.
+ source_stubber.assert_no_pending_responses()
+ self.stubber.assert_no_pending_responses()
+
+
+class TestNonMultipartCopy(BaseCopyTest):
+ __test__ = True
+
+ def test_copy(self):
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ }
+ expected_copy_object = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ }
+ self.add_head_object_response(expected_params=expected_head_params)
+ self.add_successful_copy_responses(
+ expected_copy_params=expected_copy_object
+ )
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_with_extra_args(self):
+ self.extra_args['MetadataDirective'] = 'REPLACE'
+
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ }
+ expected_copy_object = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'MetadataDirective': 'REPLACE',
+ }
+
+ self.add_head_object_response(expected_params=expected_head_params)
+ self.add_successful_copy_responses(
+ expected_copy_params=expected_copy_object
+ )
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_maps_extra_args_to_head_object(self):
+ self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256'
+
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ 'SSECustomerAlgorithm': 'AES256',
+ }
+ expected_copy_object = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'CopySourceSSECustomerAlgorithm': 'AES256',
+ }
+
+ self.add_head_object_response(expected_params=expected_head_params)
+ self.add_successful_copy_responses(
+ expected_copy_params=expected_copy_object
+ )
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_allowed_copy_params_are_valid(self):
+ op_model = self.client.meta.service_model.operation_model('CopyObject')
+ for allowed_upload_arg in self._manager.ALLOWED_COPY_ARGS:
+ self.assertIn(allowed_upload_arg, op_model.input_shape.members)
+
+ def test_copy_with_tagging(self):
+ extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'}
+ self.add_head_object_response()
+ self.add_successful_copy_responses(
+ expected_copy_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'Tagging': 'tag1=val1',
+ 'TaggingDirective': 'REPLACE',
+ }
+ )
+ future = self.manager.copy(
+ self.copy_source, self.bucket, self.key, extra_args
+ )
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.copy(self.copy_source, s3_object_lambda_arn, self.key)
+
+ def test_raise_exception_on_s3_object_lambda_resource_as_source(self):
+ source = {
+ 'Bucket': 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ }
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.copy(source, self.bucket, self.key)
+
+
+class TestMultipartCopy(BaseCopyTest):
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_threshold=1,
+ multipart_chunksize=4,
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {
+ 'method': 'create_multipart_upload',
+ 'service_response': {'UploadId': 'my-upload-id'},
+ },
+ {
+ 'method': 'upload_part_copy',
+ 'service_response': {'CopyPartResult': {'ETag': 'etag-1'}},
+ },
+ {
+ 'method': 'upload_part_copy',
+ 'service_response': {'CopyPartResult': {'ETag': 'etag-2'}},
+ },
+ {
+ 'method': 'upload_part_copy',
+ 'service_response': {'CopyPartResult': {'ETag': 'etag-3'}},
+ },
+ {'method': 'complete_multipart_upload', 'service_response': {}},
+ ]
+
+ def create_expected_progress_callback_info(self):
+ # Note that last read is from the empty sentinel indicating
+ # that the stream is done.
+ return [
+ {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE},
+ {'bytes_transferred': MIN_UPLOAD_CHUNKSIZE},
+ {'bytes_transferred': self.half_chunksize},
+ ]
+
+ def add_create_multipart_upload_response(self):
+ self.stubber.add_response(**self.create_stubbed_responses()[1])
+
+ def _get_expected_params(self):
+ upload_id = 'my-upload-id'
+
+ # Add expected parameters to the head object
+ expected_head_params = {
+ 'Bucket': 'mysourcebucket',
+ 'Key': 'mysourcekey',
+ }
+
+ # Add expected parameters for the create multipart
+ expected_create_mpu_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ }
+
+ expected_copy_params = []
+ # Add expected parameters to the copy part
+ ranges = [
+ 'bytes=0-5242879',
+ 'bytes=5242880-10485759',
+ 'bytes=10485760-13107199',
+ ]
+ for i, range_val in enumerate(ranges):
+ expected_copy_params.append(
+ {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': upload_id,
+ 'PartNumber': i + 1,
+ 'CopySourceRange': range_val,
+ }
+ )
+
+ # Add expected parameters for the complete multipart
+ expected_complete_mpu_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {
+ 'Parts': [
+ {'ETag': 'etag-1', 'PartNumber': 1},
+ {'ETag': 'etag-2', 'PartNumber': 2},
+ {'ETag': 'etag-3', 'PartNumber': 3},
+ ]
+ },
+ }
+
+ return expected_head_params, {
+ 'expected_create_mpu_params': expected_create_mpu_params,
+ 'expected_copy_params': expected_copy_params,
+ 'expected_complete_mpu_params': expected_complete_mpu_params,
+ }
+
+ def _add_params_to_expected_params(
+ self, add_copy_kwargs, operation_types, new_params
+ ):
+
+ expected_params_to_update = []
+ for operation_type in operation_types:
+ add_copy_kwargs_key = 'expected_' + operation_type + '_params'
+ expected_params = add_copy_kwargs[add_copy_kwargs_key]
+ if isinstance(expected_params, list):
+ expected_params_to_update.extend(expected_params)
+ else:
+ expected_params_to_update.append(expected_params)
+
+ for expected_params in expected_params_to_update:
+ expected_params.update(new_params)
+
+ def test_copy(self):
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_with_extra_args(self):
+ # This extra argument should be added to the head object,
+ # the create multipart upload, and upload part copy.
+ self.extra_args['RequestPayer'] = 'requester'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ head_params.update(self.extra_args)
+ self.add_head_object_response(expected_params=head_params)
+
+ self._add_params_to_expected_params(
+ add_copy_kwargs,
+ ['create_mpu', 'copy', 'complete_mpu'],
+ self.extra_args,
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_blacklists_args_to_create_multipart(self):
+ # This argument can never be used for multipart uploads
+ self.extra_args['MetadataDirective'] = 'COPY'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_args_to_only_create_multipart(self):
+ self.extra_args['ACL'] = 'private'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+
+ self._add_params_to_expected_params(
+ add_copy_kwargs, ['create_mpu'], self.extra_args
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_passes_args_to_create_multipart_and_upload_part(self):
+ # This will only be used for the complete multipart upload
+ # and upload part.
+ self.extra_args['SSECustomerAlgorithm'] = 'AES256'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+ self.add_head_object_response(expected_params=head_params)
+
+ self._add_params_to_expected_params(
+ add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_copy_maps_extra_args_to_head_object(self):
+ self.extra_args['CopySourceSSECustomerAlgorithm'] = 'AES256'
+
+ head_params, add_copy_kwargs = self._get_expected_params()
+
+ # The CopySourceSSECustomerAlgorithm needs to get mapped to
+ # SSECustomerAlgorithm for HeadObject
+ head_params['SSECustomerAlgorithm'] = 'AES256'
+ self.add_head_object_response(expected_params=head_params)
+
+ # However, it needs to remain the same for UploadPartCopy.
+ self._add_params_to_expected_params(
+ add_copy_kwargs, ['copy'], self.extra_args
+ )
+ self.add_successful_copy_responses(**add_copy_kwargs)
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['extra_args'] = self.extra_args
+ future = self.manager.copy(**call_kwargs)
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_abort_on_failure(self):
+ # First add the head object and create multipart upload
+ self.add_head_object_response()
+ self.add_create_multipart_upload_response()
+
+ # Cause an error on upload_part_copy
+ self.stubber.add_client_error('upload_part_copy', 'ArbitraryFailure')
+
+ # Add the abort multipart to ensure it gets cleaned up on failure
+ self.stubber.add_response(
+ 'abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': 'my-upload-id',
+ },
+ )
+
+ future = self.manager.copy(**self.create_call_kwargs())
+ with self.assertRaisesRegex(ClientError, 'ArbitraryFailure'):
+ future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_mp_copy_with_tagging_directive(self):
+ extra_args = {'Tagging': 'tag1=val1', 'TaggingDirective': 'REPLACE'}
+ self.add_head_object_response()
+ self.add_successful_copy_responses(
+ expected_create_mpu_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Tagging': 'tag1=val1',
+ }
+ )
+ future = self.manager.copy(
+ self.copy_source, self.bucket, self.key, extra_args
+ )
+ future.result()
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_crt.py b/contrib/python/s3transfer/py3/tests/functional/test_crt.py
index 9ebd48650d..fad0f4b23b 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_crt.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_crt.py
@@ -1,267 +1,267 @@
-# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import fnmatch
-import threading
-import time
-from concurrent.futures import Future
-
-from botocore.session import Session
-
-from s3transfer.subscribers import BaseSubscriber
-from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
-
-if HAS_CRT:
- import awscrt
-
- import s3transfer.crt
-
-
-class submitThread(threading.Thread):
- def __init__(self, transfer_manager, futures, callargs):
- threading.Thread.__init__(self)
- self._transfer_manager = transfer_manager
- self._futures = futures
- self._callargs = callargs
-
- def run(self):
- self._futures.append(self._transfer_manager.download(*self._callargs))
-
-
-class RecordingSubscriber(BaseSubscriber):
- def __init__(self):
- self.on_queued_called = False
- self.on_done_called = False
- self.bytes_transferred = 0
- self.on_queued_future = None
- self.on_done_future = None
-
- def on_queued(self, future, **kwargs):
- self.on_queued_called = True
- self.on_queued_future = future
-
- def on_done(self, future, **kwargs):
- self.on_done_called = True
- self.on_done_future = future
-
-
-@requires_crt
-class TestCRTTransferManager(unittest.TestCase):
- def setUp(self):
- self.region = 'us-west-2'
- self.bucket = "test_bucket"
- self.key = "test_key"
- self.files = FileCreator()
- self.filename = self.files.create_file('myfile', 'my content')
- self.expected_path = "/" + self.bucket + "/" + self.key
- self.expected_host = "s3.%s.amazonaws.com" % (self.region)
- self.s3_request = mock.Mock(awscrt.s3.S3Request)
- self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
- self.s3_crt_client.make_request.return_value = self.s3_request
- self.session = Session()
- self.session.set_config_variable('region', self.region)
- self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
- self.session
- )
- self.transfer_manager = s3transfer.crt.CRTTransferManager(
- crt_s3_client=self.s3_crt_client,
- crt_request_serializer=self.request_serializer,
- )
- self.record_subscriber = RecordingSubscriber()
-
- def tearDown(self):
- self.files.remove_all()
-
- def _assert_subscribers_called(self, expected_future=None):
- self.assertTrue(self.record_subscriber.on_queued_called)
- self.assertTrue(self.record_subscriber.on_done_called)
- if expected_future:
- self.assertIs(
- self.record_subscriber.on_queued_future, expected_future
- )
- self.assertIs(
- self.record_subscriber.on_done_future, expected_future
- )
-
- def _invoke_done_callbacks(self, **kwargs):
- callargs = self.s3_crt_client.make_request.call_args
- callargs_kwargs = callargs[1]
- on_done = callargs_kwargs["on_done"]
- on_done(error=None)
-
- def _simulate_file_download(self, recv_filepath):
- self.files.create_file(recv_filepath, "fake response")
-
- def _simulate_make_request_side_effect(self, **kwargs):
- if kwargs.get('recv_filepath'):
- self._simulate_file_download(kwargs['recv_filepath'])
- self._invoke_done_callbacks()
- return mock.DEFAULT
-
- def test_upload(self):
- self.s3_crt_client.make_request.side_effect = (
- self._simulate_make_request_side_effect
- )
- future = self.transfer_manager.upload(
- self.filename, self.bucket, self.key, {}, [self.record_subscriber]
- )
- future.result()
-
- callargs = self.s3_crt_client.make_request.call_args
- callargs_kwargs = callargs[1]
- self.assertEqual(callargs_kwargs["send_filepath"], self.filename)
- self.assertIsNone(callargs_kwargs["recv_filepath"])
- self.assertEqual(
- callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT
- )
- crt_request = callargs_kwargs["request"]
- self.assertEqual("PUT", crt_request.method)
- self.assertEqual(self.expected_path, crt_request.path)
- self.assertEqual(self.expected_host, crt_request.headers.get("host"))
- self._assert_subscribers_called(future)
-
- def test_download(self):
- self.s3_crt_client.make_request.side_effect = (
- self._simulate_make_request_side_effect
- )
- future = self.transfer_manager.download(
- self.bucket, self.key, self.filename, {}, [self.record_subscriber]
- )
- future.result()
-
- callargs = self.s3_crt_client.make_request.call_args
- callargs_kwargs = callargs[1]
- # the recv_filepath will be set to a temporary file path with some
- # random suffix
- self.assertTrue(
- fnmatch.fnmatch(
- callargs_kwargs["recv_filepath"],
- f'{self.filename}.*',
- )
- )
- self.assertIsNone(callargs_kwargs["send_filepath"])
- self.assertEqual(
- callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT
- )
- crt_request = callargs_kwargs["request"]
- self.assertEqual("GET", crt_request.method)
- self.assertEqual(self.expected_path, crt_request.path)
- self.assertEqual(self.expected_host, crt_request.headers.get("host"))
- self._assert_subscribers_called(future)
- with open(self.filename, 'rb') as f:
- # Check the fake response overwrites the file because of download
- self.assertEqual(f.read(), b'fake response')
-
- def test_delete(self):
- self.s3_crt_client.make_request.side_effect = (
- self._simulate_make_request_side_effect
- )
- future = self.transfer_manager.delete(
- self.bucket, self.key, {}, [self.record_subscriber]
- )
- future.result()
-
- callargs = self.s3_crt_client.make_request.call_args
- callargs_kwargs = callargs[1]
- self.assertIsNone(callargs_kwargs["send_filepath"])
- self.assertIsNone(callargs_kwargs["recv_filepath"])
- self.assertEqual(
- callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT
- )
- crt_request = callargs_kwargs["request"]
- self.assertEqual("DELETE", crt_request.method)
- self.assertEqual(self.expected_path, crt_request.path)
- self.assertEqual(self.expected_host, crt_request.headers.get("host"))
- self._assert_subscribers_called(future)
-
- def test_blocks_when_max_requests_processes_reached(self):
- futures = []
- callargs = (self.bucket, self.key, self.filename, {}, [])
- max_request_processes = 128 # the hard coded max processes
- all_concurrent = max_request_processes + 1
- threads = []
- for i in range(0, all_concurrent):
- thread = submitThread(self.transfer_manager, futures, callargs)
- thread.start()
- threads.append(thread)
- # Sleep until the expected max requests has been reached
- while len(futures) < max_request_processes:
- time.sleep(0.05)
- self.assertLessEqual(
- self.s3_crt_client.make_request.call_count, max_request_processes
- )
- # Release lock
- callargs = self.s3_crt_client.make_request.call_args
- callargs_kwargs = callargs[1]
- on_done = callargs_kwargs["on_done"]
- on_done(error=None)
- for thread in threads:
- thread.join()
- self.assertEqual(
- self.s3_crt_client.make_request.call_count, all_concurrent
- )
-
- def _cancel_function(self):
- self.cancel_called = True
- self.s3_request.finished_future.set_exception(
- awscrt.exceptions.from_code(0)
- )
- self._invoke_done_callbacks()
-
- def test_cancel(self):
- self.s3_request.finished_future = Future()
- self.cancel_called = False
- self.s3_request.cancel = self._cancel_function
- try:
- with self.transfer_manager:
- future = self.transfer_manager.upload(
- self.filename, self.bucket, self.key, {}, []
- )
- raise KeyboardInterrupt()
- except KeyboardInterrupt:
- pass
-
- with self.assertRaises(awscrt.exceptions.AwsCrtError):
- future.result()
- self.assertTrue(self.cancel_called)
-
- def test_serializer_error_handling(self):
- class SerializationException(Exception):
- pass
-
- class ExceptionRaisingSerializer(
- s3transfer.crt.BaseCRTRequestSerializer
- ):
- def serialize_http_request(self, transfer_type, future):
- raise SerializationException()
-
- not_impl_serializer = ExceptionRaisingSerializer()
- transfer_manager = s3transfer.crt.CRTTransferManager(
- crt_s3_client=self.s3_crt_client,
- crt_request_serializer=not_impl_serializer,
- )
- future = transfer_manager.upload(
- self.filename, self.bucket, self.key, {}, []
- )
-
- with self.assertRaises(SerializationException):
- future.result()
-
- def test_crt_s3_client_error_handling(self):
- self.s3_crt_client.make_request.side_effect = (
- awscrt.exceptions.from_code(0)
- )
- future = self.transfer_manager.upload(
- self.filename, self.bucket, self.key, {}, []
- )
- with self.assertRaises(awscrt.exceptions.AwsCrtError):
- future.result()
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import fnmatch
+import threading
+import time
+from concurrent.futures import Future
+
+from botocore.session import Session
+
+from s3transfer.subscribers import BaseSubscriber
+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+if HAS_CRT:
+ import awscrt
+
+ import s3transfer.crt
+
+
+class submitThread(threading.Thread):
+ def __init__(self, transfer_manager, futures, callargs):
+ threading.Thread.__init__(self)
+ self._transfer_manager = transfer_manager
+ self._futures = futures
+ self._callargs = callargs
+
+ def run(self):
+ self._futures.append(self._transfer_manager.download(*self._callargs))
+
+
+class RecordingSubscriber(BaseSubscriber):
+ def __init__(self):
+ self.on_queued_called = False
+ self.on_done_called = False
+ self.bytes_transferred = 0
+ self.on_queued_future = None
+ self.on_done_future = None
+
+ def on_queued(self, future, **kwargs):
+ self.on_queued_called = True
+ self.on_queued_future = future
+
+ def on_done(self, future, **kwargs):
+ self.on_done_called = True
+ self.on_done_future = future
+
+
+@requires_crt
+class TestCRTTransferManager(unittest.TestCase):
+ def setUp(self):
+ self.region = 'us-west-2'
+ self.bucket = "test_bucket"
+ self.key = "test_key"
+ self.files = FileCreator()
+ self.filename = self.files.create_file('myfile', 'my content')
+ self.expected_path = "/" + self.bucket + "/" + self.key
+ self.expected_host = "s3.%s.amazonaws.com" % (self.region)
+ self.s3_request = mock.Mock(awscrt.s3.S3Request)
+ self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
+ self.s3_crt_client.make_request.return_value = self.s3_request
+ self.session = Session()
+ self.session.set_config_variable('region', self.region)
+ self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
+ self.session
+ )
+ self.transfer_manager = s3transfer.crt.CRTTransferManager(
+ crt_s3_client=self.s3_crt_client,
+ crt_request_serializer=self.request_serializer,
+ )
+ self.record_subscriber = RecordingSubscriber()
+
+ def tearDown(self):
+ self.files.remove_all()
+
+ def _assert_subscribers_called(self, expected_future=None):
+ self.assertTrue(self.record_subscriber.on_queued_called)
+ self.assertTrue(self.record_subscriber.on_done_called)
+ if expected_future:
+ self.assertIs(
+ self.record_subscriber.on_queued_future, expected_future
+ )
+ self.assertIs(
+ self.record_subscriber.on_done_future, expected_future
+ )
+
+ def _invoke_done_callbacks(self, **kwargs):
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ on_done = callargs_kwargs["on_done"]
+ on_done(error=None)
+
+ def _simulate_file_download(self, recv_filepath):
+ self.files.create_file(recv_filepath, "fake response")
+
+ def _simulate_make_request_side_effect(self, **kwargs):
+ if kwargs.get('recv_filepath'):
+ self._simulate_file_download(kwargs['recv_filepath'])
+ self._invoke_done_callbacks()
+ return mock.DEFAULT
+
+ def test_upload(self):
+ self.s3_crt_client.make_request.side_effect = (
+ self._simulate_make_request_side_effect
+ )
+ future = self.transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, [self.record_subscriber]
+ )
+ future.result()
+
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ self.assertEqual(callargs_kwargs["send_filepath"], self.filename)
+ self.assertIsNone(callargs_kwargs["recv_filepath"])
+ self.assertEqual(
+ callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT
+ )
+ crt_request = callargs_kwargs["request"]
+ self.assertEqual("PUT", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self._assert_subscribers_called(future)
+
+ def test_download(self):
+ self.s3_crt_client.make_request.side_effect = (
+ self._simulate_make_request_side_effect
+ )
+ future = self.transfer_manager.download(
+ self.bucket, self.key, self.filename, {}, [self.record_subscriber]
+ )
+ future.result()
+
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ # the recv_filepath will be set to a temporary file path with some
+ # random suffix
+ self.assertTrue(
+ fnmatch.fnmatch(
+ callargs_kwargs["recv_filepath"],
+ f'{self.filename}.*',
+ )
+ )
+ self.assertIsNone(callargs_kwargs["send_filepath"])
+ self.assertEqual(
+ callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT
+ )
+ crt_request = callargs_kwargs["request"]
+ self.assertEqual("GET", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self._assert_subscribers_called(future)
+ with open(self.filename, 'rb') as f:
+ # Check the fake response overwrites the file because of download
+ self.assertEqual(f.read(), b'fake response')
+
+ def test_delete(self):
+ self.s3_crt_client.make_request.side_effect = (
+ self._simulate_make_request_side_effect
+ )
+ future = self.transfer_manager.delete(
+ self.bucket, self.key, {}, [self.record_subscriber]
+ )
+ future.result()
+
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ self.assertIsNone(callargs_kwargs["send_filepath"])
+ self.assertIsNone(callargs_kwargs["recv_filepath"])
+ self.assertEqual(
+ callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT
+ )
+ crt_request = callargs_kwargs["request"]
+ self.assertEqual("DELETE", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self._assert_subscribers_called(future)
+
+ def test_blocks_when_max_requests_processes_reached(self):
+ futures = []
+ callargs = (self.bucket, self.key, self.filename, {}, [])
+ max_request_processes = 128 # the hard coded max processes
+ all_concurrent = max_request_processes + 1
+ threads = []
+ for i in range(0, all_concurrent):
+ thread = submitThread(self.transfer_manager, futures, callargs)
+ thread.start()
+ threads.append(thread)
+ # Sleep until the expected max requests has been reached
+ while len(futures) < max_request_processes:
+ time.sleep(0.05)
+ self.assertLessEqual(
+ self.s3_crt_client.make_request.call_count, max_request_processes
+ )
+ # Release lock
+ callargs = self.s3_crt_client.make_request.call_args
+ callargs_kwargs = callargs[1]
+ on_done = callargs_kwargs["on_done"]
+ on_done(error=None)
+ for thread in threads:
+ thread.join()
+ self.assertEqual(
+ self.s3_crt_client.make_request.call_count, all_concurrent
+ )
+
+ def _cancel_function(self):
+ self.cancel_called = True
+ self.s3_request.finished_future.set_exception(
+ awscrt.exceptions.from_code(0)
+ )
+ self._invoke_done_callbacks()
+
+ def test_cancel(self):
+ self.s3_request.finished_future = Future()
+ self.cancel_called = False
+ self.s3_request.cancel = self._cancel_function
+ try:
+ with self.transfer_manager:
+ future = self.transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, []
+ )
+ raise KeyboardInterrupt()
+ except KeyboardInterrupt:
+ pass
+
+ with self.assertRaises(awscrt.exceptions.AwsCrtError):
+ future.result()
+ self.assertTrue(self.cancel_called)
+
+ def test_serializer_error_handling(self):
+ class SerializationException(Exception):
+ pass
+
+ class ExceptionRaisingSerializer(
+ s3transfer.crt.BaseCRTRequestSerializer
+ ):
+ def serialize_http_request(self, transfer_type, future):
+ raise SerializationException()
+
+ not_impl_serializer = ExceptionRaisingSerializer()
+ transfer_manager = s3transfer.crt.CRTTransferManager(
+ crt_s3_client=self.s3_crt_client,
+ crt_request_serializer=not_impl_serializer,
+ )
+ future = transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, []
+ )
+
+ with self.assertRaises(SerializationException):
+ future.result()
+
+ def test_crt_s3_client_error_handling(self):
+ self.s3_crt_client.make_request.side_effect = (
+ awscrt.exceptions.from_code(0)
+ )
+ future = self.transfer_manager.upload(
+ self.filename, self.bucket, self.key, {}, []
+ )
+ with self.assertRaises(awscrt.exceptions.AwsCrtError):
+ future.result()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_delete.py b/contrib/python/s3transfer/py3/tests/functional/test_delete.py
index 6d53537448..28587a47a4 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_delete.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_delete.py
@@ -1,76 +1,76 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from s3transfer.manager import TransferManager
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.manager import TransferManager
from __tests__ import BaseGeneralInterfaceTest
-
-
-class TestDeleteObject(BaseGeneralInterfaceTest):
-
- __test__ = True
-
- def setUp(self):
- super().setUp()
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.manager = TransferManager(self.client)
-
- @property
- def method(self):
- """The transfer manager method to invoke i.e. upload()"""
- return self.manager.delete
-
- def create_call_kwargs(self):
- """The kwargs to be passed to the transfer manager method"""
- return {
- 'bucket': self.bucket,
- 'key': self.key,
- }
-
- def create_invalid_extra_args(self):
- return {
- 'BadKwargs': True,
- }
-
- def create_stubbed_responses(self):
- """A list of stubbed responses that will cause the request to succeed
-
- The elements of this list is a dictionary that will be used as key
- word arguments to botocore.Stubber.add_response(). For example::
-
- [{'method': 'put_object', 'service_response': {}}]
- """
- return [
- {
- 'method': 'delete_object',
- 'service_response': {},
- 'expected_params': {'Bucket': self.bucket, 'Key': self.key},
- }
- ]
-
- def create_expected_progress_callback_info(self):
- return []
-
- def test_known_allowed_args_in_input_shape(self):
- op_model = self.client.meta.service_model.operation_model(
- 'DeleteObject'
- )
- for allowed_arg in self.manager.ALLOWED_DELETE_ARGS:
- self.assertIn(allowed_arg, op_model.input_shape.members)
-
- def test_raise_exception_on_s3_object_lambda_resource(self):
- s3_object_lambda_arn = (
- 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
- 'accesspoint:my-accesspoint'
- )
- with self.assertRaisesRegex(ValueError, 'methods do not support'):
- self.manager.delete(s3_object_lambda_arn, self.key)
+
+
+class TestDeleteObject(BaseGeneralInterfaceTest):
+
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.manager = TransferManager(self.client)
+
+ @property
+ def method(self):
+ """The transfer manager method to invoke i.e. upload()"""
+ return self.manager.delete
+
+ def create_call_kwargs(self):
+ """The kwargs to be passed to the transfer manager method"""
+ return {
+ 'bucket': self.bucket,
+ 'key': self.key,
+ }
+
+ def create_invalid_extra_args(self):
+ return {
+ 'BadKwargs': True,
+ }
+
+ def create_stubbed_responses(self):
+ """A list of stubbed responses that will cause the request to succeed
+
+ The elements of this list is a dictionary that will be used as key
+ word arguments to botocore.Stubber.add_response(). For example::
+
+ [{'method': 'put_object', 'service_response': {}}]
+ """
+ return [
+ {
+ 'method': 'delete_object',
+ 'service_response': {},
+ 'expected_params': {'Bucket': self.bucket, 'Key': self.key},
+ }
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return []
+
+ def test_known_allowed_args_in_input_shape(self):
+ op_model = self.client.meta.service_model.operation_model(
+ 'DeleteObject'
+ )
+ for allowed_arg in self.manager.ALLOWED_DELETE_ARGS:
+ self.assertIn(allowed_arg, op_model.input_shape.members)
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.delete(s3_object_lambda_arn, self.key)
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_download.py b/contrib/python/s3transfer/py3/tests/functional/test_download.py
index 746c040d42..64a8a1309d 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_download.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_download.py
@@ -1,497 +1,497 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import copy
-import glob
-import os
-import shutil
-import tempfile
-import time
-from io import BytesIO
-
-from botocore.exceptions import ClientError
-
-from s3transfer.compat import SOCKET_ERROR
-from s3transfer.exceptions import RetriesExceededError
-from s3transfer.manager import TransferConfig, TransferManager
-from __tests__ import (
- BaseGeneralInterfaceTest,
- FileSizeProvider,
- NonSeekableWriter,
- RecordingOSUtils,
- RecordingSubscriber,
- StreamWithError,
- skip_if_using_serial_implementation,
- skip_if_windows,
-)
-
-
-class BaseDownloadTest(BaseGeneralInterfaceTest):
- def setUp(self):
- super().setUp()
- self.config = TransferConfig(max_request_concurrency=1)
- self._manager = TransferManager(self.client, self.config)
-
- # Create a temporary directory to write to
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
-
- # Initialize some default arguments
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.extra_args = {}
- self.subscribers = []
-
- # Create a stream to read from
- self.content = b'my content'
- self.stream = BytesIO(self.content)
-
- def tearDown(self):
- super().tearDown()
- shutil.rmtree(self.tempdir)
-
- @property
- def manager(self):
- return self._manager
-
- @property
- def method(self):
- return self.manager.download
-
- def create_call_kwargs(self):
- return {
- 'bucket': self.bucket,
- 'key': self.key,
- 'fileobj': self.filename,
- }
-
- def create_invalid_extra_args(self):
- return {'Foo': 'bar'}
-
- def create_stubbed_responses(self):
- # We want to make sure the beginning of the stream is always used
- # in case this gets called twice.
- self.stream.seek(0)
- return [
- {
- 'method': 'head_object',
- 'service_response': {'ContentLength': len(self.content)},
- },
- {
- 'method': 'get_object',
- 'service_response': {'Body': self.stream},
- },
- ]
-
- def create_expected_progress_callback_info(self):
- # Note that last read is from the empty sentinel indicating
- # that the stream is done.
- return [{'bytes_transferred': 10}]
-
- def add_head_object_response(self, expected_params=None):
- head_response = self.create_stubbed_responses()[0]
- if expected_params:
- head_response['expected_params'] = expected_params
- self.stubber.add_response(**head_response)
-
- def add_successful_get_object_responses(
- self, expected_params=None, expected_ranges=None
- ):
- # Add all get_object responses needed to complete the download.
- # Should account for both ranged and nonranged downloads.
- for i, stubbed_response in enumerate(
- self.create_stubbed_responses()[1:]
- ):
- if expected_params:
- stubbed_response['expected_params'] = copy.deepcopy(
- expected_params
- )
- if expected_ranges:
- stubbed_response['expected_params'][
- 'Range'
- ] = expected_ranges[i]
- self.stubber.add_response(**stubbed_response)
-
- def add_n_retryable_get_object_responses(self, n, num_reads=0):
- for _ in range(n):
- self.stubber.add_response(
- method='get_object',
- service_response={
- 'Body': StreamWithError(
- copy.deepcopy(self.stream), SOCKET_ERROR, num_reads
- )
- },
- )
-
- def test_download_temporary_file_does_not_exist(self):
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- future = self.manager.download(**self.create_call_kwargs())
- future.result()
- # Make sure the file exists
- self.assertTrue(os.path.exists(self.filename))
- # Make sure the random temporary file does not exist
- possible_matches = glob.glob('%s*' % self.filename + os.extsep)
- self.assertEqual(possible_matches, [])
-
- def test_download_for_fileobj(self):
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- with open(self.filename, 'wb') as f:
- future = self.manager.download(
- self.bucket, self.key, f, self.extra_args
- )
- future.result()
-
- # Ensure that the contents are correct
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
- def test_download_for_seekable_filelike_obj(self):
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- # Create a file-like object to test. In this case, it is a BytesIO
- # object.
- bytes_io = BytesIO()
-
- future = self.manager.download(
- self.bucket, self.key, bytes_io, self.extra_args
- )
- future.result()
-
- # Ensure that the contents are correct
- bytes_io.seek(0)
- self.assertEqual(self.content, bytes_io.read())
-
- def test_download_for_nonseekable_filelike_obj(self):
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- with open(self.filename, 'wb') as f:
- future = self.manager.download(
- self.bucket, self.key, NonSeekableWriter(f), self.extra_args
- )
- future.result()
-
- # Ensure that the contents are correct
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
- def test_download_cleanup_on_failure(self):
- self.add_head_object_response()
-
- # Throw an error on the download
- self.stubber.add_client_error('get_object')
-
- future = self.manager.download(**self.create_call_kwargs())
-
- with self.assertRaises(ClientError):
- future.result()
- # Make sure the actual file and the temporary do not exist
- # by globbing for the file and any of its extensions
- possible_matches = glob.glob('%s*' % self.filename)
- self.assertEqual(possible_matches, [])
-
- def test_download_with_nonexistent_directory(self):
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['fileobj'] = os.path.join(
- self.tempdir, 'missing-directory', 'myfile'
- )
- future = self.manager.download(**call_kwargs)
- with self.assertRaises(IOError):
- future.result()
-
- def test_retries_and_succeeds(self):
- self.add_head_object_response()
- # Insert a response that will trigger a retry.
- self.add_n_retryable_get_object_responses(1)
- # Add the normal responses to simulate the download proceeding
- # as normal after the retry.
- self.add_successful_get_object_responses()
-
- future = self.manager.download(**self.create_call_kwargs())
- future.result()
-
- # The retry should have been consumed and the process should have
- # continued using the successful responses.
- self.stubber.assert_no_pending_responses()
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
- def test_retry_failure(self):
- self.add_head_object_response()
-
- max_retries = 3
- self.config.num_download_attempts = max_retries
- self._manager = TransferManager(self.client, self.config)
- # Add responses that fill up the maximum number of retries.
- self.add_n_retryable_get_object_responses(max_retries)
-
- future = self.manager.download(**self.create_call_kwargs())
-
- # A retry exceeded error should have happened.
- with self.assertRaises(RetriesExceededError):
- future.result()
-
- # All of the retries should have been used up.
- self.stubber.assert_no_pending_responses()
-
- def test_retry_rewinds_callbacks(self):
- self.add_head_object_response()
- # Insert a response that will trigger a retry after one read of the
- # stream has been made.
- self.add_n_retryable_get_object_responses(1, num_reads=1)
- # Add the normal responses to simulate the download proceeding
- # as normal after the retry.
- self.add_successful_get_object_responses()
-
- recorder_subscriber = RecordingSubscriber()
- # Set the streaming to a size that is smaller than the data we
- # currently provide to it to simulate rewinds of callbacks.
- self.config.io_chunksize = 3
- future = self.manager.download(
- subscribers=[recorder_subscriber], **self.create_call_kwargs()
- )
- future.result()
-
- # Ensure that there is no more remaining responses and that contents
- # are correct.
- self.stubber.assert_no_pending_responses()
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
- # Assert that the number of bytes seen is equal to the length of
- # downloaded content.
- self.assertEqual(
- recorder_subscriber.calculate_bytes_seen(), len(self.content)
- )
-
- # Also ensure that the second progress invocation was negative three
- # because a retry happened on the second read of the stream and we
- # know that the chunk size for each read is 3.
- progress_byte_amts = [
- call['bytes_transferred']
- for call in recorder_subscriber.on_progress_calls
- ]
- self.assertEqual(-3, progress_byte_amts[1])
-
- def test_can_provide_file_size(self):
- self.add_successful_get_object_responses()
-
- call_kwargs = self.create_call_kwargs()
- call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))]
-
- future = self.manager.download(**call_kwargs)
- future.result()
-
- # The HeadObject should have not happened and should have been able
- # to successfully download the file.
- self.stubber.assert_no_pending_responses()
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
- def test_uses_provided_osutil(self):
- osutil = RecordingOSUtils()
- # Use the recording os utility for the transfer manager
- self._manager = TransferManager(self.client, self.config, osutil)
-
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- future = self.manager.download(**self.create_call_kwargs())
- future.result()
- # The osutil should have had its open() method invoked when opening
- # a temporary file and its rename_file() method invoked when the
- # the temporary file was moved to its final location.
- self.assertEqual(len(osutil.open_records), 1)
- self.assertEqual(len(osutil.rename_records), 1)
-
- @skip_if_windows('Windows does not support UNIX special files')
- @skip_if_using_serial_implementation(
- 'A separate thread is needed to read from the fifo'
- )
- def test_download_for_fifo_file(self):
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- # Create the fifo file
- os.mkfifo(self.filename)
-
- future = self.manager.download(
- self.bucket, self.key, self.filename, self.extra_args
- )
-
- # The call to open a fifo will block until there is both a reader
- # and a writer, so we need to open it for reading after we've
- # started the transfer.
- with open(self.filename, 'rb') as fifo:
- future.result()
- self.assertEqual(fifo.read(), self.content)
-
- def test_raise_exception_on_s3_object_lambda_resource(self):
- s3_object_lambda_arn = (
- 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
- 'accesspoint:my-accesspoint'
- )
- with self.assertRaisesRegex(ValueError, 'methods do not support'):
- self.manager.download(
- s3_object_lambda_arn, self.key, self.filename, self.extra_args
- )
-
-
-class TestNonRangedDownload(BaseDownloadTest):
- # TODO: If you want to add tests outside of this test class and still
- # subclass from BaseDownloadTest you need to set ``__test__ = True``. If
- # you do not, your tests will not get picked up by the test runner! This
- # needs to be done until we find a better way to ignore running test cases
- # from the general test base class, which we do not want ran.
- __test__ = True
-
- def test_download(self):
- self.extra_args['RequestPayer'] = 'requester'
- expected_params = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'RequestPayer': 'requester',
- }
- self.add_head_object_response(expected_params)
- self.add_successful_get_object_responses(expected_params)
- future = self.manager.download(
- self.bucket, self.key, self.filename, self.extra_args
- )
- future.result()
-
- # Ensure that the contents are correct
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
- def test_allowed_copy_params_are_valid(self):
- op_model = self.client.meta.service_model.operation_model('GetObject')
- for allowed_upload_arg in self._manager.ALLOWED_DOWNLOAD_ARGS:
- self.assertIn(allowed_upload_arg, op_model.input_shape.members)
-
- def test_download_empty_object(self):
- self.content = b''
- self.stream = BytesIO(self.content)
- self.add_head_object_response()
- self.add_successful_get_object_responses()
- future = self.manager.download(
- self.bucket, self.key, self.filename, self.extra_args
- )
- future.result()
-
- # Ensure that the empty file exists
- with open(self.filename, 'rb') as f:
- self.assertEqual(b'', f.read())
-
- def test_uses_bandwidth_limiter(self):
- self.content = b'a' * 1024 * 1024
- self.stream = BytesIO(self.content)
- self.config = TransferConfig(
- max_request_concurrency=1, max_bandwidth=len(self.content) / 2
- )
- self._manager = TransferManager(self.client, self.config)
-
- self.add_head_object_response()
- self.add_successful_get_object_responses()
-
- start = time.time()
- future = self.manager.download(
- self.bucket, self.key, self.filename, self.extra_args
- )
- future.result()
- # This is just a smoke test to make sure that the limiter is
- # being used and not necessary its exactness. So we set the maximum
- # bandwidth to len(content)/2 per sec and make sure that it is
- # noticeably slower. Ideally it will take more than two seconds, but
- # given tracking at the beginning of transfers are not entirely
- # accurate setting at the initial start of a transfer, we give us
- # some flexibility by setting the expected time to half of the
- # theoretical time to take.
- self.assertGreaterEqual(time.time() - start, 1)
-
- # Ensure that the contents are correct
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
-
-
-class TestRangedDownload(BaseDownloadTest):
- # TODO: If you want to add tests outside of this test class and still
- # subclass from BaseDownloadTest you need to set ``__test__ = True``. If
- # you do not, your tests will not get picked up by the test runner! This
- # needs to be done until we find a better way to ignore running test cases
- # from the general test base class, which we do not want ran.
- __test__ = True
-
- def setUp(self):
- super().setUp()
- self.config = TransferConfig(
- max_request_concurrency=1,
- multipart_threshold=1,
- multipart_chunksize=4,
- )
- self._manager = TransferManager(self.client, self.config)
-
- def create_stubbed_responses(self):
- return [
- {
- 'method': 'head_object',
- 'service_response': {'ContentLength': len(self.content)},
- },
- {
- 'method': 'get_object',
- 'service_response': {'Body': BytesIO(self.content[0:4])},
- },
- {
- 'method': 'get_object',
- 'service_response': {'Body': BytesIO(self.content[4:8])},
- },
- {
- 'method': 'get_object',
- 'service_response': {'Body': BytesIO(self.content[8:])},
- },
- ]
-
- def create_expected_progress_callback_info(self):
- return [
- {'bytes_transferred': 4},
- {'bytes_transferred': 4},
- {'bytes_transferred': 2},
- ]
-
- def test_download(self):
- self.extra_args['RequestPayer'] = 'requester'
- expected_params = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'RequestPayer': 'requester',
- }
- expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-']
- self.add_head_object_response(expected_params)
- self.add_successful_get_object_responses(
- expected_params, expected_ranges
- )
-
- future = self.manager.download(
- self.bucket, self.key, self.filename, self.extra_args
- )
- future.result()
-
- # Ensure that the contents are correct
- with open(self.filename, 'rb') as f:
- self.assertEqual(self.content, f.read())
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import glob
+import os
+import shutil
+import tempfile
+import time
+from io import BytesIO
+
+from botocore.exceptions import ClientError
+
+from s3transfer.compat import SOCKET_ERROR
+from s3transfer.exceptions import RetriesExceededError
+from s3transfer.manager import TransferConfig, TransferManager
+from __tests__ import (
+ BaseGeneralInterfaceTest,
+ FileSizeProvider,
+ NonSeekableWriter,
+ RecordingOSUtils,
+ RecordingSubscriber,
+ StreamWithError,
+ skip_if_using_serial_implementation,
+ skip_if_windows,
+)
+
+
+class BaseDownloadTest(BaseGeneralInterfaceTest):
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(max_request_concurrency=1)
+ self._manager = TransferManager(self.client, self.config)
+
+ # Create a temporary directory to write to
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ # Initialize some default arguments
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # Create a stream to read from
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ @property
+ def manager(self):
+ return self._manager
+
+ @property
+ def method(self):
+ return self.manager.download
+
+ def create_call_kwargs(self):
+ return {
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'fileobj': self.filename,
+ }
+
+ def create_invalid_extra_args(self):
+ return {'Foo': 'bar'}
+
+ def create_stubbed_responses(self):
+ # We want to make sure the beginning of the stream is always used
+ # in case this gets called twice.
+ self.stream.seek(0)
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': self.stream},
+ },
+ ]
+
+ def create_expected_progress_callback_info(self):
+ # Note that last read is from the empty sentinel indicating
+ # that the stream is done.
+ return [{'bytes_transferred': 10}]
+
+ def add_head_object_response(self, expected_params=None):
+ head_response = self.create_stubbed_responses()[0]
+ if expected_params:
+ head_response['expected_params'] = expected_params
+ self.stubber.add_response(**head_response)
+
+ def add_successful_get_object_responses(
+ self, expected_params=None, expected_ranges=None
+ ):
+ # Add all get_object responses needed to complete the download.
+ # Should account for both ranged and nonranged downloads.
+ for i, stubbed_response in enumerate(
+ self.create_stubbed_responses()[1:]
+ ):
+ if expected_params:
+ stubbed_response['expected_params'] = copy.deepcopy(
+ expected_params
+ )
+ if expected_ranges:
+ stubbed_response['expected_params'][
+ 'Range'
+ ] = expected_ranges[i]
+ self.stubber.add_response(**stubbed_response)
+
+ def add_n_retryable_get_object_responses(self, n, num_reads=0):
+ for _ in range(n):
+ self.stubber.add_response(
+ method='get_object',
+ service_response={
+ 'Body': StreamWithError(
+ copy.deepcopy(self.stream), SOCKET_ERROR, num_reads
+ )
+ },
+ )
+
+ def test_download_temporary_file_does_not_exist(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ future = self.manager.download(**self.create_call_kwargs())
+ future.result()
+ # Make sure the file exists
+ self.assertTrue(os.path.exists(self.filename))
+ # Make sure the random temporary file does not exist
+ possible_matches = glob.glob('%s*' % self.filename + os.extsep)
+ self.assertEqual(possible_matches, [])
+
+ def test_download_for_fileobj(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ with open(self.filename, 'wb') as f:
+ future = self.manager.download(
+ self.bucket, self.key, f, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_download_for_seekable_filelike_obj(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ # Create a file-like object to test. In this case, it is a BytesIO
+ # object.
+ bytes_io = BytesIO()
+
+ future = self.manager.download(
+ self.bucket, self.key, bytes_io, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ bytes_io.seek(0)
+ self.assertEqual(self.content, bytes_io.read())
+
+ def test_download_for_nonseekable_filelike_obj(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ with open(self.filename, 'wb') as f:
+ future = self.manager.download(
+ self.bucket, self.key, NonSeekableWriter(f), self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_download_cleanup_on_failure(self):
+ self.add_head_object_response()
+
+ # Throw an error on the download
+ self.stubber.add_client_error('get_object')
+
+ future = self.manager.download(**self.create_call_kwargs())
+
+ with self.assertRaises(ClientError):
+ future.result()
+ # Make sure the actual file and the temporary do not exist
+ # by globbing for the file and any of its extensions
+ possible_matches = glob.glob('%s*' % self.filename)
+ self.assertEqual(possible_matches, [])
+
+ def test_download_with_nonexistent_directory(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['fileobj'] = os.path.join(
+ self.tempdir, 'missing-directory', 'myfile'
+ )
+ future = self.manager.download(**call_kwargs)
+ with self.assertRaises(IOError):
+ future.result()
+
+ def test_retries_and_succeeds(self):
+ self.add_head_object_response()
+ # Insert a response that will trigger a retry.
+ self.add_n_retryable_get_object_responses(1)
+ # Add the normal responses to simulate the download proceeding
+ # as normal after the retry.
+ self.add_successful_get_object_responses()
+
+ future = self.manager.download(**self.create_call_kwargs())
+ future.result()
+
+ # The retry should have been consumed and the process should have
+ # continued using the successful responses.
+ self.stubber.assert_no_pending_responses()
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_retry_failure(self):
+ self.add_head_object_response()
+
+ max_retries = 3
+ self.config.num_download_attempts = max_retries
+ self._manager = TransferManager(self.client, self.config)
+ # Add responses that fill up the maximum number of retries.
+ self.add_n_retryable_get_object_responses(max_retries)
+
+ future = self.manager.download(**self.create_call_kwargs())
+
+ # A retry exceeded error should have happened.
+ with self.assertRaises(RetriesExceededError):
+ future.result()
+
+ # All of the retries should have been used up.
+ self.stubber.assert_no_pending_responses()
+
+ def test_retry_rewinds_callbacks(self):
+ self.add_head_object_response()
+ # Insert a response that will trigger a retry after one read of the
+ # stream has been made.
+ self.add_n_retryable_get_object_responses(1, num_reads=1)
+ # Add the normal responses to simulate the download proceeding
+ # as normal after the retry.
+ self.add_successful_get_object_responses()
+
+ recorder_subscriber = RecordingSubscriber()
+ # Set the streaming to a size that is smaller than the data we
+ # currently provide to it to simulate rewinds of callbacks.
+ self.config.io_chunksize = 3
+ future = self.manager.download(
+ subscribers=[recorder_subscriber], **self.create_call_kwargs()
+ )
+ future.result()
+
+ # Ensure that there is no more remaining responses and that contents
+ # are correct.
+ self.stubber.assert_no_pending_responses()
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ # Assert that the number of bytes seen is equal to the length of
+ # downloaded content.
+ self.assertEqual(
+ recorder_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ # Also ensure that the second progress invocation was negative three
+ # because a retry happened on the second read of the stream and we
+ # know that the chunk size for each read is 3.
+ progress_byte_amts = [
+ call['bytes_transferred']
+ for call in recorder_subscriber.on_progress_calls
+ ]
+ self.assertEqual(-3, progress_byte_amts[1])
+
+ def test_can_provide_file_size(self):
+ self.add_successful_get_object_responses()
+
+ call_kwargs = self.create_call_kwargs()
+ call_kwargs['subscribers'] = [FileSizeProvider(len(self.content))]
+
+ future = self.manager.download(**call_kwargs)
+ future.result()
+
+ # The HeadObject should have not happened and should have been able
+ # to successfully download the file.
+ self.stubber.assert_no_pending_responses()
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_uses_provided_osutil(self):
+ osutil = RecordingOSUtils()
+ # Use the recording os utility for the transfer manager
+ self._manager = TransferManager(self.client, self.config, osutil)
+
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ future = self.manager.download(**self.create_call_kwargs())
+ future.result()
+ # The osutil should have had its open() method invoked when opening
+ # a temporary file and its rename_file() method invoked when the
+ # the temporary file was moved to its final location.
+ self.assertEqual(len(osutil.open_records), 1)
+ self.assertEqual(len(osutil.rename_records), 1)
+
+ @skip_if_windows('Windows does not support UNIX special files')
+ @skip_if_using_serial_implementation(
+ 'A separate thread is needed to read from the fifo'
+ )
+ def test_download_for_fifo_file(self):
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ # Create the fifo file
+ os.mkfifo(self.filename)
+
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+
+ # The call to open a fifo will block until there is both a reader
+ # and a writer, so we need to open it for reading after we've
+ # started the transfer.
+ with open(self.filename, 'rb') as fifo:
+ future.result()
+ self.assertEqual(fifo.read(), self.content)
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.download(
+ s3_object_lambda_arn, self.key, self.filename, self.extra_args
+ )
+
+
+class TestNonRangedDownload(BaseDownloadTest):
+ # TODO: If you want to add tests outside of this test class and still
+ # subclass from BaseDownloadTest you need to set ``__test__ = True``. If
+ # you do not, your tests will not get picked up by the test runner! This
+ # needs to be done until we find a better way to ignore running test cases
+ # from the general test base class, which we do not want ran.
+ __test__ = True
+
+ def test_download(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'RequestPayer': 'requester',
+ }
+ self.add_head_object_response(expected_params)
+ self.add_successful_get_object_responses(expected_params)
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+ def test_allowed_copy_params_are_valid(self):
+ op_model = self.client.meta.service_model.operation_model('GetObject')
+ for allowed_upload_arg in self._manager.ALLOWED_DOWNLOAD_ARGS:
+ self.assertIn(allowed_upload_arg, op_model.input_shape.members)
+
+ def test_download_empty_object(self):
+ self.content = b''
+ self.stream = BytesIO(self.content)
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the empty file exists
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(b'', f.read())
+
+ def test_uses_bandwidth_limiter(self):
+ self.content = b'a' * 1024 * 1024
+ self.stream = BytesIO(self.content)
+ self.config = TransferConfig(
+ max_request_concurrency=1, max_bandwidth=len(self.content) / 2
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ self.add_head_object_response()
+ self.add_successful_get_object_responses()
+
+ start = time.time()
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+ # This is just a smoke test to make sure that the limiter is
+ # being used and not necessary its exactness. So we set the maximum
+ # bandwidth to len(content)/2 per sec and make sure that it is
+ # noticeably slower. Ideally it will take more than two seconds, but
+ # given tracking at the beginning of transfers are not entirely
+ # accurate setting at the initial start of a transfer, we give us
+ # some flexibility by setting the expected time to half of the
+ # theoretical time to take.
+ self.assertGreaterEqual(time.time() - start, 1)
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
+
+
+class TestRangedDownload(BaseDownloadTest):
+ # TODO: If you want to add tests outside of this test class and still
+ # subclass from BaseDownloadTest you need to set ``__test__ = True``. If
+ # you do not, your tests will not get picked up by the test runner! This
+ # needs to be done until we find a better way to ignore running test cases
+ # from the general test base class, which we do not want ran.
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_threshold=1,
+ multipart_chunksize=4,
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'head_object',
+ 'service_response': {'ContentLength': len(self.content)},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': BytesIO(self.content[0:4])},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': BytesIO(self.content[4:8])},
+ },
+ {
+ 'method': 'get_object',
+ 'service_response': {'Body': BytesIO(self.content[8:])},
+ },
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return [
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 2},
+ ]
+
+ def test_download(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'RequestPayer': 'requester',
+ }
+ expected_ranges = ['bytes=0-3', 'bytes=4-7', 'bytes=8-']
+ self.add_head_object_response(expected_params)
+ self.add_successful_get_object_responses(
+ expected_params, expected_ranges
+ )
+
+ future = self.manager.download(
+ self.bucket, self.key, self.filename, self.extra_args
+ )
+ future.result()
+
+ # Ensure that the contents are correct
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(self.content, f.read())
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_manager.py b/contrib/python/s3transfer/py3/tests/functional/test_manager.py
index bde2c10201..1c980e7bc6 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_manager.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_manager.py
@@ -1,191 +1,191 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License'). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the 'license' file accompanying this file. This file is
-# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from io import BytesIO
-
-from botocore.awsrequest import create_request_object
-
-from s3transfer.exceptions import CancelledError, FatalError
-from s3transfer.futures import BaseExecutor
-from s3transfer.manager import TransferConfig, TransferManager
-from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation
-
-
-class ArbitraryException(Exception):
- pass
-
-
-class SignalTransferringBody(BytesIO):
- """A mocked body with the ability to signal when transfers occur"""
-
- def __init__(self):
- super().__init__()
- self.signal_transferring_call_count = 0
- self.signal_not_transferring_call_count = 0
-
- def signal_transferring(self):
- self.signal_transferring_call_count += 1
-
- def signal_not_transferring(self):
- self.signal_not_transferring_call_count += 1
-
- def seek(self, where, whence=0):
- pass
-
- def tell(self):
- return 0
-
- def read(self, amount=0):
- return b''
-
-
-class TestTransferManager(StubbedClientTest):
- @skip_if_using_serial_implementation(
- 'Exception is thrown once all transfers are submitted. '
- 'However for the serial implementation, transfers are performed '
- 'in main thread meaning all transfers will complete before the '
- 'exception being thrown.'
- )
- def test_error_in_context_manager_cancels_incomplete_transfers(self):
- # The purpose of this test is to make sure if an error is raised
- # in the body of the context manager, incomplete transfers will
- # be cancelled with value of the exception wrapped by a CancelledError
-
- # NOTE: The fact that delete() was chosen to test this is arbitrary
- # other than it is the easiet to set up for the stubber.
- # The specific operation is not important to the purpose of this test.
- num_transfers = 100
- futures = []
- ref_exception_msg = 'arbitrary exception'
-
- for _ in range(num_transfers):
- self.stubber.add_response('delete_object', {})
-
- manager = TransferManager(
- self.client,
- TransferConfig(
- max_request_concurrency=1, max_submission_concurrency=1
- ),
- )
- try:
- with manager:
- for i in range(num_transfers):
- futures.append(manager.delete('mybucket', 'mykey'))
- raise ArbitraryException(ref_exception_msg)
- except ArbitraryException:
- # At least one of the submitted futures should have been
- # cancelled.
- with self.assertRaisesRegex(FatalError, ref_exception_msg):
- for future in futures:
- future.result()
-
- @skip_if_using_serial_implementation(
- 'Exception is thrown once all transfers are submitted. '
- 'However for the serial implementation, transfers are performed '
- 'in main thread meaning all transfers will complete before the '
- 'exception being thrown.'
- )
- def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self):
- # The purpose of this test is to make sure if an error is raised
- # in the body of the context manager, incomplete transfers will
- # be cancelled with value of the exception wrapped by a CancelledError
-
- # NOTE: The fact that delete() was chosen to test this is arbitrary
- # other than it is the easiet to set up for the stubber.
- # The specific operation is not important to the purpose of this test.
- num_transfers = 100
- futures = []
-
- for _ in range(num_transfers):
- self.stubber.add_response('delete_object', {})
-
- manager = TransferManager(
- self.client,
- TransferConfig(
- max_request_concurrency=1, max_submission_concurrency=1
- ),
- )
- try:
- with manager:
- for i in range(num_transfers):
- futures.append(manager.delete('mybucket', 'mykey'))
- raise KeyboardInterrupt()
- except KeyboardInterrupt:
- # At least one of the submitted futures should have been
- # cancelled.
- with self.assertRaisesRegex(CancelledError, 'KeyboardInterrupt()'):
- for future in futures:
- future.result()
-
- def test_enable_disable_callbacks_only_ever_registered_once(self):
- body = SignalTransferringBody()
- request = create_request_object(
- {
- 'method': 'PUT',
- 'url': 'https://s3.amazonaws.com',
- 'body': body,
- 'headers': {},
- 'context': {},
- }
- )
- # Create two TransferManager's using the same client
- TransferManager(self.client)
- TransferManager(self.client)
- self.client.meta.events.emit(
- 'request-created.s3', request=request, operation_name='PutObject'
- )
- # The client should have only have the enable/disable callback
- # handlers registered once depite being used for two different
- # TransferManagers.
- self.assertEqual(
- body.signal_transferring_call_count,
- 1,
- 'The enable_callback() should have only ever been registered once',
- )
- self.assertEqual(
- body.signal_not_transferring_call_count,
- 1,
- 'The disable_callback() should have only ever been registered '
- 'once',
- )
-
- def test_use_custom_executor_implementation(self):
- mocked_executor_cls = mock.Mock(BaseExecutor)
- transfer_manager = TransferManager(
- self.client, executor_cls=mocked_executor_cls
- )
- transfer_manager.delete('bucket', 'key')
- self.assertTrue(mocked_executor_cls.return_value.submit.called)
-
- def test_unicode_exception_in_context_manager(self):
- with self.assertRaises(ArbitraryException):
- with TransferManager(self.client):
- raise ArbitraryException('\u2713')
-
- def test_client_property(self):
- manager = TransferManager(self.client)
- self.assertIs(manager.client, self.client)
-
- def test_config_property(self):
- config = TransferConfig()
- manager = TransferManager(self.client, config)
- self.assertIs(manager.config, config)
-
- def test_can_disable_bucket_validation(self):
- s3_object_lambda_arn = (
- 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
- 'accesspoint:my-accesspoint'
- )
- config = TransferConfig()
- manager = TransferManager(self.client, config)
- manager.VALIDATE_SUPPORTED_BUCKET_VALUES = False
- manager.delete(s3_object_lambda_arn, 'my-key')
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from io import BytesIO
+
+from botocore.awsrequest import create_request_object
+
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import BaseExecutor
+from s3transfer.manager import TransferConfig, TransferManager
+from __tests__ import StubbedClientTest, mock, skip_if_using_serial_implementation
+
+
+class ArbitraryException(Exception):
+ pass
+
+
+class SignalTransferringBody(BytesIO):
+ """A mocked body with the ability to signal when transfers occur"""
+
+ def __init__(self):
+ super().__init__()
+ self.signal_transferring_call_count = 0
+ self.signal_not_transferring_call_count = 0
+
+ def signal_transferring(self):
+ self.signal_transferring_call_count += 1
+
+ def signal_not_transferring(self):
+ self.signal_not_transferring_call_count += 1
+
+ def seek(self, where, whence=0):
+ pass
+
+ def tell(self):
+ return 0
+
+ def read(self, amount=0):
+ return b''
+
+
+class TestTransferManager(StubbedClientTest):
+ @skip_if_using_serial_implementation(
+ 'Exception is thrown once all transfers are submitted. '
+ 'However for the serial implementation, transfers are performed '
+ 'in main thread meaning all transfers will complete before the '
+ 'exception being thrown.'
+ )
+ def test_error_in_context_manager_cancels_incomplete_transfers(self):
+ # The purpose of this test is to make sure if an error is raised
+ # in the body of the context manager, incomplete transfers will
+ # be cancelled with value of the exception wrapped by a CancelledError
+
+ # NOTE: The fact that delete() was chosen to test this is arbitrary
+ # other than it is the easiet to set up for the stubber.
+ # The specific operation is not important to the purpose of this test.
+ num_transfers = 100
+ futures = []
+ ref_exception_msg = 'arbitrary exception'
+
+ for _ in range(num_transfers):
+ self.stubber.add_response('delete_object', {})
+
+ manager = TransferManager(
+ self.client,
+ TransferConfig(
+ max_request_concurrency=1, max_submission_concurrency=1
+ ),
+ )
+ try:
+ with manager:
+ for i in range(num_transfers):
+ futures.append(manager.delete('mybucket', 'mykey'))
+ raise ArbitraryException(ref_exception_msg)
+ except ArbitraryException:
+ # At least one of the submitted futures should have been
+ # cancelled.
+ with self.assertRaisesRegex(FatalError, ref_exception_msg):
+ for future in futures:
+ future.result()
+
+ @skip_if_using_serial_implementation(
+ 'Exception is thrown once all transfers are submitted. '
+ 'However for the serial implementation, transfers are performed '
+ 'in main thread meaning all transfers will complete before the '
+ 'exception being thrown.'
+ )
+ def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self):
+ # The purpose of this test is to make sure if an error is raised
+ # in the body of the context manager, incomplete transfers will
+ # be cancelled with value of the exception wrapped by a CancelledError
+
+ # NOTE: The fact that delete() was chosen to test this is arbitrary
+ # other than it is the easiet to set up for the stubber.
+ # The specific operation is not important to the purpose of this test.
+ num_transfers = 100
+ futures = []
+
+ for _ in range(num_transfers):
+ self.stubber.add_response('delete_object', {})
+
+ manager = TransferManager(
+ self.client,
+ TransferConfig(
+ max_request_concurrency=1, max_submission_concurrency=1
+ ),
+ )
+ try:
+ with manager:
+ for i in range(num_transfers):
+ futures.append(manager.delete('mybucket', 'mykey'))
+ raise KeyboardInterrupt()
+ except KeyboardInterrupt:
+ # At least one of the submitted futures should have been
+ # cancelled.
+ with self.assertRaisesRegex(CancelledError, 'KeyboardInterrupt()'):
+ for future in futures:
+ future.result()
+
+ def test_enable_disable_callbacks_only_ever_registered_once(self):
+ body = SignalTransferringBody()
+ request = create_request_object(
+ {
+ 'method': 'PUT',
+ 'url': 'https://s3.amazonaws.com',
+ 'body': body,
+ 'headers': {},
+ 'context': {},
+ }
+ )
+ # Create two TransferManager's using the same client
+ TransferManager(self.client)
+ TransferManager(self.client)
+ self.client.meta.events.emit(
+ 'request-created.s3', request=request, operation_name='PutObject'
+ )
+ # The client should have only have the enable/disable callback
+ # handlers registered once depite being used for two different
+ # TransferManagers.
+ self.assertEqual(
+ body.signal_transferring_call_count,
+ 1,
+ 'The enable_callback() should have only ever been registered once',
+ )
+ self.assertEqual(
+ body.signal_not_transferring_call_count,
+ 1,
+ 'The disable_callback() should have only ever been registered '
+ 'once',
+ )
+
+ def test_use_custom_executor_implementation(self):
+ mocked_executor_cls = mock.Mock(BaseExecutor)
+ transfer_manager = TransferManager(
+ self.client, executor_cls=mocked_executor_cls
+ )
+ transfer_manager.delete('bucket', 'key')
+ self.assertTrue(mocked_executor_cls.return_value.submit.called)
+
+ def test_unicode_exception_in_context_manager(self):
+ with self.assertRaises(ArbitraryException):
+ with TransferManager(self.client):
+ raise ArbitraryException('\u2713')
+
+ def test_client_property(self):
+ manager = TransferManager(self.client)
+ self.assertIs(manager.client, self.client)
+
+ def test_config_property(self):
+ config = TransferConfig()
+ manager = TransferManager(self.client, config)
+ self.assertIs(manager.config, config)
+
+ def test_can_disable_bucket_validation(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ config = TransferConfig()
+ manager = TransferManager(self.client, config)
+ manager.VALIDATE_SUPPORTED_BUCKET_VALUES = False
+ manager.delete(s3_object_lambda_arn, 'my-key')
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_processpool.py b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py
index d347efa869..1396c919f2 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_processpool.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_processpool.py
@@ -1,281 +1,281 @@
-# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import glob
-import os
-from io import BytesIO
-from multiprocessing.managers import BaseManager
-
-import botocore.exceptions
-import botocore.session
-from botocore.stub import Stubber
-
-from s3transfer.exceptions import CancelledError
-from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
-from __tests__ import FileCreator, mock, unittest
-
-
-class StubbedClient:
- def __init__(self):
- self._client = botocore.session.get_session().create_client(
- 's3',
- 'us-west-2',
- aws_access_key_id='foo',
- aws_secret_access_key='bar',
- )
- self._stubber = Stubber(self._client)
- self._stubber.activate()
- self._caught_stubber_errors = []
-
- def get_object(self, **kwargs):
- return self._client.get_object(**kwargs)
-
- def head_object(self, **kwargs):
- return self._client.head_object(**kwargs)
-
- def add_response(self, *args, **kwargs):
- self._stubber.add_response(*args, **kwargs)
-
- def add_client_error(self, *args, **kwargs):
- self._stubber.add_client_error(*args, **kwargs)
-
-
-class StubbedClientManager(BaseManager):
- pass
-
-
-StubbedClientManager.register('StubbedClient', StubbedClient)
-
-
-# Ideally a Mock would be used here. However, they cannot be pickled
-# for Windows. So instead we define a factory class at the module level that
-# can return a stubbed client we initialized in the setUp.
-class StubbedClientFactory:
- def __init__(self, stubbed_client):
- self._stubbed_client = stubbed_client
-
- def __call__(self, *args, **kwargs):
- # The __call__ is defined so we can provide an instance of the
- # StubbedClientFactory to mock.patch() and have the instance be
- # returned when the patched class is instantiated.
- return self
-
- def create_client(self):
- return self._stubbed_client
-
-
-class TestProcessPoolDownloader(unittest.TestCase):
- def setUp(self):
- # The stubbed client needs to run in a manager to be shared across
- # processes and have it properly consume the stubbed response across
- # processes.
- self.manager = StubbedClientManager()
- self.manager.start()
- self.stubbed_client = self.manager.StubbedClient()
- self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client)
-
- self.client_factory_patch = mock.patch(
- 's3transfer.processpool.ClientFactory', self.stubbed_client_factory
- )
- self.client_factory_patch.start()
- self.files = FileCreator()
-
- self.config = ProcessTransferConfig(max_request_processes=1)
- self.downloader = ProcessPoolDownloader(config=self.config)
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.filename = self.files.full_path('filename')
- self.remote_contents = b'my content'
- self.stream = BytesIO(self.remote_contents)
-
- def tearDown(self):
- self.manager.shutdown()
- self.client_factory_patch.stop()
- self.files.remove_all()
-
- def assert_contents(self, filename, expected_contents):
- self.assertTrue(os.path.exists(filename))
- with open(filename, 'rb') as f:
- self.assertEqual(f.read(), expected_contents)
-
- def test_download_file(self):
- self.stubbed_client.add_response(
- 'head_object', {'ContentLength': len(self.remote_contents)}
- )
- self.stubbed_client.add_response('get_object', {'Body': self.stream})
- with self.downloader:
- self.downloader.download_file(self.bucket, self.key, self.filename)
- self.assert_contents(self.filename, self.remote_contents)
-
- def test_download_multiple_files(self):
- self.stubbed_client.add_response('get_object', {'Body': self.stream})
- self.stubbed_client.add_response(
- 'get_object', {'Body': BytesIO(self.remote_contents)}
- )
- with self.downloader:
- self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- expected_size=len(self.remote_contents),
- )
- other_file = self.files.full_path('filename2')
- self.downloader.download_file(
- self.bucket,
- self.key,
- other_file,
- expected_size=len(self.remote_contents),
- )
- self.assert_contents(self.filename, self.remote_contents)
- self.assert_contents(other_file, self.remote_contents)
-
- def test_download_file_ranged_download(self):
- half_of_content_length = int(len(self.remote_contents) / 2)
- self.stubbed_client.add_response(
- 'head_object', {'ContentLength': len(self.remote_contents)}
- )
- self.stubbed_client.add_response(
- 'get_object',
- {'Body': BytesIO(self.remote_contents[:half_of_content_length])},
- )
- self.stubbed_client.add_response(
- 'get_object',
- {'Body': BytesIO(self.remote_contents[half_of_content_length:])},
- )
- downloader = ProcessPoolDownloader(
- config=ProcessTransferConfig(
- multipart_chunksize=half_of_content_length,
- multipart_threshold=half_of_content_length,
- max_request_processes=1,
- )
- )
- with downloader:
- downloader.download_file(self.bucket, self.key, self.filename)
- self.assert_contents(self.filename, self.remote_contents)
-
- def test_download_file_extra_args(self):
- self.stubbed_client.add_response(
- 'head_object',
- {'ContentLength': len(self.remote_contents)},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'VersionId': 'versionid',
- },
- )
- self.stubbed_client.add_response(
- 'get_object',
- {'Body': self.stream},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'VersionId': 'versionid',
- },
- )
- with self.downloader:
- self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- extra_args={'VersionId': 'versionid'},
- )
- self.assert_contents(self.filename, self.remote_contents)
-
- def test_download_file_expected_size(self):
- self.stubbed_client.add_response('get_object', {'Body': self.stream})
- with self.downloader:
- self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- expected_size=len(self.remote_contents),
- )
- self.assert_contents(self.filename, self.remote_contents)
-
- def test_cleans_up_tempfile_on_failure(self):
- self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
- with self.downloader:
- self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- expected_size=len(self.remote_contents),
- )
- self.assertFalse(os.path.exists(self.filename))
- # Any tempfile should have been erased as well
- possible_matches = glob.glob('%s*' % self.filename + os.extsep)
- self.assertEqual(possible_matches, [])
-
- def test_validates_extra_args(self):
- with self.downloader:
- with self.assertRaises(ValueError):
- self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- extra_args={'NotSupported': 'NotSupported'},
- )
-
- def test_result_with_success(self):
- self.stubbed_client.add_response('get_object', {'Body': self.stream})
- with self.downloader:
- future = self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- expected_size=len(self.remote_contents),
- )
- self.assertIsNone(future.result())
-
- def test_result_with_exception(self):
- self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
- with self.downloader:
- future = self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- expected_size=len(self.remote_contents),
- )
- with self.assertRaises(botocore.exceptions.ClientError):
- future.result()
-
- def test_result_with_cancel(self):
- self.stubbed_client.add_response('get_object', {'Body': self.stream})
- with self.downloader:
- future = self.downloader.download_file(
- self.bucket,
- self.key,
- self.filename,
- expected_size=len(self.remote_contents),
- )
- future.cancel()
- with self.assertRaises(CancelledError):
- future.result()
-
- def test_shutdown_with_no_downloads(self):
- downloader = ProcessPoolDownloader()
- try:
- downloader.shutdown()
- except AttributeError:
- self.fail(
- 'The downloader should be able to be shutdown even though '
- 'the downloader was never started.'
- )
-
- def test_shutdown_with_no_downloads_and_ctrl_c(self):
- # Special shutdown logic happens if a KeyboardInterrupt is raised in
- # the context manager. However, this logic can not happen if the
- # downloader was never started. So a KeyboardInterrupt should be
- # the only exception propagated.
- with self.assertRaises(KeyboardInterrupt):
- with self.downloader:
- raise KeyboardInterrupt()
+# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import glob
+import os
+from io import BytesIO
+from multiprocessing.managers import BaseManager
+
+import botocore.exceptions
+import botocore.session
+from botocore.stub import Stubber
+
+from s3transfer.exceptions import CancelledError
+from s3transfer.processpool import ProcessPoolDownloader, ProcessTransferConfig
+from __tests__ import FileCreator, mock, unittest
+
+
+class StubbedClient:
+ def __init__(self):
+ self._client = botocore.session.get_session().create_client(
+ 's3',
+ 'us-west-2',
+ aws_access_key_id='foo',
+ aws_secret_access_key='bar',
+ )
+ self._stubber = Stubber(self._client)
+ self._stubber.activate()
+ self._caught_stubber_errors = []
+
+ def get_object(self, **kwargs):
+ return self._client.get_object(**kwargs)
+
+ def head_object(self, **kwargs):
+ return self._client.head_object(**kwargs)
+
+ def add_response(self, *args, **kwargs):
+ self._stubber.add_response(*args, **kwargs)
+
+ def add_client_error(self, *args, **kwargs):
+ self._stubber.add_client_error(*args, **kwargs)
+
+
+class StubbedClientManager(BaseManager):
+ pass
+
+
+StubbedClientManager.register('StubbedClient', StubbedClient)
+
+
+# Ideally a Mock would be used here. However, they cannot be pickled
+# for Windows. So instead we define a factory class at the module level that
+# can return a stubbed client we initialized in the setUp.
+class StubbedClientFactory:
+ def __init__(self, stubbed_client):
+ self._stubbed_client = stubbed_client
+
+ def __call__(self, *args, **kwargs):
+ # The __call__ is defined so we can provide an instance of the
+ # StubbedClientFactory to mock.patch() and have the instance be
+ # returned when the patched class is instantiated.
+ return self
+
+ def create_client(self):
+ return self._stubbed_client
+
+
+class TestProcessPoolDownloader(unittest.TestCase):
+ def setUp(self):
+ # The stubbed client needs to run in a manager to be shared across
+ # processes and have it properly consume the stubbed response across
+ # processes.
+ self.manager = StubbedClientManager()
+ self.manager.start()
+ self.stubbed_client = self.manager.StubbedClient()
+ self.stubbed_client_factory = StubbedClientFactory(self.stubbed_client)
+
+ self.client_factory_patch = mock.patch(
+ 's3transfer.processpool.ClientFactory', self.stubbed_client_factory
+ )
+ self.client_factory_patch.start()
+ self.files = FileCreator()
+
+ self.config = ProcessTransferConfig(max_request_processes=1)
+ self.downloader = ProcessPoolDownloader(config=self.config)
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.filename = self.files.full_path('filename')
+ self.remote_contents = b'my content'
+ self.stream = BytesIO(self.remote_contents)
+
+ def tearDown(self):
+ self.manager.shutdown()
+ self.client_factory_patch.stop()
+ self.files.remove_all()
+
+ def assert_contents(self, filename, expected_contents):
+ self.assertTrue(os.path.exists(filename))
+ with open(filename, 'rb') as f:
+ self.assertEqual(f.read(), expected_contents)
+
+ def test_download_file(self):
+ self.stubbed_client.add_response(
+ 'head_object', {'ContentLength': len(self.remote_contents)}
+ )
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ self.downloader.download_file(self.bucket, self.key, self.filename)
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_download_multiple_files(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ self.stubbed_client.add_response(
+ 'get_object', {'Body': BytesIO(self.remote_contents)}
+ )
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ other_file = self.files.full_path('filename2')
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ other_file,
+ expected_size=len(self.remote_contents),
+ )
+ self.assert_contents(self.filename, self.remote_contents)
+ self.assert_contents(other_file, self.remote_contents)
+
+ def test_download_file_ranged_download(self):
+ half_of_content_length = int(len(self.remote_contents) / 2)
+ self.stubbed_client.add_response(
+ 'head_object', {'ContentLength': len(self.remote_contents)}
+ )
+ self.stubbed_client.add_response(
+ 'get_object',
+ {'Body': BytesIO(self.remote_contents[:half_of_content_length])},
+ )
+ self.stubbed_client.add_response(
+ 'get_object',
+ {'Body': BytesIO(self.remote_contents[half_of_content_length:])},
+ )
+ downloader = ProcessPoolDownloader(
+ config=ProcessTransferConfig(
+ multipart_chunksize=half_of_content_length,
+ multipart_threshold=half_of_content_length,
+ max_request_processes=1,
+ )
+ )
+ with downloader:
+ downloader.download_file(self.bucket, self.key, self.filename)
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_download_file_extra_args(self):
+ self.stubbed_client.add_response(
+ 'head_object',
+ {'ContentLength': len(self.remote_contents)},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ self.stubbed_client.add_response(
+ 'get_object',
+ {'Body': self.stream},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ extra_args={'VersionId': 'versionid'},
+ )
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_download_file_expected_size(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ self.assert_contents(self.filename, self.remote_contents)
+
+ def test_cleans_up_tempfile_on_failure(self):
+ self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
+ with self.downloader:
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ self.assertFalse(os.path.exists(self.filename))
+ # Any tempfile should have been erased as well
+ possible_matches = glob.glob('%s*' % self.filename + os.extsep)
+ self.assertEqual(possible_matches, [])
+
+ def test_validates_extra_args(self):
+ with self.downloader:
+ with self.assertRaises(ValueError):
+ self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ extra_args={'NotSupported': 'NotSupported'},
+ )
+
+ def test_result_with_success(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ future = self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ self.assertIsNone(future.result())
+
+ def test_result_with_exception(self):
+ self.stubbed_client.add_client_error('get_object', 'NoSuchKey')
+ with self.downloader:
+ future = self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ with self.assertRaises(botocore.exceptions.ClientError):
+ future.result()
+
+ def test_result_with_cancel(self):
+ self.stubbed_client.add_response('get_object', {'Body': self.stream})
+ with self.downloader:
+ future = self.downloader.download_file(
+ self.bucket,
+ self.key,
+ self.filename,
+ expected_size=len(self.remote_contents),
+ )
+ future.cancel()
+ with self.assertRaises(CancelledError):
+ future.result()
+
+ def test_shutdown_with_no_downloads(self):
+ downloader = ProcessPoolDownloader()
+ try:
+ downloader.shutdown()
+ except AttributeError:
+ self.fail(
+ 'The downloader should be able to be shutdown even though '
+ 'the downloader was never started.'
+ )
+
+ def test_shutdown_with_no_downloads_and_ctrl_c(self):
+ # Special shutdown logic happens if a KeyboardInterrupt is raised in
+ # the context manager. However, this logic can not happen if the
+ # downloader was never started. So a KeyboardInterrupt should be
+ # the only exception propagated.
+ with self.assertRaises(KeyboardInterrupt):
+ with self.downloader:
+ raise KeyboardInterrupt()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_upload.py b/contrib/python/s3transfer/py3/tests/functional/test_upload.py
index dcb2a65241..4f294e85ad 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_upload.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_upload.py
@@ -1,538 +1,538 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import shutil
-import tempfile
-import time
-from io import BytesIO
-
-from botocore.awsrequest import AWSRequest
-from botocore.client import Config
-from botocore.exceptions import ClientError
-from botocore.stub import ANY
-
-from s3transfer.manager import TransferConfig, TransferManager
-from s3transfer.utils import ChunksizeAdjuster
-from __tests__ import (
- BaseGeneralInterfaceTest,
- NonSeekableReader,
- RecordingOSUtils,
- RecordingSubscriber,
- mock,
-)
-
-
-class BaseUploadTest(BaseGeneralInterfaceTest):
- def setUp(self):
- super().setUp()
- # TODO: We do not want to use the real MIN_UPLOAD_CHUNKSIZE
- # when we're adjusting parts.
- # This is really wasteful and fails CI builds because self.contents
- # would normally use 10MB+ of memory.
- # Until there's an API to configure this, we're patching this with
- # a min size of 1. We can't patch MIN_UPLOAD_CHUNKSIZE directly
- # because it's already bound to a default value in the
- # chunksize adjuster. Instead we need to patch out the
- # chunksize adjuster class.
- self.adjuster_patch = mock.patch(
- 's3transfer.upload.ChunksizeAdjuster',
- lambda: ChunksizeAdjuster(min_size=1),
- )
- self.adjuster_patch.start()
- self.config = TransferConfig(max_request_concurrency=1)
- self._manager = TransferManager(self.client, self.config)
-
- # Create a temporary directory with files to read from
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
- self.content = b'my content'
-
- with open(self.filename, 'wb') as f:
- f.write(self.content)
-
- # Initialize some default arguments
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.extra_args = {}
- self.subscribers = []
-
- # A list to keep track of all of the bodies sent over the wire
- # and their order.
- self.sent_bodies = []
- self.client.meta.events.register(
- 'before-parameter-build.s3.*', self.collect_body
- )
-
- def tearDown(self):
- super().tearDown()
- shutil.rmtree(self.tempdir)
- self.adjuster_patch.stop()
-
- def collect_body(self, params, model, **kwargs):
- # A handler to simulate the reading of the body including the
- # request-created event that signals to simulate the progress
- # callbacks
- if 'Body' in params:
- # TODO: This is not ideal. Need to figure out a better idea of
- # simulating reading of the request across the wire to trigger
- # progress callbacks
- request = AWSRequest(
- method='PUT',
- url='https://s3.amazonaws.com',
- data=params['Body'],
- )
- self.client.meta.events.emit(
- 'request-created.s3.%s' % model.name,
- request=request,
- operation_name=model.name,
- )
- self.sent_bodies.append(self._stream_body(params['Body']))
-
- def _stream_body(self, body):
- read_amt = 8 * 1024
- data = body.read(read_amt)
- collected_body = data
- while data:
- data = body.read(read_amt)
- collected_body += data
- return collected_body
-
- @property
- def manager(self):
- return self._manager
-
- @property
- def method(self):
- return self.manager.upload
-
- def create_call_kwargs(self):
- return {
- 'fileobj': self.filename,
- 'bucket': self.bucket,
- 'key': self.key,
- }
-
- def create_invalid_extra_args(self):
- return {'Foo': 'bar'}
-
- def create_stubbed_responses(self):
- return [{'method': 'put_object', 'service_response': {}}]
-
- def create_expected_progress_callback_info(self):
- return [{'bytes_transferred': 10}]
-
- def assert_expected_client_calls_were_correct(self):
- # We assert that expected client calls were made by ensuring that
- # there are no more pending responses. If there are no more pending
- # responses, then all stubbed responses were consumed.
- self.stubber.assert_no_pending_responses()
-
-
-class TestNonMultipartUpload(BaseUploadTest):
- __test__ = True
-
- def add_put_object_response_with_default_expected_params(
- self, extra_expected_params=None
- ):
- expected_params = {'Body': ANY, 'Bucket': self.bucket, 'Key': self.key}
- if extra_expected_params:
- expected_params.update(extra_expected_params)
- upload_response = self.create_stubbed_responses()[0]
- upload_response['expected_params'] = expected_params
- self.stubber.add_response(**upload_response)
-
- def assert_put_object_body_was_correct(self):
- self.assertEqual(self.sent_bodies, [self.content])
-
- def test_upload(self):
- self.extra_args['RequestPayer'] = 'requester'
- self.add_put_object_response_with_default_expected_params(
- extra_expected_params={'RequestPayer': 'requester'}
- )
- future = self.manager.upload(
- self.filename, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_put_object_body_was_correct()
-
- def test_upload_for_fileobj(self):
- self.add_put_object_response_with_default_expected_params()
- with open(self.filename, 'rb') as f:
- future = self.manager.upload(
- f, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_put_object_body_was_correct()
-
- def test_upload_for_seekable_filelike_obj(self):
- self.add_put_object_response_with_default_expected_params()
- bytes_io = BytesIO(self.content)
- future = self.manager.upload(
- bytes_io, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_put_object_body_was_correct()
-
- def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self):
- self.add_put_object_response_with_default_expected_params()
- bytes_io = BytesIO(self.content)
- seek_pos = 5
- bytes_io.seek(seek_pos)
- future = self.manager.upload(
- bytes_io, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:])
-
- def test_upload_for_non_seekable_filelike_obj(self):
- self.add_put_object_response_with_default_expected_params()
- body = NonSeekableReader(self.content)
- future = self.manager.upload(
- body, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_put_object_body_was_correct()
-
- def test_sigv4_progress_callbacks_invoked_once(self):
- # Reset the client and manager to use sigv4
- self.reset_stubber_with_new_client(
- {'config': Config(signature_version='s3v4')}
- )
- self.client.meta.events.register(
- 'before-parameter-build.s3.*', self.collect_body
- )
- self._manager = TransferManager(self.client, self.config)
-
- # Add the stubbed response.
- self.add_put_object_response_with_default_expected_params()
-
- subscriber = RecordingSubscriber()
- future = self.manager.upload(
- self.filename, self.bucket, self.key, subscribers=[subscriber]
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
-
- # The amount of bytes seen should be the same as the file size
- self.assertEqual(subscriber.calculate_bytes_seen(), len(self.content))
-
- def test_uses_provided_osutil(self):
- osutil = RecordingOSUtils()
- # Use the recording os utility for the transfer manager
- self._manager = TransferManager(self.client, self.config, osutil)
-
- self.add_put_object_response_with_default_expected_params()
-
- future = self.manager.upload(self.filename, self.bucket, self.key)
- future.result()
-
- # The upload should have used the os utility. We check this by making
- # sure that the recorded opens are as expected.
- expected_opens = [(self.filename, 'rb')]
- self.assertEqual(osutil.open_records, expected_opens)
-
- def test_allowed_upload_params_are_valid(self):
- op_model = self.client.meta.service_model.operation_model('PutObject')
- for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS:
- self.assertIn(allowed_upload_arg, op_model.input_shape.members)
-
- def test_upload_with_bandwidth_limiter(self):
- self.content = b'a' * 1024 * 1024
- with open(self.filename, 'wb') as f:
- f.write(self.content)
- self.config = TransferConfig(
- max_request_concurrency=1, max_bandwidth=len(self.content) / 2
- )
- self._manager = TransferManager(self.client, self.config)
-
- self.add_put_object_response_with_default_expected_params()
- start = time.time()
- future = self.manager.upload(self.filename, self.bucket, self.key)
- future.result()
- # This is just a smoke test to make sure that the limiter is
- # being used and not necessary its exactness. So we set the maximum
- # bandwidth to len(content)/2 per sec and make sure that it is
- # noticeably slower. Ideally it will take more than two seconds, but
- # given tracking at the beginning of transfers are not entirely
- # accurate setting at the initial start of a transfer, we give us
- # some flexibility by setting the expected time to half of the
- # theoretical time to take.
- self.assertGreaterEqual(time.time() - start, 1)
-
- self.assert_expected_client_calls_were_correct()
- self.assert_put_object_body_was_correct()
-
- def test_raise_exception_on_s3_object_lambda_resource(self):
- s3_object_lambda_arn = (
- 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
- 'accesspoint:my-accesspoint'
- )
- with self.assertRaisesRegex(ValueError, 'methods do not support'):
- self.manager.upload(self.filename, s3_object_lambda_arn, self.key)
-
-
-class TestMultipartUpload(BaseUploadTest):
- __test__ = True
-
- def setUp(self):
- super().setUp()
- self.chunksize = 4
- self.config = TransferConfig(
- max_request_concurrency=1,
- multipart_threshold=1,
- multipart_chunksize=self.chunksize,
- )
- self._manager = TransferManager(self.client, self.config)
- self.multipart_id = 'my-upload-id'
-
- def create_stubbed_responses(self):
- return [
- {
- 'method': 'create_multipart_upload',
- 'service_response': {'UploadId': self.multipart_id},
- },
- {'method': 'upload_part', 'service_response': {'ETag': 'etag-1'}},
- {'method': 'upload_part', 'service_response': {'ETag': 'etag-2'}},
- {'method': 'upload_part', 'service_response': {'ETag': 'etag-3'}},
- {'method': 'complete_multipart_upload', 'service_response': {}},
- ]
-
- def create_expected_progress_callback_info(self):
- return [
- {'bytes_transferred': 4},
- {'bytes_transferred': 4},
- {'bytes_transferred': 2},
- ]
-
- def assert_upload_part_bodies_were_correct(self):
- expected_contents = []
- for i in range(0, len(self.content), self.chunksize):
- end_i = i + self.chunksize
- if end_i > len(self.content):
- expected_contents.append(self.content[i:])
- else:
- expected_contents.append(self.content[i:end_i])
- self.assertEqual(self.sent_bodies, expected_contents)
-
- def add_create_multipart_response_with_default_expected_params(
- self, extra_expected_params=None
- ):
- expected_params = {'Bucket': self.bucket, 'Key': self.key}
- if extra_expected_params:
- expected_params.update(extra_expected_params)
- response = self.create_stubbed_responses()[0]
- response['expected_params'] = expected_params
- self.stubber.add_response(**response)
-
- def add_upload_part_responses_with_default_expected_params(
- self, extra_expected_params=None
- ):
- num_parts = 3
- upload_part_responses = self.create_stubbed_responses()[1:-1]
- for i in range(num_parts):
- upload_part_response = upload_part_responses[i]
- expected_params = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': self.multipart_id,
- 'Body': ANY,
- 'PartNumber': i + 1,
- }
- if extra_expected_params:
- expected_params.update(extra_expected_params)
- upload_part_response['expected_params'] = expected_params
- self.stubber.add_response(**upload_part_response)
-
- def add_complete_multipart_response_with_default_expected_params(
- self, extra_expected_params=None
- ):
- expected_params = {
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': self.multipart_id,
- 'MultipartUpload': {
- 'Parts': [
- {'ETag': 'etag-1', 'PartNumber': 1},
- {'ETag': 'etag-2', 'PartNumber': 2},
- {'ETag': 'etag-3', 'PartNumber': 3},
- ]
- },
- }
- if extra_expected_params:
- expected_params.update(extra_expected_params)
- response = self.create_stubbed_responses()[-1]
- response['expected_params'] = expected_params
- self.stubber.add_response(**response)
-
- def test_upload(self):
- self.extra_args['RequestPayer'] = 'requester'
-
- # Add requester pays to the create multipart upload and upload parts.
- self.add_create_multipart_response_with_default_expected_params(
- extra_expected_params={'RequestPayer': 'requester'}
- )
- self.add_upload_part_responses_with_default_expected_params(
- extra_expected_params={'RequestPayer': 'requester'}
- )
- self.add_complete_multipart_response_with_default_expected_params(
- extra_expected_params={'RequestPayer': 'requester'}
- )
-
- future = self.manager.upload(
- self.filename, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
-
- def test_upload_for_fileobj(self):
- self.add_create_multipart_response_with_default_expected_params()
- self.add_upload_part_responses_with_default_expected_params()
- self.add_complete_multipart_response_with_default_expected_params()
- with open(self.filename, 'rb') as f:
- future = self.manager.upload(
- f, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_upload_part_bodies_were_correct()
-
- def test_upload_for_seekable_filelike_obj(self):
- self.add_create_multipart_response_with_default_expected_params()
- self.add_upload_part_responses_with_default_expected_params()
- self.add_complete_multipart_response_with_default_expected_params()
- bytes_io = BytesIO(self.content)
- future = self.manager.upload(
- bytes_io, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_upload_part_bodies_were_correct()
-
- def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self):
- self.add_create_multipart_response_with_default_expected_params()
- self.add_upload_part_responses_with_default_expected_params()
- self.add_complete_multipart_response_with_default_expected_params()
- bytes_io = BytesIO(self.content)
- seek_pos = 1
- bytes_io.seek(seek_pos)
- future = self.manager.upload(
- bytes_io, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:])
-
- def test_upload_for_non_seekable_filelike_obj(self):
- self.add_create_multipart_response_with_default_expected_params()
- self.add_upload_part_responses_with_default_expected_params()
- self.add_complete_multipart_response_with_default_expected_params()
- stream = NonSeekableReader(self.content)
- future = self.manager.upload(
- stream, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
- self.assert_upload_part_bodies_were_correct()
-
- def test_limits_in_memory_chunks_for_fileobj(self):
- # Limit the maximum in memory chunks to one but make number of
- # threads more than one. This means that the upload will have to
- # happen sequentially despite having many threads available because
- # data is sequentially partitioned into chunks in memory and since
- # there can only every be one in memory chunk, each upload part will
- # have to happen one at a time.
- self.config.max_request_concurrency = 10
- self.config.max_in_memory_upload_chunks = 1
- self._manager = TransferManager(self.client, self.config)
-
- # Add some default stubbed responses.
- # These responses are added in order of part number so if the
- # multipart upload is not done sequentially, which it should because
- # we limit the in memory upload chunks to one, the stubber will
- # raise exceptions for mismatching parameters for partNumber when
- # once the upload() method is called on the transfer manager.
- # If there is a mismatch, the stubber error will propagate on
- # the future.result()
- self.add_create_multipart_response_with_default_expected_params()
- self.add_upload_part_responses_with_default_expected_params()
- self.add_complete_multipart_response_with_default_expected_params()
- with open(self.filename, 'rb') as f:
- future = self.manager.upload(
- f, self.bucket, self.key, self.extra_args
- )
- future.result()
-
- # Make sure that the stubber had all of its stubbed responses consumed.
- self.assert_expected_client_calls_were_correct()
- # Ensure the contents were uploaded in sequentially order by checking
- # the sent contents were in order.
- self.assert_upload_part_bodies_were_correct()
-
- def test_upload_failure_invokes_abort(self):
- self.stubber.add_response(
- method='create_multipart_upload',
- service_response={'UploadId': self.multipart_id},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- self.stubber.add_response(
- method='upload_part',
- service_response={'ETag': 'etag-1'},
- expected_params={
- 'Bucket': self.bucket,
- 'Body': ANY,
- 'Key': self.key,
- 'UploadId': self.multipart_id,
- 'PartNumber': 1,
- },
- )
- # With the upload part failing this should immediately initiate
- # an abort multipart with no more upload parts called.
- self.stubber.add_client_error(method='upload_part')
-
- self.stubber.add_response(
- method='abort_multipart_upload',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': self.multipart_id,
- },
- )
-
- future = self.manager.upload(self.filename, self.bucket, self.key)
- # The exception should get propagated to the future and not be
- # a cancelled error or something.
- with self.assertRaises(ClientError):
- future.result()
- self.assert_expected_client_calls_were_correct()
-
- def test_upload_passes_select_extra_args(self):
- self.extra_args['Metadata'] = {'foo': 'bar'}
-
- # Add metadata to expected create multipart upload call
- self.add_create_multipart_response_with_default_expected_params(
- extra_expected_params={'Metadata': {'foo': 'bar'}}
- )
- self.add_upload_part_responses_with_default_expected_params()
- self.add_complete_multipart_response_with_default_expected_params()
-
- future = self.manager.upload(
- self.filename, self.bucket, self.key, self.extra_args
- )
- future.result()
- self.assert_expected_client_calls_were_correct()
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import tempfile
+import time
+from io import BytesIO
+
+from botocore.awsrequest import AWSRequest
+from botocore.client import Config
+from botocore.exceptions import ClientError
+from botocore.stub import ANY
+
+from s3transfer.manager import TransferConfig, TransferManager
+from s3transfer.utils import ChunksizeAdjuster
+from __tests__ import (
+ BaseGeneralInterfaceTest,
+ NonSeekableReader,
+ RecordingOSUtils,
+ RecordingSubscriber,
+ mock,
+)
+
+
+class BaseUploadTest(BaseGeneralInterfaceTest):
+ def setUp(self):
+ super().setUp()
+ # TODO: We do not want to use the real MIN_UPLOAD_CHUNKSIZE
+ # when we're adjusting parts.
+ # This is really wasteful and fails CI builds because self.contents
+ # would normally use 10MB+ of memory.
+ # Until there's an API to configure this, we're patching this with
+ # a min size of 1. We can't patch MIN_UPLOAD_CHUNKSIZE directly
+ # because it's already bound to a default value in the
+ # chunksize adjuster. Instead we need to patch out the
+ # chunksize adjuster class.
+ self.adjuster_patch = mock.patch(
+ 's3transfer.upload.ChunksizeAdjuster',
+ lambda: ChunksizeAdjuster(min_size=1),
+ )
+ self.adjuster_patch.start()
+ self.config = TransferConfig(max_request_concurrency=1)
+ self._manager = TransferManager(self.client, self.config)
+
+ # Create a temporary directory with files to read from
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'my content'
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ # Initialize some default arguments
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+ self.adjuster_patch.stop()
+
+ def collect_body(self, params, model, **kwargs):
+ # A handler to simulate the reading of the body including the
+ # request-created event that signals to simulate the progress
+ # callbacks
+ if 'Body' in params:
+ # TODO: This is not ideal. Need to figure out a better idea of
+ # simulating reading of the request across the wire to trigger
+ # progress callbacks
+ request = AWSRequest(
+ method='PUT',
+ url='https://s3.amazonaws.com',
+ data=params['Body'],
+ )
+ self.client.meta.events.emit(
+ 'request-created.s3.%s' % model.name,
+ request=request,
+ operation_name=model.name,
+ )
+ self.sent_bodies.append(self._stream_body(params['Body']))
+
+ def _stream_body(self, body):
+ read_amt = 8 * 1024
+ data = body.read(read_amt)
+ collected_body = data
+ while data:
+ data = body.read(read_amt)
+ collected_body += data
+ return collected_body
+
+ @property
+ def manager(self):
+ return self._manager
+
+ @property
+ def method(self):
+ return self.manager.upload
+
+ def create_call_kwargs(self):
+ return {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ }
+
+ def create_invalid_extra_args(self):
+ return {'Foo': 'bar'}
+
+ def create_stubbed_responses(self):
+ return [{'method': 'put_object', 'service_response': {}}]
+
+ def create_expected_progress_callback_info(self):
+ return [{'bytes_transferred': 10}]
+
+ def assert_expected_client_calls_were_correct(self):
+ # We assert that expected client calls were made by ensuring that
+ # there are no more pending responses. If there are no more pending
+ # responses, then all stubbed responses were consumed.
+ self.stubber.assert_no_pending_responses()
+
+
+class TestNonMultipartUpload(BaseUploadTest):
+ __test__ = True
+
+ def add_put_object_response_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ expected_params = {'Body': ANY, 'Bucket': self.bucket, 'Key': self.key}
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ upload_response = self.create_stubbed_responses()[0]
+ upload_response['expected_params'] = expected_params
+ self.stubber.add_response(**upload_response)
+
+ def assert_put_object_body_was_correct(self):
+ self.assertEqual(self.sent_bodies, [self.content])
+
+ def test_upload(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ self.add_put_object_response_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_upload_for_fileobj(self):
+ self.add_put_object_response_with_default_expected_params()
+ with open(self.filename, 'rb') as f:
+ future = self.manager.upload(
+ f, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_upload_for_seekable_filelike_obj(self):
+ self.add_put_object_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self):
+ self.add_put_object_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ seek_pos = 5
+ bytes_io.seek(seek_pos)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:])
+
+ def test_upload_for_non_seekable_filelike_obj(self):
+ self.add_put_object_response_with_default_expected_params()
+ body = NonSeekableReader(self.content)
+ future = self.manager.upload(
+ body, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_sigv4_progress_callbacks_invoked_once(self):
+ # Reset the client and manager to use sigv4
+ self.reset_stubber_with_new_client(
+ {'config': Config(signature_version='s3v4')}
+ )
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ # Add the stubbed response.
+ self.add_put_object_response_with_default_expected_params()
+
+ subscriber = RecordingSubscriber()
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, subscribers=[subscriber]
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+
+ # The amount of bytes seen should be the same as the file size
+ self.assertEqual(subscriber.calculate_bytes_seen(), len(self.content))
+
+ def test_uses_provided_osutil(self):
+ osutil = RecordingOSUtils()
+ # Use the recording os utility for the transfer manager
+ self._manager = TransferManager(self.client, self.config, osutil)
+
+ self.add_put_object_response_with_default_expected_params()
+
+ future = self.manager.upload(self.filename, self.bucket, self.key)
+ future.result()
+
+ # The upload should have used the os utility. We check this by making
+ # sure that the recorded opens are as expected.
+ expected_opens = [(self.filename, 'rb')]
+ self.assertEqual(osutil.open_records, expected_opens)
+
+ def test_allowed_upload_params_are_valid(self):
+ op_model = self.client.meta.service_model.operation_model('PutObject')
+ for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS:
+ self.assertIn(allowed_upload_arg, op_model.input_shape.members)
+
+ def test_upload_with_bandwidth_limiter(self):
+ self.content = b'a' * 1024 * 1024
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.config = TransferConfig(
+ max_request_concurrency=1, max_bandwidth=len(self.content) / 2
+ )
+ self._manager = TransferManager(self.client, self.config)
+
+ self.add_put_object_response_with_default_expected_params()
+ start = time.time()
+ future = self.manager.upload(self.filename, self.bucket, self.key)
+ future.result()
+ # This is just a smoke test to make sure that the limiter is
+ # being used and not necessary its exactness. So we set the maximum
+ # bandwidth to len(content)/2 per sec and make sure that it is
+ # noticeably slower. Ideally it will take more than two seconds, but
+ # given tracking at the beginning of transfers are not entirely
+ # accurate setting at the initial start of a transfer, we give us
+ # some flexibility by setting the expected time to half of the
+ # theoretical time to take.
+ self.assertGreaterEqual(time.time() - start, 1)
+
+ self.assert_expected_client_calls_were_correct()
+ self.assert_put_object_body_was_correct()
+
+ def test_raise_exception_on_s3_object_lambda_resource(self):
+ s3_object_lambda_arn = (
+ 'arn:aws:s3-object-lambda:us-west-2:123456789012:'
+ 'accesspoint:my-accesspoint'
+ )
+ with self.assertRaisesRegex(ValueError, 'methods do not support'):
+ self.manager.upload(self.filename, s3_object_lambda_arn, self.key)
+
+
+class TestMultipartUpload(BaseUploadTest):
+ __test__ = True
+
+ def setUp(self):
+ super().setUp()
+ self.chunksize = 4
+ self.config = TransferConfig(
+ max_request_concurrency=1,
+ multipart_threshold=1,
+ multipart_chunksize=self.chunksize,
+ )
+ self._manager = TransferManager(self.client, self.config)
+ self.multipart_id = 'my-upload-id'
+
+ def create_stubbed_responses(self):
+ return [
+ {
+ 'method': 'create_multipart_upload',
+ 'service_response': {'UploadId': self.multipart_id},
+ },
+ {'method': 'upload_part', 'service_response': {'ETag': 'etag-1'}},
+ {'method': 'upload_part', 'service_response': {'ETag': 'etag-2'}},
+ {'method': 'upload_part', 'service_response': {'ETag': 'etag-3'}},
+ {'method': 'complete_multipart_upload', 'service_response': {}},
+ ]
+
+ def create_expected_progress_callback_info(self):
+ return [
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 4},
+ {'bytes_transferred': 2},
+ ]
+
+ def assert_upload_part_bodies_were_correct(self):
+ expected_contents = []
+ for i in range(0, len(self.content), self.chunksize):
+ end_i = i + self.chunksize
+ if end_i > len(self.content):
+ expected_contents.append(self.content[i:])
+ else:
+ expected_contents.append(self.content[i:end_i])
+ self.assertEqual(self.sent_bodies, expected_contents)
+
+ def add_create_multipart_response_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ expected_params = {'Bucket': self.bucket, 'Key': self.key}
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ response = self.create_stubbed_responses()[0]
+ response['expected_params'] = expected_params
+ self.stubber.add_response(**response)
+
+ def add_upload_part_responses_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ num_parts = 3
+ upload_part_responses = self.create_stubbed_responses()[1:-1]
+ for i in range(num_parts):
+ upload_part_response = upload_part_responses[i]
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ 'Body': ANY,
+ 'PartNumber': i + 1,
+ }
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ upload_part_response['expected_params'] = expected_params
+ self.stubber.add_response(**upload_part_response)
+
+ def add_complete_multipart_response_with_default_expected_params(
+ self, extra_expected_params=None
+ ):
+ expected_params = {
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ 'MultipartUpload': {
+ 'Parts': [
+ {'ETag': 'etag-1', 'PartNumber': 1},
+ {'ETag': 'etag-2', 'PartNumber': 2},
+ {'ETag': 'etag-3', 'PartNumber': 3},
+ ]
+ },
+ }
+ if extra_expected_params:
+ expected_params.update(extra_expected_params)
+ response = self.create_stubbed_responses()[-1]
+ response['expected_params'] = expected_params
+ self.stubber.add_response(**response)
+
+ def test_upload(self):
+ self.extra_args['RequestPayer'] = 'requester'
+
+ # Add requester pays to the create multipart upload and upload parts.
+ self.add_create_multipart_response_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+ self.add_upload_part_responses_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+ self.add_complete_multipart_response_with_default_expected_params(
+ extra_expected_params={'RequestPayer': 'requester'}
+ )
+
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+
+ def test_upload_for_fileobj(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ with open(self.filename, 'rb') as f:
+ future = self.manager.upload(
+ f, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_upload_for_seekable_filelike_obj(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_upload_for_seekable_filelike_obj_that_has_been_seeked(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ bytes_io = BytesIO(self.content)
+ seek_pos = 1
+ bytes_io.seek(seek_pos)
+ future = self.manager.upload(
+ bytes_io, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assertEqual(b''.join(self.sent_bodies), self.content[seek_pos:])
+
+ def test_upload_for_non_seekable_filelike_obj(self):
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ stream = NonSeekableReader(self.content)
+ future = self.manager.upload(
+ stream, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_limits_in_memory_chunks_for_fileobj(self):
+ # Limit the maximum in memory chunks to one but make number of
+ # threads more than one. This means that the upload will have to
+ # happen sequentially despite having many threads available because
+ # data is sequentially partitioned into chunks in memory and since
+ # there can only every be one in memory chunk, each upload part will
+ # have to happen one at a time.
+ self.config.max_request_concurrency = 10
+ self.config.max_in_memory_upload_chunks = 1
+ self._manager = TransferManager(self.client, self.config)
+
+ # Add some default stubbed responses.
+ # These responses are added in order of part number so if the
+ # multipart upload is not done sequentially, which it should because
+ # we limit the in memory upload chunks to one, the stubber will
+ # raise exceptions for mismatching parameters for partNumber when
+ # once the upload() method is called on the transfer manager.
+ # If there is a mismatch, the stubber error will propagate on
+ # the future.result()
+ self.add_create_multipart_response_with_default_expected_params()
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+ with open(self.filename, 'rb') as f:
+ future = self.manager.upload(
+ f, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+
+ # Make sure that the stubber had all of its stubbed responses consumed.
+ self.assert_expected_client_calls_were_correct()
+ # Ensure the contents were uploaded in sequentially order by checking
+ # the sent contents were in order.
+ self.assert_upload_part_bodies_were_correct()
+
+ def test_upload_failure_invokes_abort(self):
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response={'UploadId': self.multipart_id},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ method='upload_part',
+ service_response={'ETag': 'etag-1'},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Body': ANY,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ 'PartNumber': 1,
+ },
+ )
+ # With the upload part failing this should immediately initiate
+ # an abort multipart with no more upload parts called.
+ self.stubber.add_client_error(method='upload_part')
+
+ self.stubber.add_response(
+ method='abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': self.multipart_id,
+ },
+ )
+
+ future = self.manager.upload(self.filename, self.bucket, self.key)
+ # The exception should get propagated to the future and not be
+ # a cancelled error or something.
+ with self.assertRaises(ClientError):
+ future.result()
+ self.assert_expected_client_calls_were_correct()
+
+ def test_upload_passes_select_extra_args(self):
+ self.extra_args['Metadata'] = {'foo': 'bar'}
+
+ # Add metadata to expected create multipart upload call
+ self.add_create_multipart_response_with_default_expected_params(
+ extra_expected_params={'Metadata': {'foo': 'bar'}}
+ )
+ self.add_upload_part_responses_with_default_expected_params()
+ self.add_complete_multipart_response_with_default_expected_params()
+
+ future = self.manager.upload(
+ self.filename, self.bucket, self.key, self.extra_args
+ )
+ future.result()
+ self.assert_expected_client_calls_were_correct()
diff --git a/contrib/python/s3transfer/py3/tests/functional/test_utils.py b/contrib/python/s3transfer/py3/tests/functional/test_utils.py
index cc51b955d9..fd4a232ecc 100644
--- a/contrib/python/s3transfer/py3/tests/functional/test_utils.py
+++ b/contrib/python/s3transfer/py3/tests/functional/test_utils.py
@@ -1,41 +1,41 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import shutil
-import socket
-import tempfile
-
-from s3transfer.utils import OSUtils
-from __tests__ import skip_if_windows, unittest
-
-
-@skip_if_windows('Windows does not support UNIX special files')
-class TestOSUtilsSpecialFiles(unittest.TestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def test_character_device(self):
- self.assertTrue(OSUtils().is_special_file('/dev/null'))
-
- def test_fifo(self):
- os.mkfifo(self.filename)
- self.assertTrue(OSUtils().is_special_file(self.filename))
-
- def test_socket(self):
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
- sock.bind(self.filename)
- self.assertTrue(OSUtils().is_special_file(self.filename))
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import socket
+import tempfile
+
+from s3transfer.utils import OSUtils
+from __tests__ import skip_if_windows, unittest
+
+
+@skip_if_windows('Windows does not support UNIX special files')
+class TestOSUtilsSpecialFiles(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_character_device(self):
+ self.assertTrue(OSUtils().is_special_file('/dev/null'))
+
+ def test_fifo(self):
+ os.mkfifo(self.filename)
+ self.assertTrue(OSUtils().is_special_file(self.filename))
+
+ def test_socket(self):
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ sock.bind(self.filename)
+ self.assertTrue(OSUtils().is_special_file(self.filename))
diff --git a/contrib/python/s3transfer/py3/tests/unit/__init__.py b/contrib/python/s3transfer/py3/tests/unit/__init__.py
index 5e791f8b92..79ef91c6a2 100644
--- a/contrib/python/s3transfer/py3/tests/unit/__init__.py
+++ b/contrib/python/s3transfer/py3/tests/unit/__init__.py
@@ -1,12 +1,12 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License'). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the 'license' file accompanying this file. This file is
-# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
index 0e53ec157a..b796f8f24c 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_bandwidth.py
@@ -1,452 +1,452 @@
-# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import shutil
-import tempfile
-
-from s3transfer.bandwidth import (
- BandwidthLimitedStream,
- BandwidthLimiter,
- BandwidthRateTracker,
- ConsumptionScheduler,
- LeakyBucket,
- RequestExceededException,
- RequestToken,
- TimeUtils,
-)
-from s3transfer.futures import TransferCoordinator
-from __tests__ import mock, unittest
-
-
-class FixedIncrementalTickTimeUtils(TimeUtils):
- def __init__(self, seconds_per_tick=1.0):
- self._count = 0
- self._seconds_per_tick = seconds_per_tick
-
- def time(self):
- current_count = self._count
- self._count += self._seconds_per_tick
- return current_count
-
-
-class TestTimeUtils(unittest.TestCase):
- @mock.patch('time.time')
- def test_time(self, mock_time):
- mock_return_val = 1
- mock_time.return_value = mock_return_val
- time_utils = TimeUtils()
- self.assertEqual(time_utils.time(), mock_return_val)
-
- @mock.patch('time.sleep')
- def test_sleep(self, mock_sleep):
- time_utils = TimeUtils()
- time_utils.sleep(1)
- self.assertEqual(mock_sleep.call_args_list, [mock.call(1)])
-
-
-class BaseBandwidthLimitTest(unittest.TestCase):
- def setUp(self):
- self.leaky_bucket = mock.Mock(LeakyBucket)
- self.time_utils = mock.Mock(TimeUtils)
- self.tempdir = tempfile.mkdtemp()
- self.content = b'a' * 1024 * 1024
- self.filename = os.path.join(self.tempdir, 'myfile')
- with open(self.filename, 'wb') as f:
- f.write(self.content)
- self.coordinator = TransferCoordinator()
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def assert_consume_calls(self, amts):
- expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts]
- self.assertEqual(
- self.leaky_bucket.consume.call_args_list, expected_consume_args
- )
-
-
-class TestBandwidthLimiter(BaseBandwidthLimitTest):
- def setUp(self):
- super().setUp()
- self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket)
-
- def test_get_bandwidth_limited_stream(self):
- with open(self.filename, 'rb') as f:
- stream = self.bandwidth_limiter.get_bandwith_limited_stream(
- f, self.coordinator
- )
- self.assertIsInstance(stream, BandwidthLimitedStream)
- self.assertEqual(stream.read(len(self.content)), self.content)
- self.assert_consume_calls(amts=[len(self.content)])
-
- def test_get_disabled_bandwidth_limited_stream(self):
- with open(self.filename, 'rb') as f:
- stream = self.bandwidth_limiter.get_bandwith_limited_stream(
- f, self.coordinator, enabled=False
- )
- self.assertIsInstance(stream, BandwidthLimitedStream)
- self.assertEqual(stream.read(len(self.content)), self.content)
- self.leaky_bucket.consume.assert_not_called()
-
-
-class TestBandwidthLimitedStream(BaseBandwidthLimitTest):
- def setUp(self):
- super().setUp()
- self.bytes_threshold = 1
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def get_bandwidth_limited_stream(self, f):
- return BandwidthLimitedStream(
- f,
- self.leaky_bucket,
- self.coordinator,
- self.time_utils,
- self.bytes_threshold,
- )
-
- def assert_sleep_calls(self, amts):
- expected_sleep_args_list = [mock.call(amt) for amt in amts]
- self.assertEqual(
- self.time_utils.sleep.call_args_list, expected_sleep_args_list
- )
-
- def get_unique_consume_request_tokens(self):
- return {
- call_args[0][1]
- for call_args in self.leaky_bucket.consume.call_args_list
- }
-
- def test_read(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- data = stream.read(len(self.content))
- self.assertEqual(self.content, data)
- self.assert_consume_calls(amts=[len(self.content)])
- self.assert_sleep_calls(amts=[])
-
- def test_retries_on_request_exceeded(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- retry_time = 1
- amt_requested = len(self.content)
- self.leaky_bucket.consume.side_effect = [
- RequestExceededException(amt_requested, retry_time),
- len(self.content),
- ]
- data = stream.read(len(self.content))
- self.assertEqual(self.content, data)
- self.assert_consume_calls(amts=[amt_requested, amt_requested])
- self.assert_sleep_calls(amts=[retry_time])
-
- def test_with_transfer_coordinator_exception(self):
- self.coordinator.set_exception(ValueError())
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- with self.assertRaises(ValueError):
- stream.read(len(self.content))
-
- def test_read_when_bandwidth_limiting_disabled(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- stream.disable_bandwidth_limiting()
- data = stream.read(len(self.content))
- self.assertEqual(self.content, data)
- self.assertFalse(self.leaky_bucket.consume.called)
-
- def test_read_toggle_disable_enable_bandwidth_limiting(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- stream.disable_bandwidth_limiting()
- data = stream.read(1)
- self.assertEqual(self.content[:1], data)
- self.assert_consume_calls(amts=[])
- stream.enable_bandwidth_limiting()
- data = stream.read(len(self.content) - 1)
- self.assertEqual(self.content[1:], data)
- self.assert_consume_calls(amts=[len(self.content) - 1])
-
- def test_seek(self):
- mock_fileobj = mock.Mock()
- stream = self.get_bandwidth_limited_stream(mock_fileobj)
- stream.seek(1)
- self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)])
-
- def test_tell(self):
- mock_fileobj = mock.Mock()
- stream = self.get_bandwidth_limited_stream(mock_fileobj)
- stream.tell()
- self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()])
-
- def test_close(self):
- mock_fileobj = mock.Mock()
- stream = self.get_bandwidth_limited_stream(mock_fileobj)
- stream.close()
- self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
-
- def test_context_manager(self):
- mock_fileobj = mock.Mock()
- stream = self.get_bandwidth_limited_stream(mock_fileobj)
- with stream as stream_handle:
- self.assertIs(stream_handle, stream)
- self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
-
- def test_reuses_request_token(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- stream.read(1)
- stream.read(1)
- self.assertEqual(len(self.get_unique_consume_request_tokens()), 1)
-
- def test_request_tokens_unique_per_stream(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- stream.read(1)
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- stream.read(1)
- self.assertEqual(len(self.get_unique_consume_request_tokens()), 2)
-
- def test_call_consume_after_reaching_threshold(self):
- self.bytes_threshold = 2
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- self.assertEqual(stream.read(1), self.content[:1])
- self.assert_consume_calls(amts=[])
- self.assertEqual(stream.read(1), self.content[1:2])
- self.assert_consume_calls(amts=[2])
-
- def test_resets_after_reaching_threshold(self):
- self.bytes_threshold = 2
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- self.assertEqual(stream.read(2), self.content[:2])
- self.assert_consume_calls(amts=[2])
- self.assertEqual(stream.read(1), self.content[2:3])
- self.assert_consume_calls(amts=[2])
-
- def test_pending_bytes_seen_on_close(self):
- self.bytes_threshold = 2
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- self.assertEqual(stream.read(1), self.content[:1])
- self.assert_consume_calls(amts=[])
- stream.close()
- self.assert_consume_calls(amts=[1])
-
- def test_no_bytes_remaining_on(self):
- self.bytes_threshold = 2
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- self.assertEqual(stream.read(2), self.content[:2])
- self.assert_consume_calls(amts=[2])
- stream.close()
- # There should have been no more consume() calls made
- # as all bytes have been accounted for in the previous
- # consume() call.
- self.assert_consume_calls(amts=[2])
-
- def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self):
- self.bytes_threshold = 2
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- self.assertEqual(stream.read(1), self.content[:1])
- self.assert_consume_calls(amts=[])
- stream.disable_bandwidth_limiting()
- stream.close()
- self.assert_consume_calls(amts=[])
-
- def test_signal_transferring(self):
- with open(self.filename, 'rb') as f:
- stream = self.get_bandwidth_limited_stream(f)
- stream.signal_not_transferring()
- data = stream.read(1)
- self.assertEqual(self.content[:1], data)
- self.assert_consume_calls(amts=[])
- stream.signal_transferring()
- data = stream.read(len(self.content) - 1)
- self.assertEqual(self.content[1:], data)
- self.assert_consume_calls(amts=[len(self.content) - 1])
-
-
-class TestLeakyBucket(unittest.TestCase):
- def setUp(self):
- self.max_rate = 1
- self.time_now = 1.0
- self.time_utils = mock.Mock(TimeUtils)
- self.time_utils.time.return_value = self.time_now
- self.scheduler = mock.Mock(ConsumptionScheduler)
- self.scheduler.is_scheduled.return_value = False
- self.rate_tracker = mock.Mock(BandwidthRateTracker)
- self.leaky_bucket = LeakyBucket(
- self.max_rate, self.time_utils, self.rate_tracker, self.scheduler
- )
-
- def set_projected_rate(self, rate):
- self.rate_tracker.get_projected_rate.return_value = rate
-
- def set_retry_time(self, retry_time):
- self.scheduler.schedule_consumption.return_value = retry_time
-
- def assert_recorded_consumed_amt(self, expected_amt):
- self.assertEqual(
- self.rate_tracker.record_consumption_rate.call_args,
- mock.call(expected_amt, self.time_utils.time.return_value),
- )
-
- def assert_was_scheduled(self, amt, token):
- self.assertEqual(
- self.scheduler.schedule_consumption.call_args,
- mock.call(amt, token, amt / (self.max_rate)),
- )
-
- def assert_nothing_scheduled(self):
- self.assertFalse(self.scheduler.schedule_consumption.called)
-
- def assert_processed_request_token(self, request_token):
- self.assertEqual(
- self.scheduler.process_scheduled_consumption.call_args,
- mock.call(request_token),
- )
-
- def test_consume_under_max_rate(self):
- amt = 1
- self.set_projected_rate(self.max_rate / 2)
- self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
- self.assert_recorded_consumed_amt(amt)
- self.assert_nothing_scheduled()
-
- def test_consume_at_max_rate(self):
- amt = 1
- self.set_projected_rate(self.max_rate)
- self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
- self.assert_recorded_consumed_amt(amt)
- self.assert_nothing_scheduled()
-
- def test_consume_over_max_rate(self):
- amt = 1
- retry_time = 2.0
- self.set_projected_rate(self.max_rate + 1)
- self.set_retry_time(retry_time)
- request_token = RequestToken()
- try:
- self.leaky_bucket.consume(amt, request_token)
- self.fail('A RequestExceededException should have been thrown')
- except RequestExceededException as e:
- self.assertEqual(e.requested_amt, amt)
- self.assertEqual(e.retry_time, retry_time)
- self.assert_was_scheduled(amt, request_token)
-
- def test_consume_with_scheduled_retry(self):
- amt = 1
- self.set_projected_rate(self.max_rate + 1)
- self.scheduler.is_scheduled.return_value = True
- request_token = RequestToken()
- self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt)
- # Nothing new should have been scheduled but the request token
- # should have been processed.
- self.assert_nothing_scheduled()
- self.assert_processed_request_token(request_token)
-
-
-class TestConsumptionScheduler(unittest.TestCase):
- def setUp(self):
- self.scheduler = ConsumptionScheduler()
-
- def test_schedule_consumption(self):
- token = RequestToken()
- consume_time = 5
- actual_wait_time = self.scheduler.schedule_consumption(
- 1, token, consume_time
- )
- self.assertEqual(consume_time, actual_wait_time)
-
- def test_schedule_consumption_for_multiple_requests(self):
- token = RequestToken()
- consume_time = 5
- actual_wait_time = self.scheduler.schedule_consumption(
- 1, token, consume_time
- )
- self.assertEqual(consume_time, actual_wait_time)
-
- other_consume_time = 3
- other_token = RequestToken()
- next_wait_time = self.scheduler.schedule_consumption(
- 1, other_token, other_consume_time
- )
-
- # This wait time should be the previous time plus its desired
- # wait time
- self.assertEqual(next_wait_time, consume_time + other_consume_time)
-
- def test_is_scheduled(self):
- token = RequestToken()
- consume_time = 5
- self.scheduler.schedule_consumption(1, token, consume_time)
- self.assertTrue(self.scheduler.is_scheduled(token))
-
- def test_is_not_scheduled(self):
- self.assertFalse(self.scheduler.is_scheduled(RequestToken()))
-
- def test_process_scheduled_consumption(self):
- token = RequestToken()
- consume_time = 5
- self.scheduler.schedule_consumption(1, token, consume_time)
- self.scheduler.process_scheduled_consumption(token)
- self.assertFalse(self.scheduler.is_scheduled(token))
- different_time = 7
- # The previous consume time should have no affect on the next wait tim
- # as it has been completed.
- self.assertEqual(
- self.scheduler.schedule_consumption(1, token, different_time),
- different_time,
- )
-
-
-class TestBandwidthRateTracker(unittest.TestCase):
- def setUp(self):
- self.alpha = 0.8
- self.rate_tracker = BandwidthRateTracker(self.alpha)
-
- def test_current_rate_at_initilizations(self):
- self.assertEqual(self.rate_tracker.current_rate, 0.0)
-
- def test_current_rate_after_one_recorded_point(self):
- self.rate_tracker.record_consumption_rate(1, 1)
- # There is no last time point to do a diff against so return a
- # current rate of 0.0
- self.assertEqual(self.rate_tracker.current_rate, 0.0)
-
- def test_current_rate(self):
- self.rate_tracker.record_consumption_rate(1, 1)
- self.rate_tracker.record_consumption_rate(1, 2)
- self.rate_tracker.record_consumption_rate(1, 3)
- self.assertEqual(self.rate_tracker.current_rate, 0.96)
-
- def test_get_projected_rate_at_initilizations(self):
- self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0)
-
- def test_get_projected_rate(self):
- self.rate_tracker.record_consumption_rate(1, 1)
- self.rate_tracker.record_consumption_rate(1, 2)
- projected_rate = self.rate_tracker.get_projected_rate(1, 3)
- self.assertEqual(projected_rate, 0.96)
- self.rate_tracker.record_consumption_rate(1, 3)
- self.assertEqual(self.rate_tracker.current_rate, projected_rate)
-
- def test_get_projected_rate_for_same_timestamp(self):
- self.rate_tracker.record_consumption_rate(1, 1)
- self.assertEqual(
- self.rate_tracker.get_projected_rate(1, 1), float('inf')
- )
+# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import tempfile
+
+from s3transfer.bandwidth import (
+ BandwidthLimitedStream,
+ BandwidthLimiter,
+ BandwidthRateTracker,
+ ConsumptionScheduler,
+ LeakyBucket,
+ RequestExceededException,
+ RequestToken,
+ TimeUtils,
+)
+from s3transfer.futures import TransferCoordinator
+from __tests__ import mock, unittest
+
+
+class FixedIncrementalTickTimeUtils(TimeUtils):
+ def __init__(self, seconds_per_tick=1.0):
+ self._count = 0
+ self._seconds_per_tick = seconds_per_tick
+
+ def time(self):
+ current_count = self._count
+ self._count += self._seconds_per_tick
+ return current_count
+
+
+class TestTimeUtils(unittest.TestCase):
+ @mock.patch('time.time')
+ def test_time(self, mock_time):
+ mock_return_val = 1
+ mock_time.return_value = mock_return_val
+ time_utils = TimeUtils()
+ self.assertEqual(time_utils.time(), mock_return_val)
+
+ @mock.patch('time.sleep')
+ def test_sleep(self, mock_sleep):
+ time_utils = TimeUtils()
+ time_utils.sleep(1)
+ self.assertEqual(mock_sleep.call_args_list, [mock.call(1)])
+
+
+class BaseBandwidthLimitTest(unittest.TestCase):
+ def setUp(self):
+ self.leaky_bucket = mock.Mock(LeakyBucket)
+ self.time_utils = mock.Mock(TimeUtils)
+ self.tempdir = tempfile.mkdtemp()
+ self.content = b'a' * 1024 * 1024
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.coordinator = TransferCoordinator()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def assert_consume_calls(self, amts):
+ expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts]
+ self.assertEqual(
+ self.leaky_bucket.consume.call_args_list, expected_consume_args
+ )
+
+
+class TestBandwidthLimiter(BaseBandwidthLimitTest):
+ def setUp(self):
+ super().setUp()
+ self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket)
+
+ def test_get_bandwidth_limited_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.bandwidth_limiter.get_bandwith_limited_stream(
+ f, self.coordinator
+ )
+ self.assertIsInstance(stream, BandwidthLimitedStream)
+ self.assertEqual(stream.read(len(self.content)), self.content)
+ self.assert_consume_calls(amts=[len(self.content)])
+
+ def test_get_disabled_bandwidth_limited_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.bandwidth_limiter.get_bandwith_limited_stream(
+ f, self.coordinator, enabled=False
+ )
+ self.assertIsInstance(stream, BandwidthLimitedStream)
+ self.assertEqual(stream.read(len(self.content)), self.content)
+ self.leaky_bucket.consume.assert_not_called()
+
+
+class TestBandwidthLimitedStream(BaseBandwidthLimitTest):
+ def setUp(self):
+ super().setUp()
+ self.bytes_threshold = 1
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def get_bandwidth_limited_stream(self, f):
+ return BandwidthLimitedStream(
+ f,
+ self.leaky_bucket,
+ self.coordinator,
+ self.time_utils,
+ self.bytes_threshold,
+ )
+
+ def assert_sleep_calls(self, amts):
+ expected_sleep_args_list = [mock.call(amt) for amt in amts]
+ self.assertEqual(
+ self.time_utils.sleep.call_args_list, expected_sleep_args_list
+ )
+
+ def get_unique_consume_request_tokens(self):
+ return {
+ call_args[0][1]
+ for call_args in self.leaky_bucket.consume.call_args_list
+ }
+
+ def test_read(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assert_consume_calls(amts=[len(self.content)])
+ self.assert_sleep_calls(amts=[])
+
+ def test_retries_on_request_exceeded(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ retry_time = 1
+ amt_requested = len(self.content)
+ self.leaky_bucket.consume.side_effect = [
+ RequestExceededException(amt_requested, retry_time),
+ len(self.content),
+ ]
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assert_consume_calls(amts=[amt_requested, amt_requested])
+ self.assert_sleep_calls(amts=[retry_time])
+
+ def test_with_transfer_coordinator_exception(self):
+ self.coordinator.set_exception(ValueError())
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ with self.assertRaises(ValueError):
+ stream.read(len(self.content))
+
+ def test_read_when_bandwidth_limiting_disabled(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.disable_bandwidth_limiting()
+ data = stream.read(len(self.content))
+ self.assertEqual(self.content, data)
+ self.assertFalse(self.leaky_bucket.consume.called)
+
+ def test_read_toggle_disable_enable_bandwidth_limiting(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.disable_bandwidth_limiting()
+ data = stream.read(1)
+ self.assertEqual(self.content[:1], data)
+ self.assert_consume_calls(amts=[])
+ stream.enable_bandwidth_limiting()
+ data = stream.read(len(self.content) - 1)
+ self.assertEqual(self.content[1:], data)
+ self.assert_consume_calls(amts=[len(self.content) - 1])
+
+ def test_seek(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.seek(1)
+ self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)])
+
+ def test_tell(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.tell()
+ self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()])
+
+ def test_close(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ stream.close()
+ self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
+
+ def test_context_manager(self):
+ mock_fileobj = mock.Mock()
+ stream = self.get_bandwidth_limited_stream(mock_fileobj)
+ with stream as stream_handle:
+ self.assertIs(stream_handle, stream)
+ self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
+
+ def test_reuses_request_token(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ stream.read(1)
+ self.assertEqual(len(self.get_unique_consume_request_tokens()), 1)
+
+ def test_request_tokens_unique_per_stream(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.read(1)
+ self.assertEqual(len(self.get_unique_consume_request_tokens()), 2)
+
+ def test_call_consume_after_reaching_threshold(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ self.assertEqual(stream.read(1), self.content[1:2])
+ self.assert_consume_calls(amts=[2])
+
+ def test_resets_after_reaching_threshold(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(2), self.content[:2])
+ self.assert_consume_calls(amts=[2])
+ self.assertEqual(stream.read(1), self.content[2:3])
+ self.assert_consume_calls(amts=[2])
+
+ def test_pending_bytes_seen_on_close(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ stream.close()
+ self.assert_consume_calls(amts=[1])
+
+ def test_no_bytes_remaining_on(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(2), self.content[:2])
+ self.assert_consume_calls(amts=[2])
+ stream.close()
+ # There should have been no more consume() calls made
+ # as all bytes have been accounted for in the previous
+ # consume() call.
+ self.assert_consume_calls(amts=[2])
+
+ def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self):
+ self.bytes_threshold = 2
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ self.assertEqual(stream.read(1), self.content[:1])
+ self.assert_consume_calls(amts=[])
+ stream.disable_bandwidth_limiting()
+ stream.close()
+ self.assert_consume_calls(amts=[])
+
+ def test_signal_transferring(self):
+ with open(self.filename, 'rb') as f:
+ stream = self.get_bandwidth_limited_stream(f)
+ stream.signal_not_transferring()
+ data = stream.read(1)
+ self.assertEqual(self.content[:1], data)
+ self.assert_consume_calls(amts=[])
+ stream.signal_transferring()
+ data = stream.read(len(self.content) - 1)
+ self.assertEqual(self.content[1:], data)
+ self.assert_consume_calls(amts=[len(self.content) - 1])
+
+
+class TestLeakyBucket(unittest.TestCase):
+ def setUp(self):
+ self.max_rate = 1
+ self.time_now = 1.0
+ self.time_utils = mock.Mock(TimeUtils)
+ self.time_utils.time.return_value = self.time_now
+ self.scheduler = mock.Mock(ConsumptionScheduler)
+ self.scheduler.is_scheduled.return_value = False
+ self.rate_tracker = mock.Mock(BandwidthRateTracker)
+ self.leaky_bucket = LeakyBucket(
+ self.max_rate, self.time_utils, self.rate_tracker, self.scheduler
+ )
+
+ def set_projected_rate(self, rate):
+ self.rate_tracker.get_projected_rate.return_value = rate
+
+ def set_retry_time(self, retry_time):
+ self.scheduler.schedule_consumption.return_value = retry_time
+
+ def assert_recorded_consumed_amt(self, expected_amt):
+ self.assertEqual(
+ self.rate_tracker.record_consumption_rate.call_args,
+ mock.call(expected_amt, self.time_utils.time.return_value),
+ )
+
+ def assert_was_scheduled(self, amt, token):
+ self.assertEqual(
+ self.scheduler.schedule_consumption.call_args,
+ mock.call(amt, token, amt / (self.max_rate)),
+ )
+
+ def assert_nothing_scheduled(self):
+ self.assertFalse(self.scheduler.schedule_consumption.called)
+
+ def assert_processed_request_token(self, request_token):
+ self.assertEqual(
+ self.scheduler.process_scheduled_consumption.call_args,
+ mock.call(request_token),
+ )
+
+ def test_consume_under_max_rate(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate / 2)
+ self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
+ self.assert_recorded_consumed_amt(amt)
+ self.assert_nothing_scheduled()
+
+ def test_consume_at_max_rate(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate)
+ self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
+ self.assert_recorded_consumed_amt(amt)
+ self.assert_nothing_scheduled()
+
+ def test_consume_over_max_rate(self):
+ amt = 1
+ retry_time = 2.0
+ self.set_projected_rate(self.max_rate + 1)
+ self.set_retry_time(retry_time)
+ request_token = RequestToken()
+ try:
+ self.leaky_bucket.consume(amt, request_token)
+ self.fail('A RequestExceededException should have been thrown')
+ except RequestExceededException as e:
+ self.assertEqual(e.requested_amt, amt)
+ self.assertEqual(e.retry_time, retry_time)
+ self.assert_was_scheduled(amt, request_token)
+
+ def test_consume_with_scheduled_retry(self):
+ amt = 1
+ self.set_projected_rate(self.max_rate + 1)
+ self.scheduler.is_scheduled.return_value = True
+ request_token = RequestToken()
+ self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt)
+ # Nothing new should have been scheduled but the request token
+ # should have been processed.
+ self.assert_nothing_scheduled()
+ self.assert_processed_request_token(request_token)
+
+
+class TestConsumptionScheduler(unittest.TestCase):
+ def setUp(self):
+ self.scheduler = ConsumptionScheduler()
+
+ def test_schedule_consumption(self):
+ token = RequestToken()
+ consume_time = 5
+ actual_wait_time = self.scheduler.schedule_consumption(
+ 1, token, consume_time
+ )
+ self.assertEqual(consume_time, actual_wait_time)
+
+ def test_schedule_consumption_for_multiple_requests(self):
+ token = RequestToken()
+ consume_time = 5
+ actual_wait_time = self.scheduler.schedule_consumption(
+ 1, token, consume_time
+ )
+ self.assertEqual(consume_time, actual_wait_time)
+
+ other_consume_time = 3
+ other_token = RequestToken()
+ next_wait_time = self.scheduler.schedule_consumption(
+ 1, other_token, other_consume_time
+ )
+
+ # This wait time should be the previous time plus its desired
+ # wait time
+ self.assertEqual(next_wait_time, consume_time + other_consume_time)
+
+ def test_is_scheduled(self):
+ token = RequestToken()
+ consume_time = 5
+ self.scheduler.schedule_consumption(1, token, consume_time)
+ self.assertTrue(self.scheduler.is_scheduled(token))
+
+ def test_is_not_scheduled(self):
+ self.assertFalse(self.scheduler.is_scheduled(RequestToken()))
+
+ def test_process_scheduled_consumption(self):
+ token = RequestToken()
+ consume_time = 5
+ self.scheduler.schedule_consumption(1, token, consume_time)
+ self.scheduler.process_scheduled_consumption(token)
+ self.assertFalse(self.scheduler.is_scheduled(token))
+ different_time = 7
+ # The previous consume time should have no affect on the next wait tim
+ # as it has been completed.
+ self.assertEqual(
+ self.scheduler.schedule_consumption(1, token, different_time),
+ different_time,
+ )
+
+
+class TestBandwidthRateTracker(unittest.TestCase):
+ def setUp(self):
+ self.alpha = 0.8
+ self.rate_tracker = BandwidthRateTracker(self.alpha)
+
+ def test_current_rate_at_initilizations(self):
+ self.assertEqual(self.rate_tracker.current_rate, 0.0)
+
+ def test_current_rate_after_one_recorded_point(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ # There is no last time point to do a diff against so return a
+ # current rate of 0.0
+ self.assertEqual(self.rate_tracker.current_rate, 0.0)
+
+ def test_current_rate(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.rate_tracker.record_consumption_rate(1, 2)
+ self.rate_tracker.record_consumption_rate(1, 3)
+ self.assertEqual(self.rate_tracker.current_rate, 0.96)
+
+ def test_get_projected_rate_at_initilizations(self):
+ self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0)
+
+ def test_get_projected_rate(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.rate_tracker.record_consumption_rate(1, 2)
+ projected_rate = self.rate_tracker.get_projected_rate(1, 3)
+ self.assertEqual(projected_rate, 0.96)
+ self.rate_tracker.record_consumption_rate(1, 3)
+ self.assertEqual(self.rate_tracker.current_rate, projected_rate)
+
+ def test_get_projected_rate_for_same_timestamp(self):
+ self.rate_tracker.record_consumption_rate(1, 1)
+ self.assertEqual(
+ self.rate_tracker.get_projected_rate(1, 1), float('inf')
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_compat.py b/contrib/python/s3transfer/py3/tests/unit/test_compat.py
index 8cd9ff1b9d..78fdc25845 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_compat.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_compat.py
@@ -1,105 +1,105 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import shutil
-import signal
-import tempfile
-from io import BytesIO
-
-from s3transfer.compat import BaseManager, readable, seekable
-from __tests__ import skip_if_windows, unittest
-
-
-class ErrorRaisingSeekWrapper:
- """An object wrapper that throws an error when seeked on
-
- :param fileobj: The fileobj that it wraps
- :param exception: The exception to raise when seeked on.
- """
-
- def __init__(self, fileobj, exception):
- self._fileobj = fileobj
- self._exception = exception
-
- def seek(self, offset, whence=0):
- raise self._exception
-
- def tell(self):
- return self._fileobj.tell()
-
-
-class TestSeekable(unittest.TestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'foo')
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def test_seekable_fileobj(self):
- with open(self.filename, 'w') as f:
- self.assertTrue(seekable(f))
-
- def test_non_file_like_obj(self):
- # Fails because there is no seekable(), seek(), nor tell()
- self.assertFalse(seekable(object()))
-
- def test_non_seekable_ioerror(self):
- # Should return False if IOError is thrown.
- with open(self.filename, 'w') as f:
- self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError())))
-
- def test_non_seekable_oserror(self):
- # Should return False if OSError is thrown.
- with open(self.filename, 'w') as f:
- self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError())))
-
-
-class TestReadable(unittest.TestCase):
- def test_readable_fileobj(self):
- with tempfile.TemporaryFile() as f:
- self.assertTrue(readable(f))
-
- def test_readable_file_like_obj(self):
- self.assertTrue(readable(BytesIO()))
-
- def test_non_file_like_obj(self):
- self.assertFalse(readable(object()))
-
-
-class TestBaseManager(unittest.TestCase):
- def create_pid_manager(self):
- class PIDManager(BaseManager):
- pass
-
- PIDManager.register('getpid', os.getpid)
- return PIDManager()
-
- def get_pid(self, pid_manager):
- pid = pid_manager.getpid()
- # A proxy object is returned back. The needed value can be acquired
- # from the repr and converting that to an integer
- return int(str(pid))
-
- @skip_if_windows('os.kill() with SIGINT not supported on Windows')
- def test_can_provide_signal_handler_initializers_to_start(self):
- manager = self.create_pid_manager()
- manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN))
- pid = self.get_pid(manager)
- try:
- os.kill(pid, signal.SIGINT)
- except KeyboardInterrupt:
- pass
- # Try using the manager after the os.kill on the parent process. The
- # manager should not have died and should still be usable.
- self.assertEqual(pid, self.get_pid(manager))
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import signal
+import tempfile
+from io import BytesIO
+
+from s3transfer.compat import BaseManager, readable, seekable
+from __tests__ import skip_if_windows, unittest
+
+
+class ErrorRaisingSeekWrapper:
+ """An object wrapper that throws an error when seeked on
+
+ :param fileobj: The fileobj that it wraps
+ :param exception: The exception to raise when seeked on.
+ """
+
+ def __init__(self, fileobj, exception):
+ self._fileobj = fileobj
+ self._exception = exception
+
+ def seek(self, offset, whence=0):
+ raise self._exception
+
+ def tell(self):
+ return self._fileobj.tell()
+
+
+class TestSeekable(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_seekable_fileobj(self):
+ with open(self.filename, 'w') as f:
+ self.assertTrue(seekable(f))
+
+ def test_non_file_like_obj(self):
+ # Fails because there is no seekable(), seek(), nor tell()
+ self.assertFalse(seekable(object()))
+
+ def test_non_seekable_ioerror(self):
+ # Should return False if IOError is thrown.
+ with open(self.filename, 'w') as f:
+ self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, IOError())))
+
+ def test_non_seekable_oserror(self):
+ # Should return False if OSError is thrown.
+ with open(self.filename, 'w') as f:
+ self.assertFalse(seekable(ErrorRaisingSeekWrapper(f, OSError())))
+
+
+class TestReadable(unittest.TestCase):
+ def test_readable_fileobj(self):
+ with tempfile.TemporaryFile() as f:
+ self.assertTrue(readable(f))
+
+ def test_readable_file_like_obj(self):
+ self.assertTrue(readable(BytesIO()))
+
+ def test_non_file_like_obj(self):
+ self.assertFalse(readable(object()))
+
+
+class TestBaseManager(unittest.TestCase):
+ def create_pid_manager(self):
+ class PIDManager(BaseManager):
+ pass
+
+ PIDManager.register('getpid', os.getpid)
+ return PIDManager()
+
+ def get_pid(self, pid_manager):
+ pid = pid_manager.getpid()
+ # A proxy object is returned back. The needed value can be acquired
+ # from the repr and converting that to an integer
+ return int(str(pid))
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_can_provide_signal_handler_initializers_to_start(self):
+ manager = self.create_pid_manager()
+ manager.start(signal.signal, (signal.SIGINT, signal.SIG_IGN))
+ pid = self.get_pid(manager)
+ try:
+ os.kill(pid, signal.SIGINT)
+ except KeyboardInterrupt:
+ pass
+ # Try using the manager after the os.kill on the parent process. The
+ # manager should not have died and should still be usable.
+ self.assertEqual(pid, self.get_pid(manager))
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_copies.py b/contrib/python/s3transfer/py3/tests/unit/test_copies.py
index 9d43284382..3681f69b94 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_copies.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_copies.py
@@ -1,177 +1,177 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from s3transfer.copies import CopyObjectTask, CopyPartTask
-from __tests__ import BaseTaskTest, RecordingSubscriber
-
-
-class BaseCopyTaskTest(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
- self.extra_args = {}
- self.callbacks = []
- self.size = 5
-
-
-class TestCopyObjectTask(BaseCopyTaskTest):
- def get_copy_task(self, **kwargs):
- default_kwargs = {
- 'client': self.client,
- 'copy_source': self.copy_source,
- 'bucket': self.bucket,
- 'key': self.key,
- 'extra_args': self.extra_args,
- 'callbacks': self.callbacks,
- 'size': self.size,
- }
- default_kwargs.update(kwargs)
- return self.get_task(CopyObjectTask, main_kwargs=default_kwargs)
-
- def test_main(self):
- self.stubber.add_response(
- 'copy_object',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- },
- )
- task = self.get_copy_task()
- task()
-
- self.stubber.assert_no_pending_responses()
-
- def test_extra_args(self):
- self.extra_args['ACL'] = 'private'
- self.stubber.add_response(
- 'copy_object',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'ACL': 'private',
- },
- )
- task = self.get_copy_task()
- task()
-
- self.stubber.assert_no_pending_responses()
-
- def test_callbacks_invoked(self):
- subscriber = RecordingSubscriber()
- self.callbacks.append(subscriber.on_progress)
- self.stubber.add_response(
- 'copy_object',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- },
- )
- task = self.get_copy_task()
- task()
-
- self.stubber.assert_no_pending_responses()
- self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
-
-
-class TestCopyPartTask(BaseCopyTaskTest):
- def setUp(self):
- super().setUp()
- self.copy_source_range = 'bytes=5-9'
- self.extra_args['CopySourceRange'] = self.copy_source_range
- self.upload_id = 'myuploadid'
- self.part_number = 1
- self.result_etag = 'my-etag'
-
- def get_copy_task(self, **kwargs):
- default_kwargs = {
- 'client': self.client,
- 'copy_source': self.copy_source,
- 'bucket': self.bucket,
- 'key': self.key,
- 'upload_id': self.upload_id,
- 'part_number': self.part_number,
- 'extra_args': self.extra_args,
- 'callbacks': self.callbacks,
- 'size': self.size,
- }
- default_kwargs.update(kwargs)
- return self.get_task(CopyPartTask, main_kwargs=default_kwargs)
-
- def test_main(self):
- self.stubber.add_response(
- 'upload_part_copy',
- service_response={'CopyPartResult': {'ETag': self.result_etag}},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'UploadId': self.upload_id,
- 'PartNumber': self.part_number,
- 'CopySourceRange': self.copy_source_range,
- },
- )
- task = self.get_copy_task()
- self.assertEqual(
- task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
- )
- self.stubber.assert_no_pending_responses()
-
- def test_extra_args(self):
- self.extra_args['RequestPayer'] = 'requester'
- self.stubber.add_response(
- 'upload_part_copy',
- service_response={'CopyPartResult': {'ETag': self.result_etag}},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'UploadId': self.upload_id,
- 'PartNumber': self.part_number,
- 'CopySourceRange': self.copy_source_range,
- 'RequestPayer': 'requester',
- },
- )
- task = self.get_copy_task()
- self.assertEqual(
- task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
- )
- self.stubber.assert_no_pending_responses()
-
- def test_callbacks_invoked(self):
- subscriber = RecordingSubscriber()
- self.callbacks.append(subscriber.on_progress)
- self.stubber.add_response(
- 'upload_part_copy',
- service_response={'CopyPartResult': {'ETag': self.result_etag}},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'CopySource': self.copy_source,
- 'UploadId': self.upload_id,
- 'PartNumber': self.part_number,
- 'CopySourceRange': self.copy_source_range,
- },
- )
- task = self.get_copy_task()
- self.assertEqual(
- task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
- )
- self.stubber.assert_no_pending_responses()
- self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.copies import CopyObjectTask, CopyPartTask
+from __tests__ import BaseTaskTest, RecordingSubscriber
+
+
+class BaseCopyTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.copy_source = {'Bucket': 'mysourcebucket', 'Key': 'mysourcekey'}
+ self.extra_args = {}
+ self.callbacks = []
+ self.size = 5
+
+
+class TestCopyObjectTask(BaseCopyTaskTest):
+ def get_copy_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'size': self.size,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(CopyObjectTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['ACL'] = 'private'
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'ACL': 'private',
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_callbacks_invoked(self):
+ subscriber = RecordingSubscriber()
+ self.callbacks.append(subscriber.on_progress)
+ self.stubber.add_response(
+ 'copy_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ },
+ )
+ task = self.get_copy_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
+
+
+class TestCopyPartTask(BaseCopyTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.copy_source_range = 'bytes=5-9'
+ self.extra_args['CopySourceRange'] = self.copy_source_range
+ self.upload_id = 'myuploadid'
+ self.part_number = 1
+ self.result_etag = 'my-etag'
+
+ def get_copy_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'copy_source': self.copy_source,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': self.upload_id,
+ 'part_number': self.part_number,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'size': self.size,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(CopyPartTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['RequestPayer'] = 'requester'
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ 'RequestPayer': 'requester',
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+
+ def test_callbacks_invoked(self):
+ subscriber = RecordingSubscriber()
+ self.callbacks.append(subscriber.on_progress)
+ self.stubber.add_response(
+ 'upload_part_copy',
+ service_response={'CopyPartResult': {'ETag': self.result_etag}},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'CopySource': self.copy_source,
+ 'UploadId': self.upload_id,
+ 'PartNumber': self.part_number,
+ 'CopySourceRange': self.copy_source_range,
+ },
+ )
+ task = self.get_copy_task()
+ self.assertEqual(
+ task(), {'PartNumber': self.part_number, 'ETag': self.result_etag}
+ )
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(subscriber.calculate_bytes_seen(), self.size)
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_crt.py b/contrib/python/s3transfer/py3/tests/unit/test_crt.py
index 923b0f4c66..8c32668eab 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_crt.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_crt.py
@@ -1,173 +1,173 @@
-# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from botocore.credentials import CredentialResolver, ReadOnlyCredentials
-from botocore.session import Session
-
-from s3transfer.exceptions import TransferNotDoneError
-from s3transfer.utils import CallArgs
-from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
-
-if HAS_CRT:
- import awscrt.s3
-
- import s3transfer.crt
-
-
-class CustomFutureException(Exception):
- pass
-
-
-@requires_crt
-class TestBotocoreCRTRequestSerializer(unittest.TestCase):
- def setUp(self):
- self.region = 'us-west-2'
- self.session = Session()
- self.session.set_config_variable('region', self.region)
- self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
- self.session
- )
- self.bucket = "test_bucket"
- self.key = "test_key"
- self.files = FileCreator()
- self.filename = self.files.create_file('myfile', 'my content')
- self.expected_path = "/" + self.bucket + "/" + self.key
- self.expected_host = "s3.%s.amazonaws.com" % (self.region)
-
- def tearDown(self):
- self.files.remove_all()
-
- def test_upload_request(self):
- callargs = CallArgs(
- bucket=self.bucket,
- key=self.key,
- fileobj=self.filename,
- extra_args={},
- subscribers=[],
- )
- coordinator = s3transfer.crt.CRTTransferCoordinator()
- future = s3transfer.crt.CRTTransferFuture(
- s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
- )
- crt_request = self.request_serializer.serialize_http_request(
- "put_object", future
- )
- self.assertEqual("PUT", crt_request.method)
- self.assertEqual(self.expected_path, crt_request.path)
- self.assertEqual(self.expected_host, crt_request.headers.get("host"))
- self.assertIsNone(crt_request.headers.get("Authorization"))
-
- def test_download_request(self):
- callargs = CallArgs(
- bucket=self.bucket,
- key=self.key,
- fileobj=self.filename,
- extra_args={},
- subscribers=[],
- )
- coordinator = s3transfer.crt.CRTTransferCoordinator()
- future = s3transfer.crt.CRTTransferFuture(
- s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
- )
- crt_request = self.request_serializer.serialize_http_request(
- "get_object", future
- )
- self.assertEqual("GET", crt_request.method)
- self.assertEqual(self.expected_path, crt_request.path)
- self.assertEqual(self.expected_host, crt_request.headers.get("host"))
- self.assertIsNone(crt_request.headers.get("Authorization"))
-
- def test_delete_request(self):
- callargs = CallArgs(
- bucket=self.bucket, key=self.key, extra_args={}, subscribers=[]
- )
- coordinator = s3transfer.crt.CRTTransferCoordinator()
- future = s3transfer.crt.CRTTransferFuture(
- s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
- )
- crt_request = self.request_serializer.serialize_http_request(
- "delete_object", future
- )
- self.assertEqual("DELETE", crt_request.method)
- self.assertEqual(self.expected_path, crt_request.path)
- self.assertEqual(self.expected_host, crt_request.headers.get("host"))
- self.assertIsNone(crt_request.headers.get("Authorization"))
-
-
-@requires_crt
-class TestCRTCredentialProviderAdapter(unittest.TestCase):
- def setUp(self):
- self.botocore_credential_provider = mock.Mock(CredentialResolver)
- self.access_key = "access_key"
- self.secret_key = "secret_key"
- self.token = "token"
- self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials(
- self.access_key, self.secret_key, self.token
- )
-
- def _call_adapter_and_check(self, credentails_provider_adapter):
- credentials = credentails_provider_adapter()
- self.assertEqual(credentials.access_key_id, self.access_key)
- self.assertEqual(credentials.secret_access_key, self.secret_key)
- self.assertEqual(credentials.session_token, self.token)
-
- def test_fetch_crt_credentials_successfully(self):
- credentails_provider_adapter = (
- s3transfer.crt.CRTCredentialProviderAdapter(
- self.botocore_credential_provider
- )
- )
- self._call_adapter_and_check(credentails_provider_adapter)
-
- def test_load_credentials_once(self):
- credentails_provider_adapter = (
- s3transfer.crt.CRTCredentialProviderAdapter(
- self.botocore_credential_provider
- )
- )
- called_times = 5
- for i in range(called_times):
- self._call_adapter_and_check(credentails_provider_adapter)
- # Assert that the load_credentails of botocore credential provider
- # will only be called once
- self.assertEqual(
- self.botocore_credential_provider.load_credentials.call_count, 1
- )
-
-
-@requires_crt
-class TestCRTTransferFuture(unittest.TestCase):
- def setUp(self):
- self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType)
- self.mock_crt_future = mock.Mock(awscrt.s3.Future)
- self.mock_s3_request.finished_future = self.mock_crt_future
- self.coordinator = s3transfer.crt.CRTTransferCoordinator()
- self.coordinator.set_s3_request(self.mock_s3_request)
- self.future = s3transfer.crt.CRTTransferFuture(
- coordinator=self.coordinator
- )
-
- def test_set_exception(self):
- self.future.set_exception(CustomFutureException())
- with self.assertRaises(CustomFutureException):
- self.future.result()
-
- def test_set_exception_raises_error_when_not_done(self):
- self.mock_crt_future.done.return_value = False
- with self.assertRaises(TransferNotDoneError):
- self.future.set_exception(CustomFutureException())
-
- def test_set_exception_can_override_previous_exception(self):
- self.future.set_exception(Exception())
- self.future.set_exception(CustomFutureException())
- with self.assertRaises(CustomFutureException):
- self.future.result()
+# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from botocore.credentials import CredentialResolver, ReadOnlyCredentials
+from botocore.session import Session
+
+from s3transfer.exceptions import TransferNotDoneError
+from s3transfer.utils import CallArgs
+from __tests__ import HAS_CRT, FileCreator, mock, requires_crt, unittest
+
+if HAS_CRT:
+ import awscrt.s3
+
+ import s3transfer.crt
+
+
+class CustomFutureException(Exception):
+ pass
+
+
+@requires_crt
+class TestBotocoreCRTRequestSerializer(unittest.TestCase):
+ def setUp(self):
+ self.region = 'us-west-2'
+ self.session = Session()
+ self.session.set_config_variable('region', self.region)
+ self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
+ self.session
+ )
+ self.bucket = "test_bucket"
+ self.key = "test_key"
+ self.files = FileCreator()
+ self.filename = self.files.create_file('myfile', 'my content')
+ self.expected_path = "/" + self.bucket + "/" + self.key
+ self.expected_host = "s3.%s.amazonaws.com" % (self.region)
+
+ def tearDown(self):
+ self.files.remove_all()
+
+ def test_upload_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket,
+ key=self.key,
+ fileobj=self.filename,
+ extra_args={},
+ subscribers=[],
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "put_object", future
+ )
+ self.assertEqual("PUT", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+ def test_download_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket,
+ key=self.key,
+ fileobj=self.filename,
+ extra_args={},
+ subscribers=[],
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "get_object", future
+ )
+ self.assertEqual("GET", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+ def test_delete_request(self):
+ callargs = CallArgs(
+ bucket=self.bucket, key=self.key, extra_args={}, subscribers=[]
+ )
+ coordinator = s3transfer.crt.CRTTransferCoordinator()
+ future = s3transfer.crt.CRTTransferFuture(
+ s3transfer.crt.CRTTransferMeta(call_args=callargs), coordinator
+ )
+ crt_request = self.request_serializer.serialize_http_request(
+ "delete_object", future
+ )
+ self.assertEqual("DELETE", crt_request.method)
+ self.assertEqual(self.expected_path, crt_request.path)
+ self.assertEqual(self.expected_host, crt_request.headers.get("host"))
+ self.assertIsNone(crt_request.headers.get("Authorization"))
+
+
+@requires_crt
+class TestCRTCredentialProviderAdapter(unittest.TestCase):
+ def setUp(self):
+ self.botocore_credential_provider = mock.Mock(CredentialResolver)
+ self.access_key = "access_key"
+ self.secret_key = "secret_key"
+ self.token = "token"
+ self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials(
+ self.access_key, self.secret_key, self.token
+ )
+
+ def _call_adapter_and_check(self, credentails_provider_adapter):
+ credentials = credentails_provider_adapter()
+ self.assertEqual(credentials.access_key_id, self.access_key)
+ self.assertEqual(credentials.secret_access_key, self.secret_key)
+ self.assertEqual(credentials.session_token, self.token)
+
+ def test_fetch_crt_credentials_successfully(self):
+ credentails_provider_adapter = (
+ s3transfer.crt.CRTCredentialProviderAdapter(
+ self.botocore_credential_provider
+ )
+ )
+ self._call_adapter_and_check(credentails_provider_adapter)
+
+ def test_load_credentials_once(self):
+ credentails_provider_adapter = (
+ s3transfer.crt.CRTCredentialProviderAdapter(
+ self.botocore_credential_provider
+ )
+ )
+ called_times = 5
+ for i in range(called_times):
+ self._call_adapter_and_check(credentails_provider_adapter)
+ # Assert that the load_credentails of botocore credential provider
+ # will only be called once
+ self.assertEqual(
+ self.botocore_credential_provider.load_credentials.call_count, 1
+ )
+
+
+@requires_crt
+class TestCRTTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.mock_s3_request = mock.Mock(awscrt.s3.S3RequestType)
+ self.mock_crt_future = mock.Mock(awscrt.s3.Future)
+ self.mock_s3_request.finished_future = self.mock_crt_future
+ self.coordinator = s3transfer.crt.CRTTransferCoordinator()
+ self.coordinator.set_s3_request(self.mock_s3_request)
+ self.future = s3transfer.crt.CRTTransferFuture(
+ coordinator=self.coordinator
+ )
+
+ def test_set_exception(self):
+ self.future.set_exception(CustomFutureException())
+ with self.assertRaises(CustomFutureException):
+ self.future.result()
+
+ def test_set_exception_raises_error_when_not_done(self):
+ self.mock_crt_future.done.return_value = False
+ with self.assertRaises(TransferNotDoneError):
+ self.future.set_exception(CustomFutureException())
+
+ def test_set_exception_can_override_previous_exception(self):
+ self.future.set_exception(Exception())
+ self.future.set_exception(CustomFutureException())
+ with self.assertRaises(CustomFutureException):
+ self.future.result()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_delete.py b/contrib/python/s3transfer/py3/tests/unit/test_delete.py
index 5f960adbc5..23b77112f2 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_delete.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_delete.py
@@ -1,67 +1,67 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from s3transfer.delete import DeleteObjectTask
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.delete import DeleteObjectTask
from __tests__ import BaseTaskTest
-
-
-class TestDeleteObjectTask(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.extra_args = {}
- self.callbacks = []
-
- def get_delete_task(self, **kwargs):
- default_kwargs = {
- 'client': self.client,
- 'bucket': self.bucket,
- 'key': self.key,
- 'extra_args': self.extra_args,
- }
- default_kwargs.update(kwargs)
- return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs)
-
- def test_main(self):
- self.stubber.add_response(
- 'delete_object',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- },
- )
- task = self.get_delete_task()
- task()
-
- self.stubber.assert_no_pending_responses()
-
- def test_extra_args(self):
- self.extra_args['MFA'] = 'mfa-code'
- self.extra_args['VersionId'] = '12345'
- self.stubber.add_response(
- 'delete_object',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- # These extra_args should be injected into the
- # expected params for the delete_object call.
- 'MFA': 'mfa-code',
- 'VersionId': '12345',
- },
- )
- task = self.get_delete_task()
- task()
-
- self.stubber.assert_no_pending_responses()
+
+
+class TestDeleteObjectTask(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.callbacks = []
+
+ def get_delete_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ }
+ default_kwargs.update(kwargs)
+ return self.get_task(DeleteObjectTask, main_kwargs=default_kwargs)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'delete_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ },
+ )
+ task = self.get_delete_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+
+ def test_extra_args(self):
+ self.extra_args['MFA'] = 'mfa-code'
+ self.extra_args['VersionId'] = '12345'
+ self.stubber.add_response(
+ 'delete_object',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ # These extra_args should be injected into the
+ # expected params for the delete_object call.
+ 'MFA': 'mfa-code',
+ 'VersionId': '12345',
+ },
+ )
+ task = self.get_delete_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_download.py b/contrib/python/s3transfer/py3/tests/unit/test_download.py
index 46f346347c..2bd095f867 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_download.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_download.py
@@ -1,999 +1,999 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import copy
-import os
-import shutil
-import tempfile
-from io import BytesIO
-
-from s3transfer.bandwidth import BandwidthLimiter
-from s3transfer.compat import SOCKET_ERROR
-from s3transfer.download import (
- CompleteDownloadNOOPTask,
- DeferQueue,
- DownloadChunkIterator,
- DownloadFilenameOutputManager,
- DownloadNonSeekableOutputManager,
- DownloadSeekableOutputManager,
- DownloadSpecialFilenameOutputManager,
- DownloadSubmissionTask,
- GetObjectTask,
- ImmediatelyWriteIOGetObjectTask,
- IOCloseTask,
- IORenameFileTask,
- IOStreamingWriteTask,
- IOWriteTask,
-)
-from s3transfer.exceptions import RetriesExceededError
-from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
-from s3transfer.utils import CallArgs, OSUtils
-from __tests__ import (
- BaseSubmissionTaskTest,
- BaseTaskTest,
- FileCreator,
- NonSeekableWriter,
- RecordingExecutor,
- StreamWithError,
- mock,
- unittest,
-)
-
-
-class DownloadException(Exception):
- pass
-
-
-class WriteCollector:
- """A utility to collect information about writes and seeks"""
-
- def __init__(self):
- self._pos = 0
- self.writes = []
-
- def seek(self, pos, whence=0):
- self._pos = pos
-
- def write(self, data):
- self.writes.append((self._pos, data))
- self._pos += len(data)
-
-
-class AlwaysIndicatesSpecialFileOSUtils(OSUtils):
- """OSUtil that always returns True for is_special_file"""
-
- def is_special_file(self, filename):
- return True
-
-
-class CancelledStreamWrapper:
- """A wrapper to trigger a cancellation while stream reading
-
- Forces the transfer coordinator to cancel after a certain amount of reads
- :param stream: The underlying stream to read from
- :param transfer_coordinator: The coordinator for the transfer
- :param num_reads: On which read to signal a cancellation. 0 is the first
- read.
- """
-
- def __init__(self, stream, transfer_coordinator, num_reads=0):
- self._stream = stream
- self._transfer_coordinator = transfer_coordinator
- self._num_reads = num_reads
- self._count = 0
-
- def read(self, *args, **kwargs):
- if self._num_reads == self._count:
- self._transfer_coordinator.cancel()
- self._stream.read(*args, **kwargs)
- self._count += 1
-
-
-class BaseDownloadOutputManagerTest(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.osutil = OSUtils()
-
- # Create a file to write to
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
-
- self.call_args = CallArgs(fileobj=self.filename)
- self.future = self.get_transfer_future(self.call_args)
- self.io_executor = BoundedExecutor(1000, 1)
-
- def tearDown(self):
- super().tearDown()
- shutil.rmtree(self.tempdir)
-
-
-class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest):
- def setUp(self):
- super().setUp()
- self.download_output_manager = DownloadFilenameOutputManager(
- self.osutil,
- self.transfer_coordinator,
- io_executor=self.io_executor,
- )
-
- def test_is_compatible(self):
- self.assertTrue(
- self.download_output_manager.is_compatible(
- self.filename, self.osutil
- )
- )
-
- def test_get_download_task_tag(self):
- self.assertIsNone(self.download_output_manager.get_download_task_tag())
-
- def test_get_fileobj_for_io_writes(self):
- with self.download_output_manager.get_fileobj_for_io_writes(
- self.future
- ) as f:
- # Ensure it is a file like object returned
- self.assertTrue(hasattr(f, 'read'))
- self.assertTrue(hasattr(f, 'seek'))
- # Make sure the name of the file returned is not the same as the
- # final filename as we should be writing to a temporary file.
- self.assertNotEqual(f.name, self.filename)
-
- def test_get_final_io_task(self):
- ref_contents = b'my_contents'
- with self.download_output_manager.get_fileobj_for_io_writes(
- self.future
- ) as f:
- temp_filename = f.name
- # Write some data to test that the data gets moved over to the
- # final location.
- f.write(ref_contents)
- final_task = self.download_output_manager.get_final_io_task()
- # Make sure it is the appropriate task.
- self.assertIsInstance(final_task, IORenameFileTask)
- final_task()
- # Make sure the temp_file gets removed
- self.assertFalse(os.path.exists(temp_filename))
- # Make sure what ever was written to the temp file got moved to
- # the final filename
- with open(self.filename, 'rb') as f:
- self.assertEqual(f.read(), ref_contents)
-
- def test_can_queue_file_io_task(self):
- fileobj = WriteCollector()
- self.download_output_manager.queue_file_io_task(
- fileobj=fileobj, data='foo', offset=0
- )
- self.download_output_manager.queue_file_io_task(
- fileobj=fileobj, data='bar', offset=3
- )
- self.io_executor.shutdown()
- self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
-
- def test_get_file_io_write_task(self):
- fileobj = WriteCollector()
- io_write_task = self.download_output_manager.get_io_write_task(
- fileobj=fileobj, data='foo', offset=3
- )
- self.assertIsInstance(io_write_task, IOWriteTask)
-
- io_write_task()
- self.assertEqual(fileobj.writes, [(3, 'foo')])
-
-
-class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest):
- def setUp(self):
- super().setUp()
- self.osutil = AlwaysIndicatesSpecialFileOSUtils()
- self.download_output_manager = DownloadSpecialFilenameOutputManager(
- self.osutil,
- self.transfer_coordinator,
- io_executor=self.io_executor,
- )
-
- def test_is_compatible_for_special_file(self):
- self.assertTrue(
- self.download_output_manager.is_compatible(
- self.filename, AlwaysIndicatesSpecialFileOSUtils()
- )
- )
-
- def test_is_not_compatible_for_non_special_file(self):
- self.assertFalse(
- self.download_output_manager.is_compatible(
- self.filename, OSUtils()
- )
- )
-
- def test_get_fileobj_for_io_writes(self):
- with self.download_output_manager.get_fileobj_for_io_writes(
- self.future
- ) as f:
- # Ensure it is a file like object returned
- self.assertTrue(hasattr(f, 'read'))
- # Make sure the name of the file returned is the same as the
- # final filename as we should not be writing to a temporary file.
- self.assertEqual(f.name, self.filename)
-
- def test_get_final_io_task(self):
- self.assertIsInstance(
- self.download_output_manager.get_final_io_task(), IOCloseTask
- )
-
- def test_can_queue_file_io_task(self):
- fileobj = WriteCollector()
- self.download_output_manager.queue_file_io_task(
- fileobj=fileobj, data='foo', offset=0
- )
- self.download_output_manager.queue_file_io_task(
- fileobj=fileobj, data='bar', offset=3
- )
- self.io_executor.shutdown()
- self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
-
-
-class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest):
- def setUp(self):
- super().setUp()
- self.download_output_manager = DownloadSeekableOutputManager(
- self.osutil,
- self.transfer_coordinator,
- io_executor=self.io_executor,
- )
-
- # Create a fileobj to write to
- self.fileobj = open(self.filename, 'wb')
-
- self.call_args = CallArgs(fileobj=self.fileobj)
- self.future = self.get_transfer_future(self.call_args)
-
- def tearDown(self):
- self.fileobj.close()
- super().tearDown()
-
- def test_is_compatible(self):
- self.assertTrue(
- self.download_output_manager.is_compatible(
- self.fileobj, self.osutil
- )
- )
-
- def test_is_compatible_bytes_io(self):
- self.assertTrue(
- self.download_output_manager.is_compatible(BytesIO(), self.osutil)
- )
-
- def test_not_compatible_for_non_filelike_obj(self):
- self.assertFalse(
- self.download_output_manager.is_compatible(object(), self.osutil)
- )
-
- def test_get_download_task_tag(self):
- self.assertIsNone(self.download_output_manager.get_download_task_tag())
-
- def test_get_fileobj_for_io_writes(self):
- self.assertIs(
- self.download_output_manager.get_fileobj_for_io_writes(
- self.future
- ),
- self.fileobj,
- )
-
- def test_get_final_io_task(self):
- self.assertIsInstance(
- self.download_output_manager.get_final_io_task(),
- CompleteDownloadNOOPTask,
- )
-
- def test_can_queue_file_io_task(self):
- fileobj = WriteCollector()
- self.download_output_manager.queue_file_io_task(
- fileobj=fileobj, data='foo', offset=0
- )
- self.download_output_manager.queue_file_io_task(
- fileobj=fileobj, data='bar', offset=3
- )
- self.io_executor.shutdown()
- self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
-
- def test_get_file_io_write_task(self):
- fileobj = WriteCollector()
- io_write_task = self.download_output_manager.get_io_write_task(
- fileobj=fileobj, data='foo', offset=3
- )
- self.assertIsInstance(io_write_task, IOWriteTask)
-
- io_write_task()
- self.assertEqual(fileobj.writes, [(3, 'foo')])
-
-
-class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest):
- def setUp(self):
- super().setUp()
- self.download_output_manager = DownloadNonSeekableOutputManager(
- self.osutil, self.transfer_coordinator, io_executor=None
- )
-
- def test_is_compatible_with_seekable_stream(self):
- with open(self.filename, 'wb') as f:
- self.assertTrue(
- self.download_output_manager.is_compatible(f, self.osutil)
- )
-
- def test_not_compatible_with_filename(self):
- self.assertFalse(
- self.download_output_manager.is_compatible(
- self.filename, self.osutil
- )
- )
-
- def test_compatible_with_non_seekable_stream(self):
- class NonSeekable:
- def write(self, data):
- pass
-
- f = NonSeekable()
- self.assertTrue(
- self.download_output_manager.is_compatible(f, self.osutil)
- )
-
- def test_is_compatible_with_bytesio(self):
- self.assertTrue(
- self.download_output_manager.is_compatible(BytesIO(), self.osutil)
- )
-
- def test_get_download_task_tag(self):
- self.assertIs(
- self.download_output_manager.get_download_task_tag(),
- IN_MEMORY_DOWNLOAD_TAG,
- )
-
- def test_submit_writes_from_internal_queue(self):
- class FakeQueue:
- def request_writes(self, offset, data):
- return [
- {'offset': 0, 'data': 'foo'},
- {'offset': 3, 'data': 'bar'},
- ]
-
- q = FakeQueue()
- io_executor = BoundedExecutor(1000, 1)
- manager = DownloadNonSeekableOutputManager(
- self.osutil,
- self.transfer_coordinator,
- io_executor=io_executor,
- defer_queue=q,
- )
- fileobj = WriteCollector()
- manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1)
- io_executor.shutdown()
- self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
-
- def test_get_file_io_write_task(self):
- fileobj = WriteCollector()
- io_write_task = self.download_output_manager.get_io_write_task(
- fileobj=fileobj, data='foo', offset=1
- )
- self.assertIsInstance(io_write_task, IOStreamingWriteTask)
-
- io_write_task()
- self.assertEqual(fileobj.writes, [(0, 'foo')])
-
-
-class TestDownloadSubmissionTask(BaseSubmissionTaskTest):
- def setUp(self):
- super().setUp()
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
-
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.extra_args = {}
- self.subscribers = []
-
- # Create a stream to read from
- self.content = b'my content'
- self.stream = BytesIO(self.content)
-
- # A list to keep track of all of the bodies sent over the wire
- # and their order.
-
- self.call_args = self.get_call_args()
- self.transfer_future = self.get_transfer_future(self.call_args)
- self.io_executor = BoundedExecutor(1000, 1)
- self.submission_main_kwargs = {
- 'client': self.client,
- 'config': self.config,
- 'osutil': self.osutil,
- 'request_executor': self.executor,
- 'io_executor': self.io_executor,
- 'transfer_future': self.transfer_future,
- }
- self.submission_task = self.get_download_submission_task()
-
- def tearDown(self):
- super().tearDown()
- shutil.rmtree(self.tempdir)
-
- def get_call_args(self, **kwargs):
- default_call_args = {
- 'fileobj': self.filename,
- 'bucket': self.bucket,
- 'key': self.key,
- 'extra_args': self.extra_args,
- 'subscribers': self.subscribers,
- }
- default_call_args.update(kwargs)
- return CallArgs(**default_call_args)
-
- def wrap_executor_in_recorder(self):
- self.executor = RecordingExecutor(self.executor)
- self.submission_main_kwargs['request_executor'] = self.executor
-
- def use_fileobj_in_call_args(self, fileobj):
- self.call_args = self.get_call_args(fileobj=fileobj)
- self.transfer_future = self.get_transfer_future(self.call_args)
- self.submission_main_kwargs['transfer_future'] = self.transfer_future
-
- def assert_tag_for_get_object(self, tag_value):
- submissions_to_compare = self.executor.submissions
- if len(submissions_to_compare) > 1:
- # If it was ranged get, make sure we do not include the join task.
- submissions_to_compare = submissions_to_compare[:-1]
- for submission in submissions_to_compare:
- self.assertEqual(submission['tag'], tag_value)
-
- def add_head_object_response(self):
- self.stubber.add_response(
- 'head_object', {'ContentLength': len(self.content)}
- )
-
- def add_get_responses(self):
- chunksize = self.config.multipart_chunksize
- for i in range(0, len(self.content), chunksize):
- if i + chunksize > len(self.content):
- stream = BytesIO(self.content[i:])
- self.stubber.add_response('get_object', {'Body': stream})
- else:
- stream = BytesIO(self.content[i : i + chunksize])
- self.stubber.add_response('get_object', {'Body': stream})
-
- def configure_for_ranged_get(self):
- self.config.multipart_threshold = 1
- self.config.multipart_chunksize = 4
-
- def get_download_submission_task(self):
- return self.get_task(
- DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
-
- def wait_and_assert_completed_successfully(self, submission_task):
- submission_task()
- self.transfer_future.result()
- self.stubber.assert_no_pending_responses()
-
- def test_submits_no_tag_for_get_object_filename(self):
- self.wrap_executor_in_recorder()
- self.add_head_object_response()
- self.add_get_responses()
-
- self.submission_task = self.get_download_submission_task()
- self.wait_and_assert_completed_successfully(self.submission_task)
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_for_get_object(None)
-
- def test_submits_no_tag_for_ranged_get_filename(self):
- self.wrap_executor_in_recorder()
- self.configure_for_ranged_get()
- self.add_head_object_response()
- self.add_get_responses()
-
- self.submission_task = self.get_download_submission_task()
- self.wait_and_assert_completed_successfully(self.submission_task)
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_for_get_object(None)
-
- def test_submits_no_tag_for_get_object_fileobj(self):
- self.wrap_executor_in_recorder()
- self.add_head_object_response()
- self.add_get_responses()
-
- with open(self.filename, 'wb') as f:
- self.use_fileobj_in_call_args(f)
- self.submission_task = self.get_download_submission_task()
- self.wait_and_assert_completed_successfully(self.submission_task)
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_for_get_object(None)
-
- def test_submits_no_tag_for_ranged_get_object_fileobj(self):
- self.wrap_executor_in_recorder()
- self.configure_for_ranged_get()
- self.add_head_object_response()
- self.add_get_responses()
-
- with open(self.filename, 'wb') as f:
- self.use_fileobj_in_call_args(f)
- self.submission_task = self.get_download_submission_task()
- self.wait_and_assert_completed_successfully(self.submission_task)
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_for_get_object(None)
-
- def tests_submits_tag_for_get_object_nonseekable_fileobj(self):
- self.wrap_executor_in_recorder()
- self.add_head_object_response()
- self.add_get_responses()
-
- with open(self.filename, 'wb') as f:
- self.use_fileobj_in_call_args(NonSeekableWriter(f))
- self.submission_task = self.get_download_submission_task()
- self.wait_and_assert_completed_successfully(self.submission_task)
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
-
- def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self):
- self.wrap_executor_in_recorder()
- self.configure_for_ranged_get()
- self.add_head_object_response()
- self.add_get_responses()
-
- with open(self.filename, 'wb') as f:
- self.use_fileobj_in_call_args(NonSeekableWriter(f))
- self.submission_task = self.get_download_submission_task()
- self.wait_and_assert_completed_successfully(self.submission_task)
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
-
-
-class TestGetObjectTask(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.extra_args = {}
- self.callbacks = []
- self.max_attempts = 5
- self.io_executor = BoundedExecutor(1000, 1)
- self.content = b'my content'
- self.stream = BytesIO(self.content)
- self.fileobj = WriteCollector()
- self.osutil = OSUtils()
- self.io_chunksize = 64 * (1024 ** 2)
- self.task_cls = GetObjectTask
- self.download_output_manager = DownloadSeekableOutputManager(
- self.osutil, self.transfer_coordinator, self.io_executor
- )
-
- def get_download_task(self, **kwargs):
- default_kwargs = {
- 'client': self.client,
- 'bucket': self.bucket,
- 'key': self.key,
- 'fileobj': self.fileobj,
- 'extra_args': self.extra_args,
- 'callbacks': self.callbacks,
- 'max_attempts': self.max_attempts,
- 'download_output_manager': self.download_output_manager,
- 'io_chunksize': self.io_chunksize,
- }
- default_kwargs.update(kwargs)
- self.transfer_coordinator.set_status_to_queued()
- return self.get_task(self.task_cls, main_kwargs=default_kwargs)
-
- def assert_io_writes(self, expected_writes):
- # Let the io executor process all of the writes before checking
- # what writes were sent to it.
- self.io_executor.shutdown()
- self.assertEqual(self.fileobj.writes, expected_writes)
-
- def test_main(self):
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task()
- task()
-
- self.stubber.assert_no_pending_responses()
- self.assert_io_writes([(0, self.content)])
-
- def test_extra_args(self):
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'Range': 'bytes=0-',
- },
- )
- self.extra_args['Range'] = 'bytes=0-'
- task = self.get_download_task()
- task()
-
- self.stubber.assert_no_pending_responses()
- self.assert_io_writes([(0, self.content)])
-
- def test_control_chunk_size(self):
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task(io_chunksize=1)
- task()
-
- self.stubber.assert_no_pending_responses()
- expected_contents = []
- for i in range(len(self.content)):
- expected_contents.append((i, bytes(self.content[i : i + 1])))
-
- self.assert_io_writes(expected_contents)
-
- def test_start_index(self):
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task(start_index=5)
- task()
-
- self.stubber.assert_no_pending_responses()
- self.assert_io_writes([(5, self.content)])
-
- def test_uses_bandwidth_limiter(self):
- bandwidth_limiter = mock.Mock(BandwidthLimiter)
-
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task(bandwidth_limiter=bandwidth_limiter)
- task()
-
- self.stubber.assert_no_pending_responses()
- self.assertEqual(
- bandwidth_limiter.get_bandwith_limited_stream.call_args_list,
- [mock.call(mock.ANY, self.transfer_coordinator)],
- )
-
- def test_retries_succeeds(self):
- self.stubber.add_response(
- 'get_object',
- service_response={
- 'Body': StreamWithError(self.stream, SOCKET_ERROR)
- },
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task()
- task()
-
- # Retryable error should have not affected the bytes placed into
- # the io queue.
- self.stubber.assert_no_pending_responses()
- self.assert_io_writes([(0, self.content)])
-
- def test_retries_failure(self):
- for _ in range(self.max_attempts):
- self.stubber.add_response(
- 'get_object',
- service_response={
- 'Body': StreamWithError(self.stream, SOCKET_ERROR)
- },
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
-
- task = self.get_download_task()
- task()
- self.transfer_coordinator.announce_done()
-
- # Should have failed out on a RetriesExceededError
- with self.assertRaises(RetriesExceededError):
- self.transfer_coordinator.result()
- self.stubber.assert_no_pending_responses()
-
- def test_retries_in_middle_of_streaming(self):
- # After the first read a retryable error will be thrown
- self.stubber.add_response(
- 'get_object',
- service_response={
- 'Body': StreamWithError(
- copy.deepcopy(self.stream), SOCKET_ERROR, 1
- )
- },
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- self.stubber.add_response(
- 'get_object',
- service_response={'Body': self.stream},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task(io_chunksize=1)
- task()
-
- self.stubber.assert_no_pending_responses()
- expected_contents = []
- # This is the content initially read in before the retry hit on the
- # second read()
- expected_contents.append((0, bytes(self.content[0:1])))
-
- # The rest of the content should be the entire set of data partitioned
- # out based on the one byte stream chunk size. Note the second
- # element in the list should be a copy of the first element since
- # a retryable exception happened in between.
- for i in range(len(self.content)):
- expected_contents.append((i, bytes(self.content[i : i + 1])))
- self.assert_io_writes(expected_contents)
-
- def test_cancels_out_of_queueing(self):
- self.stubber.add_response(
- 'get_object',
- service_response={
- 'Body': CancelledStreamWrapper(
- self.stream, self.transfer_coordinator
- )
- },
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- task = self.get_download_task()
- task()
-
- self.stubber.assert_no_pending_responses()
- # Make sure that no contents were added to the queue because the task
- # should have been canceled before trying to add the contents to the
- # io queue.
- self.assert_io_writes([])
-
- def test_handles_callback_on_initial_error(self):
- # We can't use the stubber for this because we need to raise
- # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows
- # you to raise a ClientError.
- self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR())
- task = self.get_download_task()
- task()
- self.transfer_coordinator.announce_done()
- # Should have failed out on a RetriesExceededError because
- # get_object keeps raising a socket error.
- with self.assertRaises(RetriesExceededError):
- self.transfer_coordinator.result()
-
-
-class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask):
- def setUp(self):
- super().setUp()
- self.task_cls = ImmediatelyWriteIOGetObjectTask
- # When data is written out, it should not use the io executor at all
- # if it does use the io executor that is a deviation from expected
- # behavior as the data should be written immediately to the file
- # object once downloaded.
- self.io_executor = None
- self.download_output_manager = DownloadSeekableOutputManager(
- self.osutil, self.transfer_coordinator, self.io_executor
- )
-
- def assert_io_writes(self, expected_writes):
- self.assertEqual(self.fileobj.writes, expected_writes)
-
-
-class BaseIOTaskTest(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.files = FileCreator()
- self.osutil = OSUtils()
- self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile')
- self.final_filename = os.path.join(self.files.rootdir, 'myfile')
-
- def tearDown(self):
- super().tearDown()
- self.files.remove_all()
-
-
-class TestIOStreamingWriteTask(BaseIOTaskTest):
- def test_main(self):
- with open(self.temp_filename, 'wb') as f:
- task = self.get_task(
- IOStreamingWriteTask,
- main_kwargs={'fileobj': f, 'data': b'foobar'},
- )
- task()
- task2 = self.get_task(
- IOStreamingWriteTask,
- main_kwargs={'fileobj': f, 'data': b'baz'},
- )
- task2()
- with open(self.temp_filename, 'rb') as f:
- # We should just have written to the file in the order
- # the tasks were executed.
- self.assertEqual(f.read(), b'foobarbaz')
-
-
-class TestIOWriteTask(BaseIOTaskTest):
- def test_main(self):
- with open(self.temp_filename, 'wb') as f:
- # Write once to the file
- task = self.get_task(
- IOWriteTask,
- main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0},
- )
- task()
-
- # Write again to the file
- task = self.get_task(
- IOWriteTask,
- main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3},
- )
- task()
-
- with open(self.temp_filename, 'rb') as f:
- self.assertEqual(f.read(), b'foobar')
-
-
-class TestIORenameFileTask(BaseIOTaskTest):
- def test_main(self):
- with open(self.temp_filename, 'wb') as f:
- task = self.get_task(
- IORenameFileTask,
- main_kwargs={
- 'fileobj': f,
- 'final_filename': self.final_filename,
- 'osutil': self.osutil,
- },
- )
- task()
- self.assertTrue(os.path.exists(self.final_filename))
- self.assertFalse(os.path.exists(self.temp_filename))
-
-
-class TestIOCloseTask(BaseIOTaskTest):
- def test_main(self):
- with open(self.temp_filename, 'w') as f:
- task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f})
- task()
- self.assertTrue(f.closed)
-
-
-class TestDownloadChunkIterator(unittest.TestCase):
- def test_iter(self):
- content = b'my content'
- body = BytesIO(content)
- ref_chunks = []
- for chunk in DownloadChunkIterator(body, len(content)):
- ref_chunks.append(chunk)
- self.assertEqual(ref_chunks, [b'my content'])
-
- def test_iter_chunksize(self):
- content = b'1234'
- body = BytesIO(content)
- ref_chunks = []
- for chunk in DownloadChunkIterator(body, 3):
- ref_chunks.append(chunk)
- self.assertEqual(ref_chunks, [b'123', b'4'])
-
- def test_empty_content(self):
- body = BytesIO(b'')
- ref_chunks = []
- for chunk in DownloadChunkIterator(body, 3):
- ref_chunks.append(chunk)
- self.assertEqual(ref_chunks, [b''])
-
-
-class TestDeferQueue(unittest.TestCase):
- def setUp(self):
- self.q = DeferQueue()
-
- def test_no_writes_when_not_lowest_block(self):
- writes = self.q.request_writes(offset=1, data='bar')
- self.assertEqual(writes, [])
-
- def test_writes_returned_in_order(self):
- self.assertEqual(self.q.request_writes(offset=3, data='d'), [])
- self.assertEqual(self.q.request_writes(offset=2, data='c'), [])
- self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
-
- # Everything at this point has been deferred, but as soon as we
- # send offset=0, that will unlock offsets 0-3.
- writes = self.q.request_writes(offset=0, data='a')
- self.assertEqual(
- writes,
- [
- {'offset': 0, 'data': 'a'},
- {'offset': 1, 'data': 'b'},
- {'offset': 2, 'data': 'c'},
- {'offset': 3, 'data': 'd'},
- ],
- )
-
- def test_unlocks_partial_range(self):
- self.assertEqual(self.q.request_writes(offset=5, data='f'), [])
- self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
-
- # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first.
- writes = self.q.request_writes(offset=0, data='a')
- self.assertEqual(
- writes,
- [
- {'offset': 0, 'data': 'a'},
- {'offset': 1, 'data': 'b'},
- ],
- )
-
- def test_data_can_be_any_size(self):
- self.q.request_writes(offset=5, data='hello world')
- writes = self.q.request_writes(offset=0, data='abcde')
- self.assertEqual(
- writes,
- [
- {'offset': 0, 'data': 'abcde'},
- {'offset': 5, 'data': 'hello world'},
- ],
- )
-
- def test_data_queued_in_order(self):
- # This immediately gets returned because offset=0 is the
- # next range we're waiting on.
- writes = self.q.request_writes(offset=0, data='hello world')
- self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}])
- # Same thing here but with offset
- writes = self.q.request_writes(offset=11, data='hello again')
- self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}])
-
- def test_writes_below_min_offset_are_ignored(self):
- self.q.request_writes(offset=0, data='a')
- self.q.request_writes(offset=1, data='b')
- self.q.request_writes(offset=2, data='c')
-
- # At this point we're expecting offset=3, so if a write
- # comes in below 3, we ignore it.
- self.assertEqual(self.q.request_writes(offset=0, data='a'), [])
- self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
-
- self.assertEqual(
- self.q.request_writes(offset=3, data='d'),
- [{'offset': 3, 'data': 'd'}],
- )
-
- def test_duplicate_writes_are_ignored(self):
- self.q.request_writes(offset=2, data='c')
- self.q.request_writes(offset=1, data='b')
-
- # We're still waiting for offset=0, but if
- # a duplicate write comes in for offset=2/offset=1
- # it's ignored. This gives "first one wins" behavior.
- self.assertEqual(self.q.request_writes(offset=2, data='X'), [])
- self.assertEqual(self.q.request_writes(offset=1, data='Y'), [])
-
- self.assertEqual(
- self.q.request_writes(offset=0, data='a'),
- [
- {'offset': 0, 'data': 'a'},
- # Note we're seeing 'b' 'c', and not 'X', 'Y'.
- {'offset': 1, 'data': 'b'},
- {'offset': 2, 'data': 'c'},
- ],
- )
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import copy
+import os
+import shutil
+import tempfile
+from io import BytesIO
+
+from s3transfer.bandwidth import BandwidthLimiter
+from s3transfer.compat import SOCKET_ERROR
+from s3transfer.download import (
+ CompleteDownloadNOOPTask,
+ DeferQueue,
+ DownloadChunkIterator,
+ DownloadFilenameOutputManager,
+ DownloadNonSeekableOutputManager,
+ DownloadSeekableOutputManager,
+ DownloadSpecialFilenameOutputManager,
+ DownloadSubmissionTask,
+ GetObjectTask,
+ ImmediatelyWriteIOGetObjectTask,
+ IOCloseTask,
+ IORenameFileTask,
+ IOStreamingWriteTask,
+ IOWriteTask,
+)
+from s3transfer.exceptions import RetriesExceededError
+from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG, BoundedExecutor
+from s3transfer.utils import CallArgs, OSUtils
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileCreator,
+ NonSeekableWriter,
+ RecordingExecutor,
+ StreamWithError,
+ mock,
+ unittest,
+)
+
+
+class DownloadException(Exception):
+ pass
+
+
+class WriteCollector:
+ """A utility to collect information about writes and seeks"""
+
+ def __init__(self):
+ self._pos = 0
+ self.writes = []
+
+ def seek(self, pos, whence=0):
+ self._pos = pos
+
+ def write(self, data):
+ self.writes.append((self._pos, data))
+ self._pos += len(data)
+
+
+class AlwaysIndicatesSpecialFileOSUtils(OSUtils):
+ """OSUtil that always returns True for is_special_file"""
+
+ def is_special_file(self, filename):
+ return True
+
+
+class CancelledStreamWrapper:
+ """A wrapper to trigger a cancellation while stream reading
+
+ Forces the transfer coordinator to cancel after a certain amount of reads
+ :param stream: The underlying stream to read from
+ :param transfer_coordinator: The coordinator for the transfer
+ :param num_reads: On which read to signal a cancellation. 0 is the first
+ read.
+ """
+
+ def __init__(self, stream, transfer_coordinator, num_reads=0):
+ self._stream = stream
+ self._transfer_coordinator = transfer_coordinator
+ self._num_reads = num_reads
+ self._count = 0
+
+ def read(self, *args, **kwargs):
+ if self._num_reads == self._count:
+ self._transfer_coordinator.cancel()
+ self._stream.read(*args, **kwargs)
+ self._count += 1
+
+
+class BaseDownloadOutputManagerTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = OSUtils()
+
+ # Create a file to write to
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ self.call_args = CallArgs(fileobj=self.filename)
+ self.future = self.get_transfer_future(self.call_args)
+ self.io_executor = BoundedExecutor(1000, 1)
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+
+class TestDownloadFilenameOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadFilenameOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.filename, self.osutil
+ )
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIsNone(self.download_output_manager.get_download_task_tag())
+
+ def test_get_fileobj_for_io_writes(self):
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ # Ensure it is a file like object returned
+ self.assertTrue(hasattr(f, 'read'))
+ self.assertTrue(hasattr(f, 'seek'))
+ # Make sure the name of the file returned is not the same as the
+ # final filename as we should be writing to a temporary file.
+ self.assertNotEqual(f.name, self.filename)
+
+ def test_get_final_io_task(self):
+ ref_contents = b'my_contents'
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ temp_filename = f.name
+ # Write some data to test that the data gets moved over to the
+ # final location.
+ f.write(ref_contents)
+ final_task = self.download_output_manager.get_final_io_task()
+ # Make sure it is the appropriate task.
+ self.assertIsInstance(final_task, IORenameFileTask)
+ final_task()
+ # Make sure the temp_file gets removed
+ self.assertFalse(os.path.exists(temp_filename))
+ # Make sure what ever was written to the temp file got moved to
+ # the final filename
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), ref_contents)
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=3
+ )
+ self.assertIsInstance(io_write_task, IOWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(3, 'foo')])
+
+
+class TestDownloadSpecialFilenameOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = AlwaysIndicatesSpecialFileOSUtils()
+ self.download_output_manager = DownloadSpecialFilenameOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ def test_is_compatible_for_special_file(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.filename, AlwaysIndicatesSpecialFileOSUtils()
+ )
+ )
+
+ def test_is_not_compatible_for_non_special_file(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(
+ self.filename, OSUtils()
+ )
+ )
+
+ def test_get_fileobj_for_io_writes(self):
+ with self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ) as f:
+ # Ensure it is a file like object returned
+ self.assertTrue(hasattr(f, 'read'))
+ # Make sure the name of the file returned is the same as the
+ # final filename as we should not be writing to a temporary file.
+ self.assertEqual(f.name, self.filename)
+
+ def test_get_final_io_task(self):
+ self.assertIsInstance(
+ self.download_output_manager.get_final_io_task(), IOCloseTask
+ )
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+
+class TestDownloadSeekableOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=self.io_executor,
+ )
+
+ # Create a fileobj to write to
+ self.fileobj = open(self.filename, 'wb')
+
+ self.call_args = CallArgs(fileobj=self.fileobj)
+ self.future = self.get_transfer_future(self.call_args)
+
+ def tearDown(self):
+ self.fileobj.close()
+ super().tearDown()
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(
+ self.fileobj, self.osutil
+ )
+ )
+
+ def test_is_compatible_bytes_io(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(BytesIO(), self.osutil)
+ )
+
+ def test_not_compatible_for_non_filelike_obj(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(object(), self.osutil)
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIsNone(self.download_output_manager.get_download_task_tag())
+
+ def test_get_fileobj_for_io_writes(self):
+ self.assertIs(
+ self.download_output_manager.get_fileobj_for_io_writes(
+ self.future
+ ),
+ self.fileobj,
+ )
+
+ def test_get_final_io_task(self):
+ self.assertIsInstance(
+ self.download_output_manager.get_final_io_task(),
+ CompleteDownloadNOOPTask,
+ )
+
+ def test_can_queue_file_io_task(self):
+ fileobj = WriteCollector()
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='foo', offset=0
+ )
+ self.download_output_manager.queue_file_io_task(
+ fileobj=fileobj, data='bar', offset=3
+ )
+ self.io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=3
+ )
+ self.assertIsInstance(io_write_task, IOWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(3, 'foo')])
+
+
+class TestDownloadNonSeekableOutputManager(BaseDownloadOutputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.download_output_manager = DownloadNonSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, io_executor=None
+ )
+
+ def test_is_compatible_with_seekable_stream(self):
+ with open(self.filename, 'wb') as f:
+ self.assertTrue(
+ self.download_output_manager.is_compatible(f, self.osutil)
+ )
+
+ def test_not_compatible_with_filename(self):
+ self.assertFalse(
+ self.download_output_manager.is_compatible(
+ self.filename, self.osutil
+ )
+ )
+
+ def test_compatible_with_non_seekable_stream(self):
+ class NonSeekable:
+ def write(self, data):
+ pass
+
+ f = NonSeekable()
+ self.assertTrue(
+ self.download_output_manager.is_compatible(f, self.osutil)
+ )
+
+ def test_is_compatible_with_bytesio(self):
+ self.assertTrue(
+ self.download_output_manager.is_compatible(BytesIO(), self.osutil)
+ )
+
+ def test_get_download_task_tag(self):
+ self.assertIs(
+ self.download_output_manager.get_download_task_tag(),
+ IN_MEMORY_DOWNLOAD_TAG,
+ )
+
+ def test_submit_writes_from_internal_queue(self):
+ class FakeQueue:
+ def request_writes(self, offset, data):
+ return [
+ {'offset': 0, 'data': 'foo'},
+ {'offset': 3, 'data': 'bar'},
+ ]
+
+ q = FakeQueue()
+ io_executor = BoundedExecutor(1000, 1)
+ manager = DownloadNonSeekableOutputManager(
+ self.osutil,
+ self.transfer_coordinator,
+ io_executor=io_executor,
+ defer_queue=q,
+ )
+ fileobj = WriteCollector()
+ manager.queue_file_io_task(fileobj=fileobj, data='foo', offset=1)
+ io_executor.shutdown()
+ self.assertEqual(fileobj.writes, [(0, 'foo'), (3, 'bar')])
+
+ def test_get_file_io_write_task(self):
+ fileobj = WriteCollector()
+ io_write_task = self.download_output_manager.get_io_write_task(
+ fileobj=fileobj, data='foo', offset=1
+ )
+ self.assertIsInstance(io_write_task, IOStreamingWriteTask)
+
+ io_write_task()
+ self.assertEqual(fileobj.writes, [(0, 'foo')])
+
+
+class TestDownloadSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # Create a stream to read from
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+
+ self.call_args = self.get_call_args()
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.io_executor = BoundedExecutor(1000, 1)
+ self.submission_main_kwargs = {
+ 'client': self.client,
+ 'config': self.config,
+ 'osutil': self.osutil,
+ 'request_executor': self.executor,
+ 'io_executor': self.io_executor,
+ 'transfer_future': self.transfer_future,
+ }
+ self.submission_task = self.get_download_submission_task()
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def get_call_args(self, **kwargs):
+ default_call_args = {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'subscribers': self.subscribers,
+ }
+ default_call_args.update(kwargs)
+ return CallArgs(**default_call_args)
+
+ def wrap_executor_in_recorder(self):
+ self.executor = RecordingExecutor(self.executor)
+ self.submission_main_kwargs['request_executor'] = self.executor
+
+ def use_fileobj_in_call_args(self, fileobj):
+ self.call_args = self.get_call_args(fileobj=fileobj)
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs['transfer_future'] = self.transfer_future
+
+ def assert_tag_for_get_object(self, tag_value):
+ submissions_to_compare = self.executor.submissions
+ if len(submissions_to_compare) > 1:
+ # If it was ranged get, make sure we do not include the join task.
+ submissions_to_compare = submissions_to_compare[:-1]
+ for submission in submissions_to_compare:
+ self.assertEqual(submission['tag'], tag_value)
+
+ def add_head_object_response(self):
+ self.stubber.add_response(
+ 'head_object', {'ContentLength': len(self.content)}
+ )
+
+ def add_get_responses(self):
+ chunksize = self.config.multipart_chunksize
+ for i in range(0, len(self.content), chunksize):
+ if i + chunksize > len(self.content):
+ stream = BytesIO(self.content[i:])
+ self.stubber.add_response('get_object', {'Body': stream})
+ else:
+ stream = BytesIO(self.content[i : i + chunksize])
+ self.stubber.add_response('get_object', {'Body': stream})
+
+ def configure_for_ranged_get(self):
+ self.config.multipart_threshold = 1
+ self.config.multipart_chunksize = 4
+
+ def get_download_submission_task(self):
+ return self.get_task(
+ DownloadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+
+ def wait_and_assert_completed_successfully(self, submission_task):
+ submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_submits_no_tag_for_get_object_filename(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_ranged_get_filename(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_get_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def test_submits_no_tag_for_ranged_get_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(None)
+
+ def tests_submits_tag_for_get_object_nonseekable_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(NonSeekableWriter(f))
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
+
+ def tests_submits_tag_for_ranged_get_object_nonseekable_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.configure_for_ranged_get()
+ self.add_head_object_response()
+ self.add_get_responses()
+
+ with open(self.filename, 'wb') as f:
+ self.use_fileobj_in_call_args(NonSeekableWriter(f))
+ self.submission_task = self.get_download_submission_task()
+ self.wait_and_assert_completed_successfully(self.submission_task)
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_for_get_object(IN_MEMORY_DOWNLOAD_TAG)
+
+
+class TestGetObjectTask(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.callbacks = []
+ self.max_attempts = 5
+ self.io_executor = BoundedExecutor(1000, 1)
+ self.content = b'my content'
+ self.stream = BytesIO(self.content)
+ self.fileobj = WriteCollector()
+ self.osutil = OSUtils()
+ self.io_chunksize = 64 * (1024 ** 2)
+ self.task_cls = GetObjectTask
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, self.io_executor
+ )
+
+ def get_download_task(self, **kwargs):
+ default_kwargs = {
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'fileobj': self.fileobj,
+ 'extra_args': self.extra_args,
+ 'callbacks': self.callbacks,
+ 'max_attempts': self.max_attempts,
+ 'download_output_manager': self.download_output_manager,
+ 'io_chunksize': self.io_chunksize,
+ }
+ default_kwargs.update(kwargs)
+ self.transfer_coordinator.set_status_to_queued()
+ return self.get_task(self.task_cls, main_kwargs=default_kwargs)
+
+ def assert_io_writes(self, expected_writes):
+ # Let the io executor process all of the writes before checking
+ # what writes were sent to it.
+ self.io_executor.shutdown()
+ self.assertEqual(self.fileobj.writes, expected_writes)
+
+ def test_main(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_extra_args(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Range': 'bytes=0-',
+ },
+ )
+ self.extra_args['Range'] = 'bytes=0-'
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_control_chunk_size(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(io_chunksize=1)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ expected_contents = []
+ for i in range(len(self.content)):
+ expected_contents.append((i, bytes(self.content[i : i + 1])))
+
+ self.assert_io_writes(expected_contents)
+
+ def test_start_index(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(start_index=5)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(5, self.content)])
+
+ def test_uses_bandwidth_limiter(self):
+ bandwidth_limiter = mock.Mock(BandwidthLimiter)
+
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(bandwidth_limiter=bandwidth_limiter)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(
+ bandwidth_limiter.get_bandwith_limited_stream.call_args_list,
+ [mock.call(mock.ANY, self.transfer_coordinator)],
+ )
+
+ def test_retries_succeeds(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(self.stream, SOCKET_ERROR)
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ # Retryable error should have not affected the bytes placed into
+ # the io queue.
+ self.stubber.assert_no_pending_responses()
+ self.assert_io_writes([(0, self.content)])
+
+ def test_retries_failure(self):
+ for _ in range(self.max_attempts):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(self.stream, SOCKET_ERROR)
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+
+ task = self.get_download_task()
+ task()
+ self.transfer_coordinator.announce_done()
+
+ # Should have failed out on a RetriesExceededError
+ with self.assertRaises(RetriesExceededError):
+ self.transfer_coordinator.result()
+ self.stubber.assert_no_pending_responses()
+
+ def test_retries_in_middle_of_streaming(self):
+ # After the first read a retryable error will be thrown
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': StreamWithError(
+ copy.deepcopy(self.stream), SOCKET_ERROR, 1
+ )
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.stubber.add_response(
+ 'get_object',
+ service_response={'Body': self.stream},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task(io_chunksize=1)
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ expected_contents = []
+ # This is the content initially read in before the retry hit on the
+ # second read()
+ expected_contents.append((0, bytes(self.content[0:1])))
+
+ # The rest of the content should be the entire set of data partitioned
+ # out based on the one byte stream chunk size. Note the second
+ # element in the list should be a copy of the first element since
+ # a retryable exception happened in between.
+ for i in range(len(self.content)):
+ expected_contents.append((i, bytes(self.content[i : i + 1])))
+ self.assert_io_writes(expected_contents)
+
+ def test_cancels_out_of_queueing(self):
+ self.stubber.add_response(
+ 'get_object',
+ service_response={
+ 'Body': CancelledStreamWrapper(
+ self.stream, self.transfer_coordinator
+ )
+ },
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ task = self.get_download_task()
+ task()
+
+ self.stubber.assert_no_pending_responses()
+ # Make sure that no contents were added to the queue because the task
+ # should have been canceled before trying to add the contents to the
+ # io queue.
+ self.assert_io_writes([])
+
+ def test_handles_callback_on_initial_error(self):
+ # We can't use the stubber for this because we need to raise
+ # a S3_RETRYABLE_DOWNLOAD_ERRORS, and the stubber only allows
+ # you to raise a ClientError.
+ self.client.get_object = mock.Mock(side_effect=SOCKET_ERROR())
+ task = self.get_download_task()
+ task()
+ self.transfer_coordinator.announce_done()
+ # Should have failed out on a RetriesExceededError because
+ # get_object keeps raising a socket error.
+ with self.assertRaises(RetriesExceededError):
+ self.transfer_coordinator.result()
+
+
+class TestImmediatelyWriteIOGetObjectTask(TestGetObjectTask):
+ def setUp(self):
+ super().setUp()
+ self.task_cls = ImmediatelyWriteIOGetObjectTask
+ # When data is written out, it should not use the io executor at all
+ # if it does use the io executor that is a deviation from expected
+ # behavior as the data should be written immediately to the file
+ # object once downloaded.
+ self.io_executor = None
+ self.download_output_manager = DownloadSeekableOutputManager(
+ self.osutil, self.transfer_coordinator, self.io_executor
+ )
+
+ def assert_io_writes(self, expected_writes):
+ self.assertEqual(self.fileobj.writes, expected_writes)
+
+
+class BaseIOTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.files = FileCreator()
+ self.osutil = OSUtils()
+ self.temp_filename = os.path.join(self.files.rootdir, 'mytempfile')
+ self.final_filename = os.path.join(self.files.rootdir, 'myfile')
+
+ def tearDown(self):
+ super().tearDown()
+ self.files.remove_all()
+
+
+class TestIOStreamingWriteTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ task = self.get_task(
+ IOStreamingWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'foobar'},
+ )
+ task()
+ task2 = self.get_task(
+ IOStreamingWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'baz'},
+ )
+ task2()
+ with open(self.temp_filename, 'rb') as f:
+ # We should just have written to the file in the order
+ # the tasks were executed.
+ self.assertEqual(f.read(), b'foobarbaz')
+
+
+class TestIOWriteTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ # Write once to the file
+ task = self.get_task(
+ IOWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'foo', 'offset': 0},
+ )
+ task()
+
+ # Write again to the file
+ task = self.get_task(
+ IOWriteTask,
+ main_kwargs={'fileobj': f, 'data': b'bar', 'offset': 3},
+ )
+ task()
+
+ with open(self.temp_filename, 'rb') as f:
+ self.assertEqual(f.read(), b'foobar')
+
+
+class TestIORenameFileTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'wb') as f:
+ task = self.get_task(
+ IORenameFileTask,
+ main_kwargs={
+ 'fileobj': f,
+ 'final_filename': self.final_filename,
+ 'osutil': self.osutil,
+ },
+ )
+ task()
+ self.assertTrue(os.path.exists(self.final_filename))
+ self.assertFalse(os.path.exists(self.temp_filename))
+
+
+class TestIOCloseTask(BaseIOTaskTest):
+ def test_main(self):
+ with open(self.temp_filename, 'w') as f:
+ task = self.get_task(IOCloseTask, main_kwargs={'fileobj': f})
+ task()
+ self.assertTrue(f.closed)
+
+
+class TestDownloadChunkIterator(unittest.TestCase):
+ def test_iter(self):
+ content = b'my content'
+ body = BytesIO(content)
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, len(content)):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b'my content'])
+
+ def test_iter_chunksize(self):
+ content = b'1234'
+ body = BytesIO(content)
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, 3):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b'123', b'4'])
+
+ def test_empty_content(self):
+ body = BytesIO(b'')
+ ref_chunks = []
+ for chunk in DownloadChunkIterator(body, 3):
+ ref_chunks.append(chunk)
+ self.assertEqual(ref_chunks, [b''])
+
+
+class TestDeferQueue(unittest.TestCase):
+ def setUp(self):
+ self.q = DeferQueue()
+
+ def test_no_writes_when_not_lowest_block(self):
+ writes = self.q.request_writes(offset=1, data='bar')
+ self.assertEqual(writes, [])
+
+ def test_writes_returned_in_order(self):
+ self.assertEqual(self.q.request_writes(offset=3, data='d'), [])
+ self.assertEqual(self.q.request_writes(offset=2, data='c'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ # Everything at this point has been deferred, but as soon as we
+ # send offset=0, that will unlock offsets 0-3.
+ writes = self.q.request_writes(offset=0, data='a')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'a'},
+ {'offset': 1, 'data': 'b'},
+ {'offset': 2, 'data': 'c'},
+ {'offset': 3, 'data': 'd'},
+ ],
+ )
+
+ def test_unlocks_partial_range(self):
+ self.assertEqual(self.q.request_writes(offset=5, data='f'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ # offset=0 unlocks 0-1, but offset=5 still needs to see 2-4 first.
+ writes = self.q.request_writes(offset=0, data='a')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'a'},
+ {'offset': 1, 'data': 'b'},
+ ],
+ )
+
+ def test_data_can_be_any_size(self):
+ self.q.request_writes(offset=5, data='hello world')
+ writes = self.q.request_writes(offset=0, data='abcde')
+ self.assertEqual(
+ writes,
+ [
+ {'offset': 0, 'data': 'abcde'},
+ {'offset': 5, 'data': 'hello world'},
+ ],
+ )
+
+ def test_data_queued_in_order(self):
+ # This immediately gets returned because offset=0 is the
+ # next range we're waiting on.
+ writes = self.q.request_writes(offset=0, data='hello world')
+ self.assertEqual(writes, [{'offset': 0, 'data': 'hello world'}])
+ # Same thing here but with offset
+ writes = self.q.request_writes(offset=11, data='hello again')
+ self.assertEqual(writes, [{'offset': 11, 'data': 'hello again'}])
+
+ def test_writes_below_min_offset_are_ignored(self):
+ self.q.request_writes(offset=0, data='a')
+ self.q.request_writes(offset=1, data='b')
+ self.q.request_writes(offset=2, data='c')
+
+ # At this point we're expecting offset=3, so if a write
+ # comes in below 3, we ignore it.
+ self.assertEqual(self.q.request_writes(offset=0, data='a'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='b'), [])
+
+ self.assertEqual(
+ self.q.request_writes(offset=3, data='d'),
+ [{'offset': 3, 'data': 'd'}],
+ )
+
+ def test_duplicate_writes_are_ignored(self):
+ self.q.request_writes(offset=2, data='c')
+ self.q.request_writes(offset=1, data='b')
+
+ # We're still waiting for offset=0, but if
+ # a duplicate write comes in for offset=2/offset=1
+ # it's ignored. This gives "first one wins" behavior.
+ self.assertEqual(self.q.request_writes(offset=2, data='X'), [])
+ self.assertEqual(self.q.request_writes(offset=1, data='Y'), [])
+
+ self.assertEqual(
+ self.q.request_writes(offset=0, data='a'),
+ [
+ {'offset': 0, 'data': 'a'},
+ # Note we're seeing 'b' 'c', and not 'X', 'Y'.
+ {'offset': 1, 'data': 'b'},
+ {'offset': 2, 'data': 'c'},
+ ],
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_futures.py b/contrib/python/s3transfer/py3/tests/unit/test_futures.py
index 2f4028effe..ca2888a654 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_futures.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_futures.py
@@ -1,696 +1,696 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import sys
-import time
-import traceback
-from concurrent.futures import ThreadPoolExecutor
-
-from s3transfer.exceptions import (
- CancelledError,
- FatalError,
- TransferNotDoneError,
-)
-from s3transfer.futures import (
- BaseExecutor,
- BoundedExecutor,
- ExecutorFuture,
- NonThreadedExecutor,
- NonThreadedExecutorFuture,
- TransferCoordinator,
- TransferFuture,
- TransferMeta,
-)
-from s3transfer.tasks import Task
-from s3transfer.utils import (
- FunctionContainer,
- NoResourcesAvailable,
- TaskSemaphore,
-)
-from __tests__ import (
- RecordingExecutor,
- TransferCoordinatorWithInterrupt,
- mock,
- unittest,
-)
-
-
-def return_call_args(*args, **kwargs):
- return args, kwargs
-
-
-def raise_exception(exception):
- raise exception
-
-
-def get_exc_info(exception):
- try:
- raise_exception(exception)
- except Exception:
- return sys.exc_info()
-
-
-class RecordingTransferCoordinator(TransferCoordinator):
- def __init__(self):
- self.all_transfer_futures_ever_associated = set()
- super().__init__()
-
- def add_associated_future(self, future):
- self.all_transfer_futures_ever_associated.add(future)
- super().add_associated_future(future)
-
-
-class ReturnFooTask(Task):
- def _main(self, **kwargs):
- return 'foo'
-
-
-class SleepTask(Task):
- def _main(self, sleep_time, **kwargs):
- time.sleep(sleep_time)
-
-
-class TestTransferFuture(unittest.TestCase):
- def setUp(self):
- self.meta = TransferMeta()
- self.coordinator = TransferCoordinator()
- self.future = self._get_transfer_future()
-
- def _get_transfer_future(self, **kwargs):
- components = {
- 'meta': self.meta,
- 'coordinator': self.coordinator,
- }
- for component_name, component in kwargs.items():
- components[component_name] = component
- return TransferFuture(**components)
-
- def test_meta(self):
- self.assertIs(self.future.meta, self.meta)
-
- def test_done(self):
- self.assertFalse(self.future.done())
- self.coordinator.set_result(None)
- self.assertTrue(self.future.done())
-
- def test_result(self):
- result = 'foo'
- self.coordinator.set_result(result)
- self.coordinator.announce_done()
- self.assertEqual(self.future.result(), result)
-
- def test_keyboard_interrupt_on_result_does_not_block(self):
- # This should raise a KeyboardInterrupt when result is called on it.
- self.coordinator = TransferCoordinatorWithInterrupt()
- self.future = self._get_transfer_future()
-
- # result() should not block and immediately raise the keyboard
- # interrupt exception.
- with self.assertRaises(KeyboardInterrupt):
- self.future.result()
-
- def test_cancel(self):
- self.future.cancel()
- self.assertTrue(self.future.done())
- self.assertEqual(self.coordinator.status, 'cancelled')
-
- def test_set_exception(self):
- # Set the result such that there is no exception
- self.coordinator.set_result('result')
- self.coordinator.announce_done()
- self.assertEqual(self.future.result(), 'result')
-
- self.future.set_exception(ValueError())
- with self.assertRaises(ValueError):
- self.future.result()
-
- def test_set_exception_only_after_done(self):
- with self.assertRaises(TransferNotDoneError):
- self.future.set_exception(ValueError())
-
- self.coordinator.set_result('result')
- self.coordinator.announce_done()
- self.future.set_exception(ValueError())
- with self.assertRaises(ValueError):
- self.future.result()
-
-
-class TestTransferMeta(unittest.TestCase):
- def setUp(self):
- self.transfer_meta = TransferMeta()
-
- def test_size(self):
- self.assertEqual(self.transfer_meta.size, None)
- self.transfer_meta.provide_transfer_size(5)
- self.assertEqual(self.transfer_meta.size, 5)
-
- def test_call_args(self):
- call_args = object()
- transfer_meta = TransferMeta(call_args)
- # Assert the that call args provided is the same as is returned
- self.assertIs(transfer_meta.call_args, call_args)
-
- def test_transfer_id(self):
- transfer_meta = TransferMeta(transfer_id=1)
- self.assertEqual(transfer_meta.transfer_id, 1)
-
- def test_user_context(self):
- self.transfer_meta.user_context['foo'] = 'bar'
- self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'})
-
-
-class TestTransferCoordinator(unittest.TestCase):
- def setUp(self):
- self.transfer_coordinator = TransferCoordinator()
-
- def test_transfer_id(self):
- transfer_coordinator = TransferCoordinator(transfer_id=1)
- self.assertEqual(transfer_coordinator.transfer_id, 1)
-
- def test_repr(self):
- transfer_coordinator = TransferCoordinator(transfer_id=1)
- self.assertEqual(
- repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)'
- )
-
- def test_initial_status(self):
- # A TransferCoordinator with no progress should have the status
- # of not-started
- self.assertEqual(self.transfer_coordinator.status, 'not-started')
-
- def test_set_status_to_queued(self):
- self.transfer_coordinator.set_status_to_queued()
- self.assertEqual(self.transfer_coordinator.status, 'queued')
-
- def test_cannot_set_status_to_queued_from_done_state(self):
- self.transfer_coordinator.set_exception(RuntimeError)
- with self.assertRaises(RuntimeError):
- self.transfer_coordinator.set_status_to_queued()
-
- def test_status_running(self):
- self.transfer_coordinator.set_status_to_running()
- self.assertEqual(self.transfer_coordinator.status, 'running')
-
- def test_cannot_set_status_to_running_from_done_state(self):
- self.transfer_coordinator.set_exception(RuntimeError)
- with self.assertRaises(RuntimeError):
- self.transfer_coordinator.set_status_to_running()
-
- def test_set_result(self):
- success_result = 'foo'
- self.transfer_coordinator.set_result(success_result)
- self.transfer_coordinator.announce_done()
- # Setting result should result in a success state and the return value
- # that was set.
- self.assertEqual(self.transfer_coordinator.status, 'success')
- self.assertEqual(self.transfer_coordinator.result(), success_result)
-
- def test_set_exception(self):
- exception_result = RuntimeError
- self.transfer_coordinator.set_exception(exception_result)
- self.transfer_coordinator.announce_done()
- # Setting an exception should result in a failed state and the return
- # value should be the raised exception
- self.assertEqual(self.transfer_coordinator.status, 'failed')
- self.assertEqual(self.transfer_coordinator.exception, exception_result)
- with self.assertRaises(exception_result):
- self.transfer_coordinator.result()
-
- def test_exception_cannot_override_done_state(self):
- self.transfer_coordinator.set_result('foo')
- self.transfer_coordinator.set_exception(RuntimeError)
- # It status should be success even after the exception is set because
- # success is a done state.
- self.assertEqual(self.transfer_coordinator.status, 'success')
-
- def test_exception_can_override_done_state_with_override_flag(self):
- self.transfer_coordinator.set_result('foo')
- self.transfer_coordinator.set_exception(RuntimeError, override=True)
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- def test_cancel(self):
- self.assertEqual(self.transfer_coordinator.status, 'not-started')
- self.transfer_coordinator.cancel()
- # This should set the state to cancelled and raise the CancelledError
- # exception and should have also set the done event so that result()
- # is no longer set.
- self.assertEqual(self.transfer_coordinator.status, 'cancelled')
- with self.assertRaises(CancelledError):
- self.transfer_coordinator.result()
-
- def test_cancel_can_run_done_callbacks_that_uses_result(self):
- exceptions = []
-
- def capture_exception(transfer_coordinator, captured_exceptions):
- try:
- transfer_coordinator.result()
- except Exception as e:
- captured_exceptions.append(e)
-
- self.assertEqual(self.transfer_coordinator.status, 'not-started')
- self.transfer_coordinator.add_done_callback(
- capture_exception, self.transfer_coordinator, exceptions
- )
- self.transfer_coordinator.cancel()
-
- self.assertEqual(len(exceptions), 1)
- self.assertIsInstance(exceptions[0], CancelledError)
-
- def test_cancel_with_message(self):
- message = 'my message'
- self.transfer_coordinator.cancel(message)
- self.transfer_coordinator.announce_done()
- with self.assertRaisesRegex(CancelledError, message):
- self.transfer_coordinator.result()
-
- def test_cancel_with_provided_exception(self):
- message = 'my message'
- self.transfer_coordinator.cancel(message, exc_type=FatalError)
- self.transfer_coordinator.announce_done()
- with self.assertRaisesRegex(FatalError, message):
- self.transfer_coordinator.result()
-
- def test_cancel_cannot_override_done_state(self):
- self.transfer_coordinator.set_result('foo')
- self.transfer_coordinator.cancel()
- # It status should be success even after cancel is called because
- # success is a done state.
- self.assertEqual(self.transfer_coordinator.status, 'success')
-
- def test_set_result_can_override_cancel(self):
- self.transfer_coordinator.cancel()
- # Result setting should override any cancel or set exception as this
- # is always invoked by the final task.
- self.transfer_coordinator.set_result('foo')
- self.transfer_coordinator.announce_done()
- self.assertEqual(self.transfer_coordinator.status, 'success')
-
- def test_submit(self):
- # Submit a callable to the transfer coordinator. It should submit it
- # to the executor.
- executor = RecordingExecutor(
- BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)})
- )
- task = ReturnFooTask(self.transfer_coordinator)
- future = self.transfer_coordinator.submit(executor, task, tag='my-tag')
- executor.shutdown()
- # Make sure the future got submit and executed as well by checking its
- # result value which should include the provided future tag.
- self.assertEqual(
- executor.submissions,
- [{'block': True, 'tag': 'my-tag', 'task': task}],
- )
- self.assertEqual(future.result(), 'foo')
-
- def test_association_and_disassociation_on_submit(self):
- self.transfer_coordinator = RecordingTransferCoordinator()
-
- # Submit a callable to the transfer coordinator.
- executor = BoundedExecutor(1, 1)
- task = ReturnFooTask(self.transfer_coordinator)
- future = self.transfer_coordinator.submit(executor, task)
- executor.shutdown()
-
- # Make sure the future that got submitted was associated to the
- # transfer future at some point.
- self.assertEqual(
- self.transfer_coordinator.all_transfer_futures_ever_associated,
- {future},
- )
-
- # Make sure the future got disassociated once the future is now done
- # by looking at the currently associated futures.
- self.assertEqual(self.transfer_coordinator.associated_futures, set())
-
- def test_done(self):
- # These should result in not done state:
- # queued
- self.assertFalse(self.transfer_coordinator.done())
- # running
- self.transfer_coordinator.set_status_to_running()
- self.assertFalse(self.transfer_coordinator.done())
-
- # These should result in done state:
- # failed
- self.transfer_coordinator.set_exception(Exception)
- self.assertTrue(self.transfer_coordinator.done())
-
- # success
- self.transfer_coordinator.set_result('foo')
- self.assertTrue(self.transfer_coordinator.done())
-
- # cancelled
- self.transfer_coordinator.cancel()
- self.assertTrue(self.transfer_coordinator.done())
-
- def test_result_waits_until_done(self):
- execution_order = []
-
- def sleep_then_set_result(transfer_coordinator, execution_order):
- time.sleep(0.05)
- execution_order.append('setting_result')
- transfer_coordinator.set_result(None)
- self.transfer_coordinator.announce_done()
-
- with ThreadPoolExecutor(max_workers=1) as executor:
- executor.submit(
- sleep_then_set_result,
- self.transfer_coordinator,
- execution_order,
- )
- self.transfer_coordinator.result()
- execution_order.append('after_result')
-
- # The result() call should have waited until the other thread set
- # the result after sleeping for 0.05 seconds.
- self.assertTrue(execution_order, ['setting_result', 'after_result'])
-
- def test_failure_cleanups(self):
- args = (1, 2)
- kwargs = {'foo': 'bar'}
-
- second_args = (2, 4)
- second_kwargs = {'biz': 'baz'}
-
- self.transfer_coordinator.add_failure_cleanup(
- return_call_args, *args, **kwargs
- )
- self.transfer_coordinator.add_failure_cleanup(
- return_call_args, *second_args, **second_kwargs
- )
-
- # Ensure the callbacks got added.
- self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2)
-
- result_list = []
- # Ensure they will get called in the correct order.
- for cleanup in self.transfer_coordinator.failure_cleanups:
- result_list.append(cleanup())
- self.assertEqual(
- result_list, [(args, kwargs), (second_args, second_kwargs)]
- )
-
- def test_associated_futures(self):
- first_future = object()
- # Associate one future to the transfer
- self.transfer_coordinator.add_associated_future(first_future)
- associated_futures = self.transfer_coordinator.associated_futures
- # The first future should be in the returned list of futures.
- self.assertEqual(associated_futures, {first_future})
-
- second_future = object()
- # Associate another future to the transfer.
- self.transfer_coordinator.add_associated_future(second_future)
- # The association should not have mutated the returned list from
- # before.
- self.assertEqual(associated_futures, {first_future})
-
- # Both futures should be in the returned list.
- self.assertEqual(
- self.transfer_coordinator.associated_futures,
- {first_future, second_future},
- )
-
- def test_done_callbacks_on_done(self):
- done_callback_invocations = []
- callback = FunctionContainer(
- done_callback_invocations.append, 'done callback called'
- )
-
- # Add the done callback to the transfer.
- self.transfer_coordinator.add_done_callback(callback)
-
- # Announce that the transfer is done. This should invoke the done
- # callback.
- self.transfer_coordinator.announce_done()
- self.assertEqual(done_callback_invocations, ['done callback called'])
-
- # If done is announced again, we should not invoke the callback again
- # because done has already been announced and thus the callback has
- # been ran as well.
- self.transfer_coordinator.announce_done()
- self.assertEqual(done_callback_invocations, ['done callback called'])
-
- def test_failure_cleanups_on_done(self):
- cleanup_invocations = []
- callback = FunctionContainer(
- cleanup_invocations.append, 'cleanup called'
- )
-
- # Add the failure cleanup to the transfer.
- self.transfer_coordinator.add_failure_cleanup(callback)
-
- # Announce that the transfer is done. This should invoke the failure
- # cleanup.
- self.transfer_coordinator.announce_done()
- self.assertEqual(cleanup_invocations, ['cleanup called'])
-
- # If done is announced again, we should not invoke the cleanup again
- # because done has already been announced and thus the cleanup has
- # been ran as well.
- self.transfer_coordinator.announce_done()
- self.assertEqual(cleanup_invocations, ['cleanup called'])
-
-
-class TestBoundedExecutor(unittest.TestCase):
- def setUp(self):
- self.coordinator = TransferCoordinator()
- self.tag_semaphores = {}
- self.executor = self.get_executor()
-
- def get_executor(self, max_size=1, max_num_threads=1):
- return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores)
-
- def get_task(self, task_cls, main_kwargs=None):
- return task_cls(self.coordinator, main_kwargs=main_kwargs)
-
- def get_sleep_task(self, sleep_time=0.01):
- return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time})
-
- def add_semaphore(self, task_tag, count):
- self.tag_semaphores[task_tag] = TaskSemaphore(count)
-
- def assert_submit_would_block(self, task, tag=None):
- with self.assertRaises(NoResourcesAvailable):
- self.executor.submit(task, tag=tag, block=False)
-
- def assert_submit_would_not_block(self, task, tag=None, **kwargs):
- try:
- self.executor.submit(task, tag=tag, block=False)
- except NoResourcesAvailable:
- self.fail(
- 'Task {} should not have been blocked. Caused by:\n{}'.format(
- task, traceback.format_exc()
- )
- )
-
- def add_done_callback_to_future(self, future, fn, *args, **kwargs):
- callback_for_future = FunctionContainer(fn, *args, **kwargs)
- future.add_done_callback(callback_for_future)
-
- def test_submit_single_task(self):
- # Ensure we can submit a task to the executor
- task = self.get_task(ReturnFooTask)
- future = self.executor.submit(task)
-
- # Ensure what we get back is a Future
- self.assertIsInstance(future, ExecutorFuture)
- # Ensure the callable got executed.
- self.assertEqual(future.result(), 'foo')
-
- @unittest.skipIf(
- os.environ.get('USE_SERIAL_EXECUTOR'),
- "Not supported with serial executor tests",
- )
- def test_executor_blocks_on_full_capacity(self):
- first_task = self.get_sleep_task()
- second_task = self.get_sleep_task()
- self.executor.submit(first_task)
- # The first task should be sleeping for a substantial period of
- # time such that on the submission of the second task, it will
- # raise an error saying that it cannot be submitted as the max
- # capacity of the semaphore is one.
- self.assert_submit_would_block(second_task)
-
- def test_executor_clears_capacity_on_done_tasks(self):
- first_task = self.get_sleep_task()
- second_task = self.get_task(ReturnFooTask)
-
- # Submit a task.
- future = self.executor.submit(first_task)
-
- # Submit a new task when the first task finishes. This should not get
- # blocked because the first task should have finished clearing up
- # capacity.
- self.add_done_callback_to_future(
- future, self.assert_submit_would_not_block, second_task
- )
-
- # Wait for it to complete.
- self.executor.shutdown()
-
- @unittest.skipIf(
- os.environ.get('USE_SERIAL_EXECUTOR'),
- "Not supported with serial executor tests",
- )
- def test_would_not_block_when_full_capacity_in_other_semaphore(self):
- first_task = self.get_sleep_task()
-
- # Now let's create a new task with a tag and so it uses different
- # semaphore.
- task_tag = 'other'
- other_task = self.get_sleep_task()
- self.add_semaphore(task_tag, 1)
-
- # Submit the normal first task
- self.executor.submit(first_task)
-
- # Even though The first task should be sleeping for a substantial
- # period of time, the submission of the second task should not
- # raise an error because it should use a different semaphore
- self.assert_submit_would_not_block(other_task, task_tag)
-
- # Another submission of the other task though should raise
- # an exception as the capacity is equal to one for that tag.
- self.assert_submit_would_block(other_task, task_tag)
-
- def test_shutdown(self):
- slow_task = self.get_sleep_task()
- future = self.executor.submit(slow_task)
- self.executor.shutdown()
- # Ensure that the shutdown waits until the task is done
- self.assertTrue(future.done())
-
- @unittest.skipIf(
- os.environ.get('USE_SERIAL_EXECUTOR'),
- "Not supported with serial executor tests",
- )
- def test_shutdown_no_wait(self):
- slow_task = self.get_sleep_task()
- future = self.executor.submit(slow_task)
- self.executor.shutdown(False)
- # Ensure that the shutdown returns immediately even if the task is
- # not done, which it should not be because it it slow.
- self.assertFalse(future.done())
-
- def test_replace_underlying_executor(self):
- mocked_executor_cls = mock.Mock(BaseExecutor)
- executor = BoundedExecutor(10, 1, {}, mocked_executor_cls)
- executor.submit(self.get_task(ReturnFooTask))
- self.assertTrue(mocked_executor_cls.return_value.submit.called)
-
-
-class TestExecutorFuture(unittest.TestCase):
- def test_result(self):
- with ThreadPoolExecutor(max_workers=1) as executor:
- future = executor.submit(return_call_args, 'foo', biz='baz')
- wrapped_future = ExecutorFuture(future)
- self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'}))
-
- def test_done(self):
- with ThreadPoolExecutor(max_workers=1) as executor:
- future = executor.submit(return_call_args, 'foo', biz='baz')
- wrapped_future = ExecutorFuture(future)
- self.assertTrue(wrapped_future.done())
-
- def test_add_done_callback(self):
- done_callbacks = []
- with ThreadPoolExecutor(max_workers=1) as executor:
- future = executor.submit(return_call_args, 'foo', biz='baz')
- wrapped_future = ExecutorFuture(future)
- wrapped_future.add_done_callback(
- FunctionContainer(done_callbacks.append, 'called')
- )
- self.assertEqual(done_callbacks, ['called'])
-
-
-class TestNonThreadedExecutor(unittest.TestCase):
- def test_submit(self):
- executor = NonThreadedExecutor()
- future = executor.submit(return_call_args, 1, 2, foo='bar')
- self.assertIsInstance(future, NonThreadedExecutorFuture)
- self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'}))
-
- def test_submit_with_exception(self):
- executor = NonThreadedExecutor()
- future = executor.submit(raise_exception, RuntimeError())
- self.assertIsInstance(future, NonThreadedExecutorFuture)
- with self.assertRaises(RuntimeError):
- future.result()
-
- def test_submit_with_exception_and_captures_info(self):
- exception = ValueError('message')
- tb = get_exc_info(exception)[2]
- future = NonThreadedExecutor().submit(raise_exception, exception)
- try:
- future.result()
- # An exception should have been raised
- self.fail('Future should have raised a ValueError')
- except ValueError:
- actual_tb = sys.exc_info()[2]
- last_frame = traceback.extract_tb(actual_tb)[-1]
- last_expected_frame = traceback.extract_tb(tb)[-1]
- self.assertEqual(last_frame, last_expected_frame)
-
-
-class TestNonThreadedExecutorFuture(unittest.TestCase):
- def setUp(self):
- self.future = NonThreadedExecutorFuture()
-
- def test_done_starts_false(self):
- self.assertFalse(self.future.done())
-
- def test_done_after_setting_result(self):
- self.future.set_result('result')
- self.assertTrue(self.future.done())
-
- def test_done_after_setting_exception(self):
- self.future.set_exception_info(Exception(), None)
- self.assertTrue(self.future.done())
-
- def test_result(self):
- self.future.set_result('result')
- self.assertEqual(self.future.result(), 'result')
-
- def test_exception_result(self):
- exception = ValueError('message')
- self.future.set_exception_info(exception, None)
- with self.assertRaisesRegex(ValueError, 'message'):
- self.future.result()
-
- def test_exception_result_doesnt_modify_last_frame(self):
- exception = ValueError('message')
- tb = get_exc_info(exception)[2]
- self.future.set_exception_info(exception, tb)
- try:
- self.future.result()
- # An exception should have been raised
- self.fail()
- except ValueError:
- actual_tb = sys.exc_info()[2]
- last_frame = traceback.extract_tb(actual_tb)[-1]
- last_expected_frame = traceback.extract_tb(tb)[-1]
- self.assertEqual(last_frame, last_expected_frame)
-
- def test_done_callback(self):
- done_futures = []
- self.future.add_done_callback(done_futures.append)
- self.assertEqual(done_futures, [])
- self.future.set_result('result')
- self.assertEqual(done_futures, [self.future])
-
- def test_done_callback_after_done(self):
- self.future.set_result('result')
- done_futures = []
- self.future.add_done_callback(done_futures.append)
- self.assertEqual(done_futures, [self.future])
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import sys
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+
+from s3transfer.exceptions import (
+ CancelledError,
+ FatalError,
+ TransferNotDoneError,
+)
+from s3transfer.futures import (
+ BaseExecutor,
+ BoundedExecutor,
+ ExecutorFuture,
+ NonThreadedExecutor,
+ NonThreadedExecutorFuture,
+ TransferCoordinator,
+ TransferFuture,
+ TransferMeta,
+)
+from s3transfer.tasks import Task
+from s3transfer.utils import (
+ FunctionContainer,
+ NoResourcesAvailable,
+ TaskSemaphore,
+)
+from __tests__ import (
+ RecordingExecutor,
+ TransferCoordinatorWithInterrupt,
+ mock,
+ unittest,
+)
+
+
+def return_call_args(*args, **kwargs):
+ return args, kwargs
+
+
+def raise_exception(exception):
+ raise exception
+
+
+def get_exc_info(exception):
+ try:
+ raise_exception(exception)
+ except Exception:
+ return sys.exc_info()
+
+
+class RecordingTransferCoordinator(TransferCoordinator):
+ def __init__(self):
+ self.all_transfer_futures_ever_associated = set()
+ super().__init__()
+
+ def add_associated_future(self, future):
+ self.all_transfer_futures_ever_associated.add(future)
+ super().add_associated_future(future)
+
+
+class ReturnFooTask(Task):
+ def _main(self, **kwargs):
+ return 'foo'
+
+
+class SleepTask(Task):
+ def _main(self, sleep_time, **kwargs):
+ time.sleep(sleep_time)
+
+
+class TestTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.meta = TransferMeta()
+ self.coordinator = TransferCoordinator()
+ self.future = self._get_transfer_future()
+
+ def _get_transfer_future(self, **kwargs):
+ components = {
+ 'meta': self.meta,
+ 'coordinator': self.coordinator,
+ }
+ for component_name, component in kwargs.items():
+ components[component_name] = component
+ return TransferFuture(**components)
+
+ def test_meta(self):
+ self.assertIs(self.future.meta, self.meta)
+
+ def test_done(self):
+ self.assertFalse(self.future.done())
+ self.coordinator.set_result(None)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ result = 'foo'
+ self.coordinator.set_result(result)
+ self.coordinator.announce_done()
+ self.assertEqual(self.future.result(), result)
+
+ def test_keyboard_interrupt_on_result_does_not_block(self):
+ # This should raise a KeyboardInterrupt when result is called on it.
+ self.coordinator = TransferCoordinatorWithInterrupt()
+ self.future = self._get_transfer_future()
+
+ # result() should not block and immediately raise the keyboard
+ # interrupt exception.
+ with self.assertRaises(KeyboardInterrupt):
+ self.future.result()
+
+ def test_cancel(self):
+ self.future.cancel()
+ self.assertTrue(self.future.done())
+ self.assertEqual(self.coordinator.status, 'cancelled')
+
+ def test_set_exception(self):
+ # Set the result such that there is no exception
+ self.coordinator.set_result('result')
+ self.coordinator.announce_done()
+ self.assertEqual(self.future.result(), 'result')
+
+ self.future.set_exception(ValueError())
+ with self.assertRaises(ValueError):
+ self.future.result()
+
+ def test_set_exception_only_after_done(self):
+ with self.assertRaises(TransferNotDoneError):
+ self.future.set_exception(ValueError())
+
+ self.coordinator.set_result('result')
+ self.coordinator.announce_done()
+ self.future.set_exception(ValueError())
+ with self.assertRaises(ValueError):
+ self.future.result()
+
+
+class TestTransferMeta(unittest.TestCase):
+ def setUp(self):
+ self.transfer_meta = TransferMeta()
+
+ def test_size(self):
+ self.assertEqual(self.transfer_meta.size, None)
+ self.transfer_meta.provide_transfer_size(5)
+ self.assertEqual(self.transfer_meta.size, 5)
+
+ def test_call_args(self):
+ call_args = object()
+ transfer_meta = TransferMeta(call_args)
+ # Assert the that call args provided is the same as is returned
+ self.assertIs(transfer_meta.call_args, call_args)
+
+ def test_transfer_id(self):
+ transfer_meta = TransferMeta(transfer_id=1)
+ self.assertEqual(transfer_meta.transfer_id, 1)
+
+ def test_user_context(self):
+ self.transfer_meta.user_context['foo'] = 'bar'
+ self.assertEqual(self.transfer_meta.user_context, {'foo': 'bar'})
+
+
+class TestTransferCoordinator(unittest.TestCase):
+ def setUp(self):
+ self.transfer_coordinator = TransferCoordinator()
+
+ def test_transfer_id(self):
+ transfer_coordinator = TransferCoordinator(transfer_id=1)
+ self.assertEqual(transfer_coordinator.transfer_id, 1)
+
+ def test_repr(self):
+ transfer_coordinator = TransferCoordinator(transfer_id=1)
+ self.assertEqual(
+ repr(transfer_coordinator), 'TransferCoordinator(transfer_id=1)'
+ )
+
+ def test_initial_status(self):
+ # A TransferCoordinator with no progress should have the status
+ # of not-started
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+
+ def test_set_status_to_queued(self):
+ self.transfer_coordinator.set_status_to_queued()
+ self.assertEqual(self.transfer_coordinator.status, 'queued')
+
+ def test_cannot_set_status_to_queued_from_done_state(self):
+ self.transfer_coordinator.set_exception(RuntimeError)
+ with self.assertRaises(RuntimeError):
+ self.transfer_coordinator.set_status_to_queued()
+
+ def test_status_running(self):
+ self.transfer_coordinator.set_status_to_running()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ def test_cannot_set_status_to_running_from_done_state(self):
+ self.transfer_coordinator.set_exception(RuntimeError)
+ with self.assertRaises(RuntimeError):
+ self.transfer_coordinator.set_status_to_running()
+
+ def test_set_result(self):
+ success_result = 'foo'
+ self.transfer_coordinator.set_result(success_result)
+ self.transfer_coordinator.announce_done()
+ # Setting result should result in a success state and the return value
+ # that was set.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+ self.assertEqual(self.transfer_coordinator.result(), success_result)
+
+ def test_set_exception(self):
+ exception_result = RuntimeError
+ self.transfer_coordinator.set_exception(exception_result)
+ self.transfer_coordinator.announce_done()
+ # Setting an exception should result in a failed state and the return
+ # value should be the raised exception
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ self.assertEqual(self.transfer_coordinator.exception, exception_result)
+ with self.assertRaises(exception_result):
+ self.transfer_coordinator.result()
+
+ def test_exception_cannot_override_done_state(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.set_exception(RuntimeError)
+ # It status should be success even after the exception is set because
+ # success is a done state.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_exception_can_override_done_state_with_override_flag(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.set_exception(RuntimeError, override=True)
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ def test_cancel(self):
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+ self.transfer_coordinator.cancel()
+ # This should set the state to cancelled and raise the CancelledError
+ # exception and should have also set the done event so that result()
+ # is no longer set.
+ self.assertEqual(self.transfer_coordinator.status, 'cancelled')
+ with self.assertRaises(CancelledError):
+ self.transfer_coordinator.result()
+
+ def test_cancel_can_run_done_callbacks_that_uses_result(self):
+ exceptions = []
+
+ def capture_exception(transfer_coordinator, captured_exceptions):
+ try:
+ transfer_coordinator.result()
+ except Exception as e:
+ captured_exceptions.append(e)
+
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+ self.transfer_coordinator.add_done_callback(
+ capture_exception, self.transfer_coordinator, exceptions
+ )
+ self.transfer_coordinator.cancel()
+
+ self.assertEqual(len(exceptions), 1)
+ self.assertIsInstance(exceptions[0], CancelledError)
+
+ def test_cancel_with_message(self):
+ message = 'my message'
+ self.transfer_coordinator.cancel(message)
+ self.transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(CancelledError, message):
+ self.transfer_coordinator.result()
+
+ def test_cancel_with_provided_exception(self):
+ message = 'my message'
+ self.transfer_coordinator.cancel(message, exc_type=FatalError)
+ self.transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(FatalError, message):
+ self.transfer_coordinator.result()
+
+ def test_cancel_cannot_override_done_state(self):
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.cancel()
+ # It status should be success even after cancel is called because
+ # success is a done state.
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_set_result_can_override_cancel(self):
+ self.transfer_coordinator.cancel()
+ # Result setting should override any cancel or set exception as this
+ # is always invoked by the final task.
+ self.transfer_coordinator.set_result('foo')
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_submit(self):
+ # Submit a callable to the transfer coordinator. It should submit it
+ # to the executor.
+ executor = RecordingExecutor(
+ BoundedExecutor(1, 1, {'my-tag': TaskSemaphore(1)})
+ )
+ task = ReturnFooTask(self.transfer_coordinator)
+ future = self.transfer_coordinator.submit(executor, task, tag='my-tag')
+ executor.shutdown()
+ # Make sure the future got submit and executed as well by checking its
+ # result value which should include the provided future tag.
+ self.assertEqual(
+ executor.submissions,
+ [{'block': True, 'tag': 'my-tag', 'task': task}],
+ )
+ self.assertEqual(future.result(), 'foo')
+
+ def test_association_and_disassociation_on_submit(self):
+ self.transfer_coordinator = RecordingTransferCoordinator()
+
+ # Submit a callable to the transfer coordinator.
+ executor = BoundedExecutor(1, 1)
+ task = ReturnFooTask(self.transfer_coordinator)
+ future = self.transfer_coordinator.submit(executor, task)
+ executor.shutdown()
+
+ # Make sure the future that got submitted was associated to the
+ # transfer future at some point.
+ self.assertEqual(
+ self.transfer_coordinator.all_transfer_futures_ever_associated,
+ {future},
+ )
+
+ # Make sure the future got disassociated once the future is now done
+ # by looking at the currently associated futures.
+ self.assertEqual(self.transfer_coordinator.associated_futures, set())
+
+ def test_done(self):
+ # These should result in not done state:
+ # queued
+ self.assertFalse(self.transfer_coordinator.done())
+ # running
+ self.transfer_coordinator.set_status_to_running()
+ self.assertFalse(self.transfer_coordinator.done())
+
+ # These should result in done state:
+ # failed
+ self.transfer_coordinator.set_exception(Exception)
+ self.assertTrue(self.transfer_coordinator.done())
+
+ # success
+ self.transfer_coordinator.set_result('foo')
+ self.assertTrue(self.transfer_coordinator.done())
+
+ # cancelled
+ self.transfer_coordinator.cancel()
+ self.assertTrue(self.transfer_coordinator.done())
+
+ def test_result_waits_until_done(self):
+ execution_order = []
+
+ def sleep_then_set_result(transfer_coordinator, execution_order):
+ time.sleep(0.05)
+ execution_order.append('setting_result')
+ transfer_coordinator.set_result(None)
+ self.transfer_coordinator.announce_done()
+
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ executor.submit(
+ sleep_then_set_result,
+ self.transfer_coordinator,
+ execution_order,
+ )
+ self.transfer_coordinator.result()
+ execution_order.append('after_result')
+
+ # The result() call should have waited until the other thread set
+ # the result after sleeping for 0.05 seconds.
+ self.assertTrue(execution_order, ['setting_result', 'after_result'])
+
+ def test_failure_cleanups(self):
+ args = (1, 2)
+ kwargs = {'foo': 'bar'}
+
+ second_args = (2, 4)
+ second_kwargs = {'biz': 'baz'}
+
+ self.transfer_coordinator.add_failure_cleanup(
+ return_call_args, *args, **kwargs
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ return_call_args, *second_args, **second_kwargs
+ )
+
+ # Ensure the callbacks got added.
+ self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 2)
+
+ result_list = []
+ # Ensure they will get called in the correct order.
+ for cleanup in self.transfer_coordinator.failure_cleanups:
+ result_list.append(cleanup())
+ self.assertEqual(
+ result_list, [(args, kwargs), (second_args, second_kwargs)]
+ )
+
+ def test_associated_futures(self):
+ first_future = object()
+ # Associate one future to the transfer
+ self.transfer_coordinator.add_associated_future(first_future)
+ associated_futures = self.transfer_coordinator.associated_futures
+ # The first future should be in the returned list of futures.
+ self.assertEqual(associated_futures, {first_future})
+
+ second_future = object()
+ # Associate another future to the transfer.
+ self.transfer_coordinator.add_associated_future(second_future)
+ # The association should not have mutated the returned list from
+ # before.
+ self.assertEqual(associated_futures, {first_future})
+
+ # Both futures should be in the returned list.
+ self.assertEqual(
+ self.transfer_coordinator.associated_futures,
+ {first_future, second_future},
+ )
+
+ def test_done_callbacks_on_done(self):
+ done_callback_invocations = []
+ callback = FunctionContainer(
+ done_callback_invocations.append, 'done callback called'
+ )
+
+ # Add the done callback to the transfer.
+ self.transfer_coordinator.add_done_callback(callback)
+
+ # Announce that the transfer is done. This should invoke the done
+ # callback.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(done_callback_invocations, ['done callback called'])
+
+ # If done is announced again, we should not invoke the callback again
+ # because done has already been announced and thus the callback has
+ # been ran as well.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(done_callback_invocations, ['done callback called'])
+
+ def test_failure_cleanups_on_done(self):
+ cleanup_invocations = []
+ callback = FunctionContainer(
+ cleanup_invocations.append, 'cleanup called'
+ )
+
+ # Add the failure cleanup to the transfer.
+ self.transfer_coordinator.add_failure_cleanup(callback)
+
+ # Announce that the transfer is done. This should invoke the failure
+ # cleanup.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(cleanup_invocations, ['cleanup called'])
+
+ # If done is announced again, we should not invoke the cleanup again
+ # because done has already been announced and thus the cleanup has
+ # been ran as well.
+ self.transfer_coordinator.announce_done()
+ self.assertEqual(cleanup_invocations, ['cleanup called'])
+
+
+class TestBoundedExecutor(unittest.TestCase):
+ def setUp(self):
+ self.coordinator = TransferCoordinator()
+ self.tag_semaphores = {}
+ self.executor = self.get_executor()
+
+ def get_executor(self, max_size=1, max_num_threads=1):
+ return BoundedExecutor(max_size, max_num_threads, self.tag_semaphores)
+
+ def get_task(self, task_cls, main_kwargs=None):
+ return task_cls(self.coordinator, main_kwargs=main_kwargs)
+
+ def get_sleep_task(self, sleep_time=0.01):
+ return self.get_task(SleepTask, main_kwargs={'sleep_time': sleep_time})
+
+ def add_semaphore(self, task_tag, count):
+ self.tag_semaphores[task_tag] = TaskSemaphore(count)
+
+ def assert_submit_would_block(self, task, tag=None):
+ with self.assertRaises(NoResourcesAvailable):
+ self.executor.submit(task, tag=tag, block=False)
+
+ def assert_submit_would_not_block(self, task, tag=None, **kwargs):
+ try:
+ self.executor.submit(task, tag=tag, block=False)
+ except NoResourcesAvailable:
+ self.fail(
+ 'Task {} should not have been blocked. Caused by:\n{}'.format(
+ task, traceback.format_exc()
+ )
+ )
+
+ def add_done_callback_to_future(self, future, fn, *args, **kwargs):
+ callback_for_future = FunctionContainer(fn, *args, **kwargs)
+ future.add_done_callback(callback_for_future)
+
+ def test_submit_single_task(self):
+ # Ensure we can submit a task to the executor
+ task = self.get_task(ReturnFooTask)
+ future = self.executor.submit(task)
+
+ # Ensure what we get back is a Future
+ self.assertIsInstance(future, ExecutorFuture)
+ # Ensure the callable got executed.
+ self.assertEqual(future.result(), 'foo')
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_executor_blocks_on_full_capacity(self):
+ first_task = self.get_sleep_task()
+ second_task = self.get_sleep_task()
+ self.executor.submit(first_task)
+ # The first task should be sleeping for a substantial period of
+ # time such that on the submission of the second task, it will
+ # raise an error saying that it cannot be submitted as the max
+ # capacity of the semaphore is one.
+ self.assert_submit_would_block(second_task)
+
+ def test_executor_clears_capacity_on_done_tasks(self):
+ first_task = self.get_sleep_task()
+ second_task = self.get_task(ReturnFooTask)
+
+ # Submit a task.
+ future = self.executor.submit(first_task)
+
+ # Submit a new task when the first task finishes. This should not get
+ # blocked because the first task should have finished clearing up
+ # capacity.
+ self.add_done_callback_to_future(
+ future, self.assert_submit_would_not_block, second_task
+ )
+
+ # Wait for it to complete.
+ self.executor.shutdown()
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_would_not_block_when_full_capacity_in_other_semaphore(self):
+ first_task = self.get_sleep_task()
+
+ # Now let's create a new task with a tag and so it uses different
+ # semaphore.
+ task_tag = 'other'
+ other_task = self.get_sleep_task()
+ self.add_semaphore(task_tag, 1)
+
+ # Submit the normal first task
+ self.executor.submit(first_task)
+
+ # Even though The first task should be sleeping for a substantial
+ # period of time, the submission of the second task should not
+ # raise an error because it should use a different semaphore
+ self.assert_submit_would_not_block(other_task, task_tag)
+
+ # Another submission of the other task though should raise
+ # an exception as the capacity is equal to one for that tag.
+ self.assert_submit_would_block(other_task, task_tag)
+
+ def test_shutdown(self):
+ slow_task = self.get_sleep_task()
+ future = self.executor.submit(slow_task)
+ self.executor.shutdown()
+ # Ensure that the shutdown waits until the task is done
+ self.assertTrue(future.done())
+
+ @unittest.skipIf(
+ os.environ.get('USE_SERIAL_EXECUTOR'),
+ "Not supported with serial executor tests",
+ )
+ def test_shutdown_no_wait(self):
+ slow_task = self.get_sleep_task()
+ future = self.executor.submit(slow_task)
+ self.executor.shutdown(False)
+ # Ensure that the shutdown returns immediately even if the task is
+ # not done, which it should not be because it it slow.
+ self.assertFalse(future.done())
+
+ def test_replace_underlying_executor(self):
+ mocked_executor_cls = mock.Mock(BaseExecutor)
+ executor = BoundedExecutor(10, 1, {}, mocked_executor_cls)
+ executor.submit(self.get_task(ReturnFooTask))
+ self.assertTrue(mocked_executor_cls.return_value.submit.called)
+
+
+class TestExecutorFuture(unittest.TestCase):
+ def test_result(self):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ self.assertEqual(wrapped_future.result(), (('foo',), {'biz': 'baz'}))
+
+ def test_done(self):
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ self.assertTrue(wrapped_future.done())
+
+ def test_add_done_callback(self):
+ done_callbacks = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ future = executor.submit(return_call_args, 'foo', biz='baz')
+ wrapped_future = ExecutorFuture(future)
+ wrapped_future.add_done_callback(
+ FunctionContainer(done_callbacks.append, 'called')
+ )
+ self.assertEqual(done_callbacks, ['called'])
+
+
+class TestNonThreadedExecutor(unittest.TestCase):
+ def test_submit(self):
+ executor = NonThreadedExecutor()
+ future = executor.submit(return_call_args, 1, 2, foo='bar')
+ self.assertIsInstance(future, NonThreadedExecutorFuture)
+ self.assertEqual(future.result(), ((1, 2), {'foo': 'bar'}))
+
+ def test_submit_with_exception(self):
+ executor = NonThreadedExecutor()
+ future = executor.submit(raise_exception, RuntimeError())
+ self.assertIsInstance(future, NonThreadedExecutorFuture)
+ with self.assertRaises(RuntimeError):
+ future.result()
+
+ def test_submit_with_exception_and_captures_info(self):
+ exception = ValueError('message')
+ tb = get_exc_info(exception)[2]
+ future = NonThreadedExecutor().submit(raise_exception, exception)
+ try:
+ future.result()
+ # An exception should have been raised
+ self.fail('Future should have raised a ValueError')
+ except ValueError:
+ actual_tb = sys.exc_info()[2]
+ last_frame = traceback.extract_tb(actual_tb)[-1]
+ last_expected_frame = traceback.extract_tb(tb)[-1]
+ self.assertEqual(last_frame, last_expected_frame)
+
+
+class TestNonThreadedExecutorFuture(unittest.TestCase):
+ def setUp(self):
+ self.future = NonThreadedExecutorFuture()
+
+ def test_done_starts_false(self):
+ self.assertFalse(self.future.done())
+
+ def test_done_after_setting_result(self):
+ self.future.set_result('result')
+ self.assertTrue(self.future.done())
+
+ def test_done_after_setting_exception(self):
+ self.future.set_exception_info(Exception(), None)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ self.future.set_result('result')
+ self.assertEqual(self.future.result(), 'result')
+
+ def test_exception_result(self):
+ exception = ValueError('message')
+ self.future.set_exception_info(exception, None)
+ with self.assertRaisesRegex(ValueError, 'message'):
+ self.future.result()
+
+ def test_exception_result_doesnt_modify_last_frame(self):
+ exception = ValueError('message')
+ tb = get_exc_info(exception)[2]
+ self.future.set_exception_info(exception, tb)
+ try:
+ self.future.result()
+ # An exception should have been raised
+ self.fail()
+ except ValueError:
+ actual_tb = sys.exc_info()[2]
+ last_frame = traceback.extract_tb(actual_tb)[-1]
+ last_expected_frame = traceback.extract_tb(tb)[-1]
+ self.assertEqual(last_frame, last_expected_frame)
+
+ def test_done_callback(self):
+ done_futures = []
+ self.future.add_done_callback(done_futures.append)
+ self.assertEqual(done_futures, [])
+ self.future.set_result('result')
+ self.assertEqual(done_futures, [self.future])
+
+ def test_done_callback_after_done(self):
+ self.future.set_result('result')
+ done_futures = []
+ self.future.add_done_callback(done_futures.append)
+ self.assertEqual(done_futures, [self.future])
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_manager.py b/contrib/python/s3transfer/py3/tests/unit/test_manager.py
index 2676357927..fc3caa843f 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_manager.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_manager.py
@@ -1,143 +1,143 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import time
-from concurrent.futures import ThreadPoolExecutor
-
-from s3transfer.exceptions import CancelledError, FatalError
-from s3transfer.futures import TransferCoordinator
-from s3transfer.manager import TransferConfig, TransferCoordinatorController
-from __tests__ import TransferCoordinatorWithInterrupt, unittest
-
-
-class FutureResultException(Exception):
- pass
-
-
-class TestTransferConfig(unittest.TestCase):
- def test_exception_on_zero_attr_value(self):
- with self.assertRaises(ValueError):
- TransferConfig(max_request_queue_size=0)
-
-
-class TestTransferCoordinatorController(unittest.TestCase):
- def setUp(self):
- self.coordinator_controller = TransferCoordinatorController()
-
- def sleep_then_announce_done(self, transfer_coordinator, sleep_time):
- time.sleep(sleep_time)
- transfer_coordinator.set_result('done')
- transfer_coordinator.announce_done()
-
- def assert_coordinator_is_cancelled(self, transfer_coordinator):
- self.assertEqual(transfer_coordinator.status, 'cancelled')
-
- def test_add_transfer_coordinator(self):
- transfer_coordinator = TransferCoordinator()
- # Add the transfer coordinator
- self.coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
- # Ensure that is tracked.
- self.assertEqual(
- self.coordinator_controller.tracked_transfer_coordinators,
- {transfer_coordinator},
- )
-
- def test_remove_transfer_coordinator(self):
- transfer_coordinator = TransferCoordinator()
- # Add the coordinator
- self.coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
- # Now remove the coordinator
- self.coordinator_controller.remove_transfer_coordinator(
- transfer_coordinator
- )
- # Make sure that it is no longer getting tracked.
- self.assertEqual(
- self.coordinator_controller.tracked_transfer_coordinators, set()
- )
-
- def test_cancel(self):
- transfer_coordinator = TransferCoordinator()
- # Add the transfer coordinator
- self.coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
- # Cancel with the canceler
- self.coordinator_controller.cancel()
- # Check that coordinator got canceled
- self.assert_coordinator_is_cancelled(transfer_coordinator)
-
- def test_cancel_with_message(self):
- message = 'my cancel message'
- transfer_coordinator = TransferCoordinator()
- self.coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
- self.coordinator_controller.cancel(message)
- transfer_coordinator.announce_done()
- with self.assertRaisesRegex(CancelledError, message):
- transfer_coordinator.result()
-
- def test_cancel_with_provided_exception(self):
- message = 'my cancel message'
- transfer_coordinator = TransferCoordinator()
- self.coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
- self.coordinator_controller.cancel(message, exc_type=FatalError)
- transfer_coordinator.announce_done()
- with self.assertRaisesRegex(FatalError, message):
- transfer_coordinator.result()
-
- def test_wait_for_done_transfer_coordinators(self):
- # Create a coordinator and add it to the canceler
- transfer_coordinator = TransferCoordinator()
- self.coordinator_controller.add_transfer_coordinator(
- transfer_coordinator
- )
-
- sleep_time = 0.02
- with ThreadPoolExecutor(max_workers=1) as executor:
- # In a separate thread sleep and then set the transfer coordinator
- # to done after sleeping.
- start_time = time.time()
- executor.submit(
- self.sleep_then_announce_done, transfer_coordinator, sleep_time
- )
- # Now call wait to wait for the transfer coordinator to be done.
- self.coordinator_controller.wait()
- end_time = time.time()
- wait_time = end_time - start_time
- # The time waited should not be less than the time it took to sleep in
- # the separate thread because the wait ending should be dependent on
- # the sleeping thread announcing that the transfer coordinator is done.
- self.assertTrue(sleep_time <= wait_time)
-
- def test_wait_does_not_propogate_exceptions_from_result(self):
- transfer_coordinator = TransferCoordinator()
- transfer_coordinator.set_exception(FutureResultException())
- transfer_coordinator.announce_done()
- try:
- self.coordinator_controller.wait()
- except FutureResultException as e:
- self.fail('%s should not have been raised.' % e)
-
- def test_wait_can_be_interrupted(self):
- inject_interrupt_coordinator = TransferCoordinatorWithInterrupt()
- self.coordinator_controller.add_transfer_coordinator(
- inject_interrupt_coordinator
- )
- with self.assertRaises(KeyboardInterrupt):
- self.coordinator_controller.wait()
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import time
+from concurrent.futures import ThreadPoolExecutor
+
+from s3transfer.exceptions import CancelledError, FatalError
+from s3transfer.futures import TransferCoordinator
+from s3transfer.manager import TransferConfig, TransferCoordinatorController
+from __tests__ import TransferCoordinatorWithInterrupt, unittest
+
+
+class FutureResultException(Exception):
+ pass
+
+
+class TestTransferConfig(unittest.TestCase):
+ def test_exception_on_zero_attr_value(self):
+ with self.assertRaises(ValueError):
+ TransferConfig(max_request_queue_size=0)
+
+
+class TestTransferCoordinatorController(unittest.TestCase):
+ def setUp(self):
+ self.coordinator_controller = TransferCoordinatorController()
+
+ def sleep_then_announce_done(self, transfer_coordinator, sleep_time):
+ time.sleep(sleep_time)
+ transfer_coordinator.set_result('done')
+ transfer_coordinator.announce_done()
+
+ def assert_coordinator_is_cancelled(self, transfer_coordinator):
+ self.assertEqual(transfer_coordinator.status, 'cancelled')
+
+ def test_add_transfer_coordinator(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the transfer coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Ensure that is tracked.
+ self.assertEqual(
+ self.coordinator_controller.tracked_transfer_coordinators,
+ {transfer_coordinator},
+ )
+
+ def test_remove_transfer_coordinator(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Now remove the coordinator
+ self.coordinator_controller.remove_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Make sure that it is no longer getting tracked.
+ self.assertEqual(
+ self.coordinator_controller.tracked_transfer_coordinators, set()
+ )
+
+ def test_cancel(self):
+ transfer_coordinator = TransferCoordinator()
+ # Add the transfer coordinator
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ # Cancel with the canceler
+ self.coordinator_controller.cancel()
+ # Check that coordinator got canceled
+ self.assert_coordinator_is_cancelled(transfer_coordinator)
+
+ def test_cancel_with_message(self):
+ message = 'my cancel message'
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ self.coordinator_controller.cancel(message)
+ transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(CancelledError, message):
+ transfer_coordinator.result()
+
+ def test_cancel_with_provided_exception(self):
+ message = 'my cancel message'
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+ self.coordinator_controller.cancel(message, exc_type=FatalError)
+ transfer_coordinator.announce_done()
+ with self.assertRaisesRegex(FatalError, message):
+ transfer_coordinator.result()
+
+ def test_wait_for_done_transfer_coordinators(self):
+ # Create a coordinator and add it to the canceler
+ transfer_coordinator = TransferCoordinator()
+ self.coordinator_controller.add_transfer_coordinator(
+ transfer_coordinator
+ )
+
+ sleep_time = 0.02
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ # In a separate thread sleep and then set the transfer coordinator
+ # to done after sleeping.
+ start_time = time.time()
+ executor.submit(
+ self.sleep_then_announce_done, transfer_coordinator, sleep_time
+ )
+ # Now call wait to wait for the transfer coordinator to be done.
+ self.coordinator_controller.wait()
+ end_time = time.time()
+ wait_time = end_time - start_time
+ # The time waited should not be less than the time it took to sleep in
+ # the separate thread because the wait ending should be dependent on
+ # the sleeping thread announcing that the transfer coordinator is done.
+ self.assertTrue(sleep_time <= wait_time)
+
+ def test_wait_does_not_propogate_exceptions_from_result(self):
+ transfer_coordinator = TransferCoordinator()
+ transfer_coordinator.set_exception(FutureResultException())
+ transfer_coordinator.announce_done()
+ try:
+ self.coordinator_controller.wait()
+ except FutureResultException as e:
+ self.fail('%s should not have been raised.' % e)
+
+ def test_wait_can_be_interrupted(self):
+ inject_interrupt_coordinator = TransferCoordinatorWithInterrupt()
+ self.coordinator_controller.add_transfer_coordinator(
+ inject_interrupt_coordinator
+ )
+ with self.assertRaises(KeyboardInterrupt):
+ self.coordinator_controller.wait()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
index 9a6f63b638..d77b5e0240 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_processpool.py
@@ -1,728 +1,728 @@
-# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import queue
-import signal
-import threading
-import time
-from io import BytesIO
-
-from botocore.client import BaseClient
-from botocore.config import Config
-from botocore.exceptions import ClientError, ReadTimeoutError
-
-from s3transfer.constants import PROCESS_USER_AGENT
-from s3transfer.exceptions import CancelledError, RetriesExceededError
-from s3transfer.processpool import (
- SHUTDOWN_SIGNAL,
- ClientFactory,
- DownloadFileRequest,
- GetObjectJob,
- GetObjectSubmitter,
- GetObjectWorker,
- ProcessPoolDownloader,
- ProcessPoolTransferFuture,
- ProcessPoolTransferMeta,
- ProcessTransferConfig,
- TransferMonitor,
- TransferState,
- ignore_ctrl_c,
-)
-from s3transfer.utils import CallArgs, OSUtils
-from __tests__ import (
- FileCreator,
- StreamWithError,
- StubbedClientTest,
- mock,
- skip_if_windows,
- unittest,
-)
-
-
-class RenameFailingOSUtils(OSUtils):
- def __init__(self, exception):
- self.exception = exception
-
- def rename_file(self, current_filename, new_filename):
- raise self.exception
-
-
-class TestIgnoreCtrlC(unittest.TestCase):
- @skip_if_windows('os.kill() with SIGINT not supported on Windows')
- def test_ignore_ctrl_c(self):
- with ignore_ctrl_c():
- try:
- os.kill(os.getpid(), signal.SIGINT)
- except KeyboardInterrupt:
- self.fail(
- 'The ignore_ctrl_c context manager should have '
- 'ignored the KeyboardInterrupt exception'
- )
-
-
-class TestProcessPoolDownloader(unittest.TestCase):
- def test_uses_client_kwargs(self):
- with mock.patch('s3transfer.processpool.ClientFactory') as factory:
- ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'})
- self.assertEqual(
- factory.call_args[0][0], {'region_name': 'myregion'}
- )
-
-
-class TestProcessPoolTransferFuture(unittest.TestCase):
- def setUp(self):
- self.monitor = TransferMonitor()
- self.transfer_id = self.monitor.notify_new_transfer()
- self.meta = ProcessPoolTransferMeta(
- transfer_id=self.transfer_id, call_args=CallArgs()
- )
- self.future = ProcessPoolTransferFuture(
- monitor=self.monitor, meta=self.meta
- )
-
- def test_meta(self):
- self.assertEqual(self.future.meta, self.meta)
-
- def test_done(self):
- self.assertFalse(self.future.done())
- self.monitor.notify_done(self.transfer_id)
- self.assertTrue(self.future.done())
-
- def test_result(self):
- self.monitor.notify_done(self.transfer_id)
- self.assertIsNone(self.future.result())
-
- def test_result_with_exception(self):
- self.monitor.notify_exception(self.transfer_id, RuntimeError())
- self.monitor.notify_done(self.transfer_id)
- with self.assertRaises(RuntimeError):
- self.future.result()
-
- def test_result_with_keyboard_interrupt(self):
- mock_monitor = mock.Mock(TransferMonitor)
- mock_monitor._connect = mock.Mock()
- mock_monitor.poll_for_result.side_effect = KeyboardInterrupt()
- future = ProcessPoolTransferFuture(
- monitor=mock_monitor, meta=self.meta
- )
- with self.assertRaises(KeyboardInterrupt):
- future.result()
- self.assertTrue(mock_monitor._connect.called)
- self.assertTrue(mock_monitor.notify_exception.called)
- call_args = mock_monitor.notify_exception.call_args[0]
- self.assertEqual(call_args[0], self.transfer_id)
- self.assertIsInstance(call_args[1], CancelledError)
-
- def test_cancel(self):
- self.future.cancel()
- self.monitor.notify_done(self.transfer_id)
- with self.assertRaises(CancelledError):
- self.future.result()
-
-
-class TestProcessPoolTransferMeta(unittest.TestCase):
- def test_transfer_id(self):
- meta = ProcessPoolTransferMeta(1, CallArgs())
- self.assertEqual(meta.transfer_id, 1)
-
- def test_call_args(self):
- call_args = CallArgs()
- meta = ProcessPoolTransferMeta(1, call_args)
- self.assertEqual(meta.call_args, call_args)
-
- def test_user_context(self):
- meta = ProcessPoolTransferMeta(1, CallArgs())
- self.assertEqual(meta.user_context, {})
- meta.user_context['mykey'] = 'myvalue'
- self.assertEqual(meta.user_context, {'mykey': 'myvalue'})
-
-
-class TestClientFactory(unittest.TestCase):
- def test_create_client(self):
- client = ClientFactory().create_client()
- self.assertIsInstance(client, BaseClient)
- self.assertEqual(client.meta.service_model.service_name, 's3')
- self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
-
- def test_create_client_with_client_kwargs(self):
- client = ClientFactory({'region_name': 'myregion'}).create_client()
- self.assertEqual(client.meta.region_name, 'myregion')
-
- def test_user_agent_with_config(self):
- client = ClientFactory({'config': Config()}).create_client()
- self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
-
- def test_user_agent_with_existing_user_agent_extra(self):
- config = Config(user_agent_extra='foo/1.0')
- client = ClientFactory({'config': config}).create_client()
- self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
-
- def test_user_agent_with_existing_user_agent(self):
- config = Config(user_agent='foo/1.0')
- client = ClientFactory({'config': config}).create_client()
- self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
-
-
-class TestTransferMonitor(unittest.TestCase):
- def setUp(self):
- self.monitor = TransferMonitor()
- self.transfer_id = self.monitor.notify_new_transfer()
-
- def test_notify_new_transfer_creates_new_state(self):
- monitor = TransferMonitor()
- transfer_id = monitor.notify_new_transfer()
- self.assertFalse(monitor.is_done(transfer_id))
- self.assertIsNone(monitor.get_exception(transfer_id))
-
- def test_notify_new_transfer_increments_transfer_id(self):
- monitor = TransferMonitor()
- self.assertEqual(monitor.notify_new_transfer(), 0)
- self.assertEqual(monitor.notify_new_transfer(), 1)
-
- def test_notify_get_exception(self):
- exception = Exception()
- self.monitor.notify_exception(self.transfer_id, exception)
- self.assertEqual(
- self.monitor.get_exception(self.transfer_id), exception
- )
-
- def test_get_no_exception(self):
- self.assertIsNone(self.monitor.get_exception(self.transfer_id))
-
- def test_notify_jobs(self):
- self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
- self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
- self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0)
-
- def test_notify_jobs_for_multiple_transfers(self):
- self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
- other_transfer_id = self.monitor.notify_new_transfer()
- self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2)
- self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
- self.assertEqual(
- self.monitor.notify_job_complete(other_transfer_id), 1
- )
-
- def test_done(self):
- self.assertFalse(self.monitor.is_done(self.transfer_id))
- self.monitor.notify_done(self.transfer_id)
- self.assertTrue(self.monitor.is_done(self.transfer_id))
-
- def test_poll_for_result(self):
- self.monitor.notify_done(self.transfer_id)
- self.assertIsNone(self.monitor.poll_for_result(self.transfer_id))
-
- def test_poll_for_result_raises_error(self):
- self.monitor.notify_exception(self.transfer_id, RuntimeError())
- self.monitor.notify_done(self.transfer_id)
- with self.assertRaises(RuntimeError):
- self.monitor.poll_for_result(self.transfer_id)
-
- def test_poll_for_result_waits_till_done(self):
- event_order = []
-
- def sleep_then_notify_done():
- time.sleep(0.05)
- event_order.append('notify_done')
- self.monitor.notify_done(self.transfer_id)
-
- t = threading.Thread(target=sleep_then_notify_done)
- t.start()
-
- self.monitor.poll_for_result(self.transfer_id)
- event_order.append('done_polling')
- self.assertEqual(event_order, ['notify_done', 'done_polling'])
-
- def test_notify_cancel_all_in_progress(self):
- monitor = TransferMonitor()
- transfer_ids = []
- for _ in range(10):
- transfer_ids.append(monitor.notify_new_transfer())
- monitor.notify_cancel_all_in_progress()
- for transfer_id in transfer_ids:
- self.assertIsInstance(
- monitor.get_exception(transfer_id), CancelledError
- )
- # Cancelling a transfer does not mean it is done as there may
- # be cleanup work left to do.
- self.assertFalse(monitor.is_done(transfer_id))
-
- def test_notify_cancel_does_not_affect_done_transfers(self):
- self.monitor.notify_done(self.transfer_id)
- self.monitor.notify_cancel_all_in_progress()
- self.assertTrue(self.monitor.is_done(self.transfer_id))
- self.assertIsNone(self.monitor.get_exception(self.transfer_id))
-
-
-class TestTransferState(unittest.TestCase):
- def setUp(self):
- self.state = TransferState()
-
- def test_done(self):
- self.assertFalse(self.state.done)
- self.state.set_done()
- self.assertTrue(self.state.done)
-
- def test_waits_till_done_is_set(self):
- event_order = []
-
- def sleep_then_set_done():
- time.sleep(0.05)
- event_order.append('set_done')
- self.state.set_done()
-
- t = threading.Thread(target=sleep_then_set_done)
- t.start()
-
- self.state.wait_till_done()
- event_order.append('done_waiting')
- self.assertEqual(event_order, ['set_done', 'done_waiting'])
-
- def test_exception(self):
- exception = RuntimeError()
- self.state.exception = exception
- self.assertEqual(self.state.exception, exception)
-
- def test_jobs_to_complete(self):
- self.state.jobs_to_complete = 5
- self.assertEqual(self.state.jobs_to_complete, 5)
-
- def test_decrement_jobs_to_complete(self):
- self.state.jobs_to_complete = 5
- self.assertEqual(self.state.decrement_jobs_to_complete(), 4)
-
-
-class TestGetObjectSubmitter(StubbedClientTest):
- def setUp(self):
- super().setUp()
- self.transfer_config = ProcessTransferConfig()
- self.client_factory = mock.Mock(ClientFactory)
- self.client_factory.create_client.return_value = self.client
- self.transfer_monitor = TransferMonitor()
- self.osutil = mock.Mock(OSUtils)
- self.download_request_queue = queue.Queue()
- self.worker_queue = queue.Queue()
- self.submitter = GetObjectSubmitter(
- transfer_config=self.transfer_config,
- client_factory=self.client_factory,
- transfer_monitor=self.transfer_monitor,
- osutil=self.osutil,
- download_request_queue=self.download_request_queue,
- worker_queue=self.worker_queue,
- )
- self.transfer_id = self.transfer_monitor.notify_new_transfer()
- self.bucket = 'bucket'
- self.key = 'key'
- self.filename = 'myfile'
- self.temp_filename = 'myfile.temp'
- self.osutil.get_temp_filename.return_value = self.temp_filename
- self.extra_args = {}
- self.expected_size = None
-
- def add_download_file_request(self, **override_kwargs):
- kwargs = {
- 'transfer_id': self.transfer_id,
- 'bucket': self.bucket,
- 'key': self.key,
- 'filename': self.filename,
- 'extra_args': self.extra_args,
- 'expected_size': self.expected_size,
- }
- kwargs.update(override_kwargs)
- self.download_request_queue.put(DownloadFileRequest(**kwargs))
-
- def add_shutdown(self):
- self.download_request_queue.put(SHUTDOWN_SIGNAL)
-
- def assert_submitted_get_object_jobs(self, expected_jobs):
- actual_jobs = []
- while not self.worker_queue.empty():
- actual_jobs.append(self.worker_queue.get())
- self.assertEqual(actual_jobs, expected_jobs)
-
- def test_run_for_non_ranged_download(self):
- self.add_download_file_request(expected_size=1)
- self.add_shutdown()
- self.submitter.run()
- self.osutil.allocate.assert_called_with(self.temp_filename, 1)
- self.assert_submitted_get_object_jobs(
- [
- GetObjectJob(
- transfer_id=self.transfer_id,
- bucket=self.bucket,
- key=self.key,
- temp_filename=self.temp_filename,
- offset=0,
- extra_args={},
- filename=self.filename,
- )
- ]
- )
-
- def test_run_for_ranged_download(self):
- self.transfer_config.multipart_chunksize = 2
- self.transfer_config.multipart_threshold = 4
- self.add_download_file_request(expected_size=4)
- self.add_shutdown()
- self.submitter.run()
- self.osutil.allocate.assert_called_with(self.temp_filename, 4)
- self.assert_submitted_get_object_jobs(
- [
- GetObjectJob(
- transfer_id=self.transfer_id,
- bucket=self.bucket,
- key=self.key,
- temp_filename=self.temp_filename,
- offset=0,
- extra_args={'Range': 'bytes=0-1'},
- filename=self.filename,
- ),
- GetObjectJob(
- transfer_id=self.transfer_id,
- bucket=self.bucket,
- key=self.key,
- temp_filename=self.temp_filename,
- offset=2,
- extra_args={'Range': 'bytes=2-'},
- filename=self.filename,
- ),
- ]
- )
-
- def test_run_when_expected_size_not_provided(self):
- self.stubber.add_response(
- 'head_object',
- {'ContentLength': 1},
- expected_params={'Bucket': self.bucket, 'Key': self.key},
- )
- self.add_download_file_request(expected_size=None)
- self.add_shutdown()
- self.submitter.run()
- self.stubber.assert_no_pending_responses()
- self.osutil.allocate.assert_called_with(self.temp_filename, 1)
- self.assert_submitted_get_object_jobs(
- [
- GetObjectJob(
- transfer_id=self.transfer_id,
- bucket=self.bucket,
- key=self.key,
- temp_filename=self.temp_filename,
- offset=0,
- extra_args={},
- filename=self.filename,
- )
- ]
- )
-
- def test_run_with_extra_args(self):
- self.stubber.add_response(
- 'head_object',
- {'ContentLength': 1},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'VersionId': 'versionid',
- },
- )
- self.add_download_file_request(
- extra_args={'VersionId': 'versionid'}, expected_size=None
- )
- self.add_shutdown()
- self.submitter.run()
- self.stubber.assert_no_pending_responses()
- self.osutil.allocate.assert_called_with(self.temp_filename, 1)
- self.assert_submitted_get_object_jobs(
- [
- GetObjectJob(
- transfer_id=self.transfer_id,
- bucket=self.bucket,
- key=self.key,
- temp_filename=self.temp_filename,
- offset=0,
- extra_args={'VersionId': 'versionid'},
- filename=self.filename,
- )
- ]
- )
-
- def test_run_with_exception(self):
- self.stubber.add_client_error('head_object', 'NoSuchKey', 404)
- self.add_download_file_request(expected_size=None)
- self.add_shutdown()
- self.submitter.run()
- self.stubber.assert_no_pending_responses()
- self.assert_submitted_get_object_jobs([])
- self.assertIsInstance(
- self.transfer_monitor.get_exception(self.transfer_id), ClientError
- )
-
- def test_run_with_error_in_allocating_temp_file(self):
- self.osutil.allocate.side_effect = OSError()
- self.add_download_file_request(expected_size=1)
- self.add_shutdown()
- self.submitter.run()
- self.assert_submitted_get_object_jobs([])
- self.assertIsInstance(
- self.transfer_monitor.get_exception(self.transfer_id), OSError
- )
-
- @skip_if_windows('os.kill() with SIGINT not supported on Windows')
- def test_submitter_cannot_be_killed(self):
- self.add_download_file_request(expected_size=None)
- self.add_shutdown()
-
- def raise_ctrl_c(**kwargs):
- os.kill(os.getpid(), signal.SIGINT)
-
- mock_client = mock.Mock()
- mock_client.head_object = raise_ctrl_c
- self.client_factory.create_client.return_value = mock_client
-
- try:
- self.submitter.run()
- except KeyboardInterrupt:
- self.fail(
- 'The submitter should have not been killed by the '
- 'KeyboardInterrupt'
- )
-
-
-class TestGetObjectWorker(StubbedClientTest):
- def setUp(self):
- super().setUp()
- self.files = FileCreator()
- self.queue = queue.Queue()
- self.client_factory = mock.Mock(ClientFactory)
- self.client_factory.create_client.return_value = self.client
- self.transfer_monitor = TransferMonitor()
- self.osutil = OSUtils()
- self.worker = GetObjectWorker(
- queue=self.queue,
- client_factory=self.client_factory,
- transfer_monitor=self.transfer_monitor,
- osutil=self.osutil,
- )
- self.transfer_id = self.transfer_monitor.notify_new_transfer()
- self.bucket = 'bucket'
- self.key = 'key'
- self.remote_contents = b'my content'
- self.temp_filename = self.files.create_file('tempfile', '')
- self.extra_args = {}
- self.offset = 0
- self.final_filename = self.files.full_path('final_filename')
- self.stream = BytesIO(self.remote_contents)
- self.transfer_monitor.notify_expected_jobs_to_complete(
- self.transfer_id, 1000
- )
-
- def tearDown(self):
- super().tearDown()
- self.files.remove_all()
-
- def add_get_object_job(self, **override_kwargs):
- kwargs = {
- 'transfer_id': self.transfer_id,
- 'bucket': self.bucket,
- 'key': self.key,
- 'temp_filename': self.temp_filename,
- 'extra_args': self.extra_args,
- 'offset': self.offset,
- 'filename': self.final_filename,
- }
- kwargs.update(override_kwargs)
- self.queue.put(GetObjectJob(**kwargs))
-
- def add_shutdown(self):
- self.queue.put(SHUTDOWN_SIGNAL)
-
- def add_stubbed_get_object_response(self, body=None, expected_params=None):
- if body is None:
- body = self.stream
- get_object_response = {'Body': body}
-
- if expected_params is None:
- expected_params = {'Bucket': self.bucket, 'Key': self.key}
-
- self.stubber.add_response(
- 'get_object', get_object_response, expected_params
- )
-
- def assert_contents(self, filename, contents):
- self.assertTrue(os.path.exists(filename))
- with open(filename, 'rb') as f:
- self.assertEqual(f.read(), contents)
-
- def assert_does_not_exist(self, filename):
- self.assertFalse(os.path.exists(filename))
-
- def test_run_is_final_job(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.add_stubbed_get_object_response()
- self.transfer_monitor.notify_expected_jobs_to_complete(
- self.transfer_id, 1
- )
-
- self.worker.run()
- self.stubber.assert_no_pending_responses()
- self.assert_does_not_exist(self.temp_filename)
- self.assert_contents(self.final_filename, self.remote_contents)
-
- def test_run_jobs_is_not_final_job(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.add_stubbed_get_object_response()
- self.transfer_monitor.notify_expected_jobs_to_complete(
- self.transfer_id, 1000
- )
-
- self.worker.run()
- self.stubber.assert_no_pending_responses()
- self.assert_contents(self.temp_filename, self.remote_contents)
- self.assert_does_not_exist(self.final_filename)
-
- def test_run_with_extra_args(self):
- self.add_get_object_job(extra_args={'VersionId': 'versionid'})
- self.add_shutdown()
- self.add_stubbed_get_object_response(
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'VersionId': 'versionid',
- }
- )
-
- self.worker.run()
- self.stubber.assert_no_pending_responses()
-
- def test_run_with_offset(self):
- offset = 1
- self.add_get_object_job(offset=offset)
- self.add_shutdown()
- self.add_stubbed_get_object_response()
-
- self.worker.run()
- with open(self.temp_filename, 'rb') as f:
- f.seek(offset)
- self.assertEqual(f.read(), self.remote_contents)
-
- def test_run_error_in_get_object(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.stubber.add_client_error('get_object', 'NoSuchKey', 404)
- self.add_stubbed_get_object_response()
-
- self.worker.run()
- self.assertIsInstance(
- self.transfer_monitor.get_exception(self.transfer_id), ClientError
- )
-
- def test_run_does_retries_for_get_object(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.add_stubbed_get_object_response(
- body=StreamWithError(
- self.stream, ReadTimeoutError(endpoint_url='')
- )
- )
- self.add_stubbed_get_object_response()
-
- self.worker.run()
- self.stubber.assert_no_pending_responses()
- self.assert_contents(self.temp_filename, self.remote_contents)
-
- def test_run_can_exhaust_retries_for_get_object(self):
- self.add_get_object_job()
- self.add_shutdown()
- # 5 is the current setting for max number of GetObject attempts
- for _ in range(5):
- self.add_stubbed_get_object_response(
- body=StreamWithError(
- self.stream, ReadTimeoutError(endpoint_url='')
- )
- )
-
- self.worker.run()
- self.stubber.assert_no_pending_responses()
- self.assertIsInstance(
- self.transfer_monitor.get_exception(self.transfer_id),
- RetriesExceededError,
- )
-
- def test_run_skips_get_object_on_previous_exception(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.transfer_monitor.notify_exception(self.transfer_id, Exception())
-
- self.worker.run()
- # Note we did not add a stubbed response for get_object
- self.stubber.assert_no_pending_responses()
-
- def test_run_final_job_removes_file_on_previous_exception(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.transfer_monitor.notify_exception(self.transfer_id, Exception())
- self.transfer_monitor.notify_expected_jobs_to_complete(
- self.transfer_id, 1
- )
-
- self.worker.run()
- self.stubber.assert_no_pending_responses()
- self.assert_does_not_exist(self.temp_filename)
- self.assert_does_not_exist(self.final_filename)
-
- def test_run_fails_to_rename_file(self):
- exception = OSError()
- osutil = RenameFailingOSUtils(exception)
- self.worker = GetObjectWorker(
- queue=self.queue,
- client_factory=self.client_factory,
- transfer_monitor=self.transfer_monitor,
- osutil=osutil,
- )
- self.add_get_object_job()
- self.add_shutdown()
- self.add_stubbed_get_object_response()
- self.transfer_monitor.notify_expected_jobs_to_complete(
- self.transfer_id, 1
- )
-
- self.worker.run()
- self.assertEqual(
- self.transfer_monitor.get_exception(self.transfer_id), exception
- )
- self.assert_does_not_exist(self.temp_filename)
- self.assert_does_not_exist(self.final_filename)
-
- @skip_if_windows('os.kill() with SIGINT not supported on Windows')
- def test_worker_cannot_be_killed(self):
- self.add_get_object_job()
- self.add_shutdown()
- self.transfer_monitor.notify_expected_jobs_to_complete(
- self.transfer_id, 1
- )
-
- def raise_ctrl_c(**kwargs):
- os.kill(os.getpid(), signal.SIGINT)
-
- mock_client = mock.Mock()
- mock_client.get_object = raise_ctrl_c
- self.client_factory.create_client.return_value = mock_client
-
- try:
- self.worker.run()
- except KeyboardInterrupt:
- self.fail(
- 'The worker should have not been killed by the '
- 'KeyboardInterrupt'
- )
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import queue
+import signal
+import threading
+import time
+from io import BytesIO
+
+from botocore.client import BaseClient
+from botocore.config import Config
+from botocore.exceptions import ClientError, ReadTimeoutError
+
+from s3transfer.constants import PROCESS_USER_AGENT
+from s3transfer.exceptions import CancelledError, RetriesExceededError
+from s3transfer.processpool import (
+ SHUTDOWN_SIGNAL,
+ ClientFactory,
+ DownloadFileRequest,
+ GetObjectJob,
+ GetObjectSubmitter,
+ GetObjectWorker,
+ ProcessPoolDownloader,
+ ProcessPoolTransferFuture,
+ ProcessPoolTransferMeta,
+ ProcessTransferConfig,
+ TransferMonitor,
+ TransferState,
+ ignore_ctrl_c,
+)
+from s3transfer.utils import CallArgs, OSUtils
+from __tests__ import (
+ FileCreator,
+ StreamWithError,
+ StubbedClientTest,
+ mock,
+ skip_if_windows,
+ unittest,
+)
+
+
+class RenameFailingOSUtils(OSUtils):
+ def __init__(self, exception):
+ self.exception = exception
+
+ def rename_file(self, current_filename, new_filename):
+ raise self.exception
+
+
+class TestIgnoreCtrlC(unittest.TestCase):
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_ignore_ctrl_c(self):
+ with ignore_ctrl_c():
+ try:
+ os.kill(os.getpid(), signal.SIGINT)
+ except KeyboardInterrupt:
+ self.fail(
+ 'The ignore_ctrl_c context manager should have '
+ 'ignored the KeyboardInterrupt exception'
+ )
+
+
+class TestProcessPoolDownloader(unittest.TestCase):
+ def test_uses_client_kwargs(self):
+ with mock.patch('s3transfer.processpool.ClientFactory') as factory:
+ ProcessPoolDownloader(client_kwargs={'region_name': 'myregion'})
+ self.assertEqual(
+ factory.call_args[0][0], {'region_name': 'myregion'}
+ )
+
+
+class TestProcessPoolTransferFuture(unittest.TestCase):
+ def setUp(self):
+ self.monitor = TransferMonitor()
+ self.transfer_id = self.monitor.notify_new_transfer()
+ self.meta = ProcessPoolTransferMeta(
+ transfer_id=self.transfer_id, call_args=CallArgs()
+ )
+ self.future = ProcessPoolTransferFuture(
+ monitor=self.monitor, meta=self.meta
+ )
+
+ def test_meta(self):
+ self.assertEqual(self.future.meta, self.meta)
+
+ def test_done(self):
+ self.assertFalse(self.future.done())
+ self.monitor.notify_done(self.transfer_id)
+ self.assertTrue(self.future.done())
+
+ def test_result(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.assertIsNone(self.future.result())
+
+ def test_result_with_exception(self):
+ self.monitor.notify_exception(self.transfer_id, RuntimeError())
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(RuntimeError):
+ self.future.result()
+
+ def test_result_with_keyboard_interrupt(self):
+ mock_monitor = mock.Mock(TransferMonitor)
+ mock_monitor._connect = mock.Mock()
+ mock_monitor.poll_for_result.side_effect = KeyboardInterrupt()
+ future = ProcessPoolTransferFuture(
+ monitor=mock_monitor, meta=self.meta
+ )
+ with self.assertRaises(KeyboardInterrupt):
+ future.result()
+ self.assertTrue(mock_monitor._connect.called)
+ self.assertTrue(mock_monitor.notify_exception.called)
+ call_args = mock_monitor.notify_exception.call_args[0]
+ self.assertEqual(call_args[0], self.transfer_id)
+ self.assertIsInstance(call_args[1], CancelledError)
+
+ def test_cancel(self):
+ self.future.cancel()
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(CancelledError):
+ self.future.result()
+
+
+class TestProcessPoolTransferMeta(unittest.TestCase):
+ def test_transfer_id(self):
+ meta = ProcessPoolTransferMeta(1, CallArgs())
+ self.assertEqual(meta.transfer_id, 1)
+
+ def test_call_args(self):
+ call_args = CallArgs()
+ meta = ProcessPoolTransferMeta(1, call_args)
+ self.assertEqual(meta.call_args, call_args)
+
+ def test_user_context(self):
+ meta = ProcessPoolTransferMeta(1, CallArgs())
+ self.assertEqual(meta.user_context, {})
+ meta.user_context['mykey'] = 'myvalue'
+ self.assertEqual(meta.user_context, {'mykey': 'myvalue'})
+
+
+class TestClientFactory(unittest.TestCase):
+ def test_create_client(self):
+ client = ClientFactory().create_client()
+ self.assertIsInstance(client, BaseClient)
+ self.assertEqual(client.meta.service_model.service_name, 's3')
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_create_client_with_client_kwargs(self):
+ client = ClientFactory({'region_name': 'myregion'}).create_client()
+ self.assertEqual(client.meta.region_name, 'myregion')
+
+ def test_user_agent_with_config(self):
+ client = ClientFactory({'config': Config()}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_user_agent_with_existing_user_agent_extra(self):
+ config = Config(user_agent_extra='foo/1.0')
+ client = ClientFactory({'config': config}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+ def test_user_agent_with_existing_user_agent(self):
+ config = Config(user_agent='foo/1.0')
+ client = ClientFactory({'config': config}).create_client()
+ self.assertIn(PROCESS_USER_AGENT, client.meta.config.user_agent)
+
+
+class TestTransferMonitor(unittest.TestCase):
+ def setUp(self):
+ self.monitor = TransferMonitor()
+ self.transfer_id = self.monitor.notify_new_transfer()
+
+ def test_notify_new_transfer_creates_new_state(self):
+ monitor = TransferMonitor()
+ transfer_id = monitor.notify_new_transfer()
+ self.assertFalse(monitor.is_done(transfer_id))
+ self.assertIsNone(monitor.get_exception(transfer_id))
+
+ def test_notify_new_transfer_increments_transfer_id(self):
+ monitor = TransferMonitor()
+ self.assertEqual(monitor.notify_new_transfer(), 0)
+ self.assertEqual(monitor.notify_new_transfer(), 1)
+
+ def test_notify_get_exception(self):
+ exception = Exception()
+ self.monitor.notify_exception(self.transfer_id, exception)
+ self.assertEqual(
+ self.monitor.get_exception(self.transfer_id), exception
+ )
+
+ def test_get_no_exception(self):
+ self.assertIsNone(self.monitor.get_exception(self.transfer_id))
+
+ def test_notify_jobs(self):
+ self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 0)
+
+ def test_notify_jobs_for_multiple_transfers(self):
+ self.monitor.notify_expected_jobs_to_complete(self.transfer_id, 2)
+ other_transfer_id = self.monitor.notify_new_transfer()
+ self.monitor.notify_expected_jobs_to_complete(other_transfer_id, 2)
+ self.assertEqual(self.monitor.notify_job_complete(self.transfer_id), 1)
+ self.assertEqual(
+ self.monitor.notify_job_complete(other_transfer_id), 1
+ )
+
+ def test_done(self):
+ self.assertFalse(self.monitor.is_done(self.transfer_id))
+ self.monitor.notify_done(self.transfer_id)
+ self.assertTrue(self.monitor.is_done(self.transfer_id))
+
+ def test_poll_for_result(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.assertIsNone(self.monitor.poll_for_result(self.transfer_id))
+
+ def test_poll_for_result_raises_error(self):
+ self.monitor.notify_exception(self.transfer_id, RuntimeError())
+ self.monitor.notify_done(self.transfer_id)
+ with self.assertRaises(RuntimeError):
+ self.monitor.poll_for_result(self.transfer_id)
+
+ def test_poll_for_result_waits_till_done(self):
+ event_order = []
+
+ def sleep_then_notify_done():
+ time.sleep(0.05)
+ event_order.append('notify_done')
+ self.monitor.notify_done(self.transfer_id)
+
+ t = threading.Thread(target=sleep_then_notify_done)
+ t.start()
+
+ self.monitor.poll_for_result(self.transfer_id)
+ event_order.append('done_polling')
+ self.assertEqual(event_order, ['notify_done', 'done_polling'])
+
+ def test_notify_cancel_all_in_progress(self):
+ monitor = TransferMonitor()
+ transfer_ids = []
+ for _ in range(10):
+ transfer_ids.append(monitor.notify_new_transfer())
+ monitor.notify_cancel_all_in_progress()
+ for transfer_id in transfer_ids:
+ self.assertIsInstance(
+ monitor.get_exception(transfer_id), CancelledError
+ )
+ # Cancelling a transfer does not mean it is done as there may
+ # be cleanup work left to do.
+ self.assertFalse(monitor.is_done(transfer_id))
+
+ def test_notify_cancel_does_not_affect_done_transfers(self):
+ self.monitor.notify_done(self.transfer_id)
+ self.monitor.notify_cancel_all_in_progress()
+ self.assertTrue(self.monitor.is_done(self.transfer_id))
+ self.assertIsNone(self.monitor.get_exception(self.transfer_id))
+
+
+class TestTransferState(unittest.TestCase):
+ def setUp(self):
+ self.state = TransferState()
+
+ def test_done(self):
+ self.assertFalse(self.state.done)
+ self.state.set_done()
+ self.assertTrue(self.state.done)
+
+ def test_waits_till_done_is_set(self):
+ event_order = []
+
+ def sleep_then_set_done():
+ time.sleep(0.05)
+ event_order.append('set_done')
+ self.state.set_done()
+
+ t = threading.Thread(target=sleep_then_set_done)
+ t.start()
+
+ self.state.wait_till_done()
+ event_order.append('done_waiting')
+ self.assertEqual(event_order, ['set_done', 'done_waiting'])
+
+ def test_exception(self):
+ exception = RuntimeError()
+ self.state.exception = exception
+ self.assertEqual(self.state.exception, exception)
+
+ def test_jobs_to_complete(self):
+ self.state.jobs_to_complete = 5
+ self.assertEqual(self.state.jobs_to_complete, 5)
+
+ def test_decrement_jobs_to_complete(self):
+ self.state.jobs_to_complete = 5
+ self.assertEqual(self.state.decrement_jobs_to_complete(), 4)
+
+
+class TestGetObjectSubmitter(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.transfer_config = ProcessTransferConfig()
+ self.client_factory = mock.Mock(ClientFactory)
+ self.client_factory.create_client.return_value = self.client
+ self.transfer_monitor = TransferMonitor()
+ self.osutil = mock.Mock(OSUtils)
+ self.download_request_queue = queue.Queue()
+ self.worker_queue = queue.Queue()
+ self.submitter = GetObjectSubmitter(
+ transfer_config=self.transfer_config,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=self.osutil,
+ download_request_queue=self.download_request_queue,
+ worker_queue=self.worker_queue,
+ )
+ self.transfer_id = self.transfer_monitor.notify_new_transfer()
+ self.bucket = 'bucket'
+ self.key = 'key'
+ self.filename = 'myfile'
+ self.temp_filename = 'myfile.temp'
+ self.osutil.get_temp_filename.return_value = self.temp_filename
+ self.extra_args = {}
+ self.expected_size = None
+
+ def add_download_file_request(self, **override_kwargs):
+ kwargs = {
+ 'transfer_id': self.transfer_id,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'filename': self.filename,
+ 'extra_args': self.extra_args,
+ 'expected_size': self.expected_size,
+ }
+ kwargs.update(override_kwargs)
+ self.download_request_queue.put(DownloadFileRequest(**kwargs))
+
+ def add_shutdown(self):
+ self.download_request_queue.put(SHUTDOWN_SIGNAL)
+
+ def assert_submitted_get_object_jobs(self, expected_jobs):
+ actual_jobs = []
+ while not self.worker_queue.empty():
+ actual_jobs.append(self.worker_queue.get())
+ self.assertEqual(actual_jobs, expected_jobs)
+
+ def test_run_for_non_ranged_download(self):
+ self.add_download_file_request(expected_size=1)
+ self.add_shutdown()
+ self.submitter.run()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_for_ranged_download(self):
+ self.transfer_config.multipart_chunksize = 2
+ self.transfer_config.multipart_threshold = 4
+ self.add_download_file_request(expected_size=4)
+ self.add_shutdown()
+ self.submitter.run()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 4)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={'Range': 'bytes=0-1'},
+ filename=self.filename,
+ ),
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=2,
+ extra_args={'Range': 'bytes=2-'},
+ filename=self.filename,
+ ),
+ ]
+ )
+
+ def test_run_when_expected_size_not_provided(self):
+ self.stubber.add_response(
+ 'head_object',
+ {'ContentLength': 1},
+ expected_params={'Bucket': self.bucket, 'Key': self.key},
+ )
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_with_extra_args(self):
+ self.stubber.add_response(
+ 'head_object',
+ {'ContentLength': 1},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ },
+ )
+ self.add_download_file_request(
+ extra_args={'VersionId': 'versionid'}, expected_size=None
+ )
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.osutil.allocate.assert_called_with(self.temp_filename, 1)
+ self.assert_submitted_get_object_jobs(
+ [
+ GetObjectJob(
+ transfer_id=self.transfer_id,
+ bucket=self.bucket,
+ key=self.key,
+ temp_filename=self.temp_filename,
+ offset=0,
+ extra_args={'VersionId': 'versionid'},
+ filename=self.filename,
+ )
+ ]
+ )
+
+ def test_run_with_exception(self):
+ self.stubber.add_client_error('head_object', 'NoSuchKey', 404)
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+ self.submitter.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_submitted_get_object_jobs([])
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), ClientError
+ )
+
+ def test_run_with_error_in_allocating_temp_file(self):
+ self.osutil.allocate.side_effect = OSError()
+ self.add_download_file_request(expected_size=1)
+ self.add_shutdown()
+ self.submitter.run()
+ self.assert_submitted_get_object_jobs([])
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), OSError
+ )
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_submitter_cannot_be_killed(self):
+ self.add_download_file_request(expected_size=None)
+ self.add_shutdown()
+
+ def raise_ctrl_c(**kwargs):
+ os.kill(os.getpid(), signal.SIGINT)
+
+ mock_client = mock.Mock()
+ mock_client.head_object = raise_ctrl_c
+ self.client_factory.create_client.return_value = mock_client
+
+ try:
+ self.submitter.run()
+ except KeyboardInterrupt:
+ self.fail(
+ 'The submitter should have not been killed by the '
+ 'KeyboardInterrupt'
+ )
+
+
+class TestGetObjectWorker(StubbedClientTest):
+ def setUp(self):
+ super().setUp()
+ self.files = FileCreator()
+ self.queue = queue.Queue()
+ self.client_factory = mock.Mock(ClientFactory)
+ self.client_factory.create_client.return_value = self.client
+ self.transfer_monitor = TransferMonitor()
+ self.osutil = OSUtils()
+ self.worker = GetObjectWorker(
+ queue=self.queue,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=self.osutil,
+ )
+ self.transfer_id = self.transfer_monitor.notify_new_transfer()
+ self.bucket = 'bucket'
+ self.key = 'key'
+ self.remote_contents = b'my content'
+ self.temp_filename = self.files.create_file('tempfile', '')
+ self.extra_args = {}
+ self.offset = 0
+ self.final_filename = self.files.full_path('final_filename')
+ self.stream = BytesIO(self.remote_contents)
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1000
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ self.files.remove_all()
+
+ def add_get_object_job(self, **override_kwargs):
+ kwargs = {
+ 'transfer_id': self.transfer_id,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'temp_filename': self.temp_filename,
+ 'extra_args': self.extra_args,
+ 'offset': self.offset,
+ 'filename': self.final_filename,
+ }
+ kwargs.update(override_kwargs)
+ self.queue.put(GetObjectJob(**kwargs))
+
+ def add_shutdown(self):
+ self.queue.put(SHUTDOWN_SIGNAL)
+
+ def add_stubbed_get_object_response(self, body=None, expected_params=None):
+ if body is None:
+ body = self.stream
+ get_object_response = {'Body': body}
+
+ if expected_params is None:
+ expected_params = {'Bucket': self.bucket, 'Key': self.key}
+
+ self.stubber.add_response(
+ 'get_object', get_object_response, expected_params
+ )
+
+ def assert_contents(self, filename, contents):
+ self.assertTrue(os.path.exists(filename))
+ with open(filename, 'rb') as f:
+ self.assertEqual(f.read(), contents)
+
+ def assert_does_not_exist(self, filename):
+ self.assertFalse(os.path.exists(filename))
+
+ def test_run_is_final_job(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_contents(self.final_filename, self.remote_contents)
+
+ def test_run_jobs_is_not_final_job(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1000
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_contents(self.temp_filename, self.remote_contents)
+ self.assert_does_not_exist(self.final_filename)
+
+ def test_run_with_extra_args(self):
+ self.add_get_object_job(extra_args={'VersionId': 'versionid'})
+ self.add_shutdown()
+ self.add_stubbed_get_object_response(
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'VersionId': 'versionid',
+ }
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+
+ def test_run_with_offset(self):
+ offset = 1
+ self.add_get_object_job(offset=offset)
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ with open(self.temp_filename, 'rb') as f:
+ f.seek(offset)
+ self.assertEqual(f.read(), self.remote_contents)
+
+ def test_run_error_in_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.stubber.add_client_error('get_object', 'NoSuchKey', 404)
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id), ClientError
+ )
+
+ def test_run_does_retries_for_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response(
+ body=StreamWithError(
+ self.stream, ReadTimeoutError(endpoint_url='')
+ )
+ )
+ self.add_stubbed_get_object_response()
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_contents(self.temp_filename, self.remote_contents)
+
+ def test_run_can_exhaust_retries_for_get_object(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ # 5 is the current setting for max number of GetObject attempts
+ for _ in range(5):
+ self.add_stubbed_get_object_response(
+ body=StreamWithError(
+ self.stream, ReadTimeoutError(endpoint_url='')
+ )
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assertIsInstance(
+ self.transfer_monitor.get_exception(self.transfer_id),
+ RetriesExceededError,
+ )
+
+ def test_run_skips_get_object_on_previous_exception(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_exception(self.transfer_id, Exception())
+
+ self.worker.run()
+ # Note we did not add a stubbed response for get_object
+ self.stubber.assert_no_pending_responses()
+
+ def test_run_final_job_removes_file_on_previous_exception(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_exception(self.transfer_id, Exception())
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.stubber.assert_no_pending_responses()
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_does_not_exist(self.final_filename)
+
+ def test_run_fails_to_rename_file(self):
+ exception = OSError()
+ osutil = RenameFailingOSUtils(exception)
+ self.worker = GetObjectWorker(
+ queue=self.queue,
+ client_factory=self.client_factory,
+ transfer_monitor=self.transfer_monitor,
+ osutil=osutil,
+ )
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.add_stubbed_get_object_response()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ self.worker.run()
+ self.assertEqual(
+ self.transfer_monitor.get_exception(self.transfer_id), exception
+ )
+ self.assert_does_not_exist(self.temp_filename)
+ self.assert_does_not_exist(self.final_filename)
+
+ @skip_if_windows('os.kill() with SIGINT not supported on Windows')
+ def test_worker_cannot_be_killed(self):
+ self.add_get_object_job()
+ self.add_shutdown()
+ self.transfer_monitor.notify_expected_jobs_to_complete(
+ self.transfer_id, 1
+ )
+
+ def raise_ctrl_c(**kwargs):
+ os.kill(os.getpid(), signal.SIGINT)
+
+ mock_client = mock.Mock()
+ mock_client.get_object = raise_ctrl_c
+ self.client_factory.create_client.return_value = mock_client
+
+ try:
+ self.worker.run()
+ except KeyboardInterrupt:
+ self.fail(
+ 'The worker should have not been killed by the '
+ 'KeyboardInterrupt'
+ )
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
index 1aababd760..35cf4a22dd 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_s3transfer.py
@@ -1,780 +1,780 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License'). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the 'license' file accompanying this file. This file is
-# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import os
-import shutil
-import socket
-import tempfile
-from concurrent import futures
-from contextlib import closing
-from io import BytesIO, StringIO
-
-from s3transfer import (
- MultipartDownloader,
- MultipartUploader,
- OSUtils,
- QueueShutdownError,
- ReadFileChunk,
- S3Transfer,
- ShutdownQueue,
- StreamReaderProgress,
- TransferConfig,
- disable_upload_callbacks,
- enable_upload_callbacks,
- random_file_extension,
-)
-from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
-from __tests__ import mock, unittest
-
-
-class InMemoryOSLayer(OSUtils):
- def __init__(self, filemap):
- self.filemap = filemap
-
- def get_file_size(self, filename):
- return len(self.filemap[filename])
-
- def open_file_chunk_reader(self, filename, start_byte, size, callback):
- return closing(BytesIO(self.filemap[filename]))
-
- def open(self, filename, mode):
- if 'wb' in mode:
- fileobj = BytesIO()
- self.filemap[filename] = fileobj
- return closing(fileobj)
- else:
- return closing(self.filemap[filename])
-
- def remove_file(self, filename):
- if filename in self.filemap:
- del self.filemap[filename]
-
- def rename_file(self, current_filename, new_filename):
- if current_filename in self.filemap:
- self.filemap[new_filename] = self.filemap.pop(current_filename)
-
-
-class SequentialExecutor:
- def __init__(self, max_workers):
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args, **kwargs):
- pass
-
- # The real map() interface actually takes *args, but we specifically do
- # _not_ use this interface.
- def map(self, function, args):
- results = []
- for arg in args:
- results.append(function(arg))
- return results
-
- def submit(self, function):
- future = futures.Future()
- future.set_result(function())
- return future
-
-
-class TestOSUtils(unittest.TestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def test_get_file_size(self):
- with mock.patch('os.path.getsize') as m:
- OSUtils().get_file_size('myfile')
- m.assert_called_with('myfile')
-
- def test_open_file_chunk_reader(self):
- with mock.patch('s3transfer.ReadFileChunk') as m:
- OSUtils().open_file_chunk_reader('myfile', 0, 100, None)
- m.from_filename.assert_called_with(
- 'myfile', 0, 100, None, enable_callback=False
- )
-
- def test_open_file(self):
- fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
- self.assertTrue(hasattr(fileobj, 'write'))
-
- def test_remove_file_ignores_errors(self):
- with mock.patch('os.remove') as remove:
- remove.side_effect = OSError('fake error')
- OSUtils().remove_file('foo')
- remove.assert_called_with('foo')
-
- def test_remove_file_proxies_remove_file(self):
- with mock.patch('os.remove') as remove:
- OSUtils().remove_file('foo')
- remove.assert_called_with('foo')
-
- def test_rename_file(self):
- with mock.patch('s3transfer.compat.rename_file') as rename_file:
- OSUtils().rename_file('foo', 'newfoo')
- rename_file.assert_called_with('foo', 'newfoo')
-
-
-class TestReadFileChunk(unittest.TestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def test_read_entire_chunk(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=0, chunk_size=3
- )
- self.assertEqual(chunk.read(), b'one')
- self.assertEqual(chunk.read(), b'')
-
- def test_read_with_amount_size(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=11, chunk_size=4
- )
- self.assertEqual(chunk.read(1), b'f')
- self.assertEqual(chunk.read(1), b'o')
- self.assertEqual(chunk.read(1), b'u')
- self.assertEqual(chunk.read(1), b'r')
- self.assertEqual(chunk.read(1), b'')
-
- def test_reset_stream_emulation(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=11, chunk_size=4
- )
- self.assertEqual(chunk.read(), b'four')
- chunk.seek(0)
- self.assertEqual(chunk.read(), b'four')
-
- def test_read_past_end_of_file(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=36, chunk_size=100000
- )
- self.assertEqual(chunk.read(), b'ten')
- self.assertEqual(chunk.read(), b'')
- self.assertEqual(len(chunk), 3)
-
- def test_tell_and_seek(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=36, chunk_size=100000
- )
- self.assertEqual(chunk.tell(), 0)
- self.assertEqual(chunk.read(), b'ten')
- self.assertEqual(chunk.tell(), 3)
- chunk.seek(0)
- self.assertEqual(chunk.tell(), 0)
-
- def test_file_chunk_supports_context_manager(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'abc')
- with ReadFileChunk.from_filename(
- filename, start_byte=0, chunk_size=2
- ) as chunk:
- val = chunk.read()
- self.assertEqual(val, b'ab')
-
- def test_iter_is_always_empty(self):
- # This tests the workaround for the httplib bug (see
- # the source for more info).
- filename = os.path.join(self.tempdir, 'foo')
- open(filename, 'wb').close()
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=0, chunk_size=10
- )
- self.assertEqual(list(chunk), [])
-
-
-class TestReadFileChunkWithCallback(TestReadFileChunk):
- def setUp(self):
- super().setUp()
- self.filename = os.path.join(self.tempdir, 'foo')
- with open(self.filename, 'wb') as f:
- f.write(b'abc')
- self.amounts_seen = []
-
- def callback(self, amount):
- self.amounts_seen.append(amount)
-
- def test_callback_is_invoked_on_read(self):
- chunk = ReadFileChunk.from_filename(
- self.filename, start_byte=0, chunk_size=3, callback=self.callback
- )
- chunk.read(1)
- chunk.read(1)
- chunk.read(1)
- self.assertEqual(self.amounts_seen, [1, 1, 1])
-
- def test_callback_can_be_disabled(self):
- chunk = ReadFileChunk.from_filename(
- self.filename, start_byte=0, chunk_size=3, callback=self.callback
- )
- chunk.disable_callback()
- # Now reading from the ReadFileChunk should not invoke
- # the callback.
- chunk.read()
- self.assertEqual(self.amounts_seen, [])
-
- def test_callback_will_also_be_triggered_by_seek(self):
- chunk = ReadFileChunk.from_filename(
- self.filename, start_byte=0, chunk_size=3, callback=self.callback
- )
- chunk.read(2)
- chunk.seek(0)
- chunk.read(2)
- chunk.seek(1)
- chunk.read(2)
- self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
-
-
-class TestStreamReaderProgress(unittest.TestCase):
- def test_proxies_to_wrapped_stream(self):
- original_stream = StringIO('foobarbaz')
- wrapped = StreamReaderProgress(original_stream)
- self.assertEqual(wrapped.read(), 'foobarbaz')
-
- def test_callback_invoked(self):
- amounts_seen = []
-
- def callback(amount):
- amounts_seen.append(amount)
-
- original_stream = StringIO('foobarbaz')
- wrapped = StreamReaderProgress(original_stream, callback)
- self.assertEqual(wrapped.read(), 'foobarbaz')
- self.assertEqual(amounts_seen, [9])
-
-
-class TestMultipartUploader(unittest.TestCase):
- def test_multipart_upload_uses_correct_client_calls(self):
- client = mock.Mock()
- uploader = MultipartUploader(
- client,
- TransferConfig(),
- InMemoryOSLayer({'filename': b'foobar'}),
- SequentialExecutor,
- )
- client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
- client.upload_part.return_value = {'ETag': 'first'}
-
- uploader.upload_file('filename', 'bucket', 'key', None, {})
-
- # We need to check both the sequence of calls (create/upload/complete)
- # as well as the params passed between the calls, including
- # 1. The upload_id was plumbed through
- # 2. The collected etags were added to the complete call.
- client.create_multipart_upload.assert_called_with(
- Bucket='bucket', Key='key'
- )
- # Should be two parts.
- client.upload_part.assert_called_with(
- Body=mock.ANY,
- Bucket='bucket',
- UploadId='upload_id',
- Key='key',
- PartNumber=1,
- )
- client.complete_multipart_upload.assert_called_with(
- MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
- Bucket='bucket',
- UploadId='upload_id',
- Key='key',
- )
-
- def test_multipart_upload_injects_proper_kwargs(self):
- client = mock.Mock()
- uploader = MultipartUploader(
- client,
- TransferConfig(),
- InMemoryOSLayer({'filename': b'foobar'}),
- SequentialExecutor,
- )
- client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
- client.upload_part.return_value = {'ETag': 'first'}
-
- extra_args = {
- 'SSECustomerKey': 'fakekey',
- 'SSECustomerAlgorithm': 'AES256',
- 'StorageClass': 'REDUCED_REDUNDANCY',
- }
- uploader.upload_file('filename', 'bucket', 'key', None, extra_args)
-
- client.create_multipart_upload.assert_called_with(
- Bucket='bucket',
- Key='key',
- # The initial call should inject all the storage class params.
- SSECustomerKey='fakekey',
- SSECustomerAlgorithm='AES256',
- StorageClass='REDUCED_REDUNDANCY',
- )
- client.upload_part.assert_called_with(
- Body=mock.ANY,
- Bucket='bucket',
- UploadId='upload_id',
- Key='key',
- PartNumber=1,
- # We only have to forward certain **extra_args in subsequent
- # UploadPart calls.
- SSECustomerKey='fakekey',
- SSECustomerAlgorithm='AES256',
- )
- client.complete_multipart_upload.assert_called_with(
- MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
- Bucket='bucket',
- UploadId='upload_id',
- Key='key',
- )
-
- def test_multipart_upload_is_aborted_on_error(self):
- # If the create_multipart_upload succeeds and any upload_part
- # fails, then abort_multipart_upload will be called.
- client = mock.Mock()
- uploader = MultipartUploader(
- client,
- TransferConfig(),
- InMemoryOSLayer({'filename': b'foobar'}),
- SequentialExecutor,
- )
- client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
- client.upload_part.side_effect = Exception(
- "Some kind of error occurred."
- )
-
- with self.assertRaises(S3UploadFailedError):
- uploader.upload_file('filename', 'bucket', 'key', None, {})
-
- client.abort_multipart_upload.assert_called_with(
- Bucket='bucket', Key='key', UploadId='upload_id'
- )
-
-
-class TestMultipartDownloader(unittest.TestCase):
-
- maxDiff = None
-
- def test_multipart_download_uses_correct_client_calls(self):
- client = mock.Mock()
- response_body = b'foobarbaz'
- client.get_object.return_value = {'Body': BytesIO(response_body)}
-
- downloader = MultipartDownloader(
- client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor
- )
- downloader.download_file(
- 'bucket', 'key', 'filename', len(response_body), {}
- )
-
- client.get_object.assert_called_with(
- Range='bytes=0-', Bucket='bucket', Key='key'
- )
-
- def test_multipart_download_with_multiple_parts(self):
- client = mock.Mock()
- response_body = b'foobarbaz'
- client.get_object.return_value = {'Body': BytesIO(response_body)}
- # For testing purposes, we're testing with a multipart threshold
- # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz',
- # this should result in 3 calls. In python slices this would be:
- # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly
- # different because they use inclusive ranges.
- config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
-
- downloader = MultipartDownloader(
- client, config, InMemoryOSLayer({}), SequentialExecutor
- )
- downloader.download_file(
- 'bucket', 'key', 'filename', len(response_body), {}
- )
-
- # We're storing these in **extra because the assertEqual
- # below is really about verifying we have the correct value
- # for the Range param.
- extra = {'Bucket': 'bucket', 'Key': 'key'}
- self.assertEqual(
- client.get_object.call_args_list,
- # Note these are inclusive ranges.
- [
- mock.call(Range='bytes=0-3', **extra),
- mock.call(Range='bytes=4-7', **extra),
- mock.call(Range='bytes=8-', **extra),
- ],
- )
-
- def test_retry_on_failures_from_stream_reads(self):
- # If we get an exception during a call to the response body's .read()
- # method, we should retry the request.
- client = mock.Mock()
- response_body = b'foobarbaz'
- stream_with_errors = mock.Mock()
- stream_with_errors.read.side_effect = [
- socket.error("fake error"),
- response_body,
- ]
- client.get_object.return_value = {'Body': stream_with_errors}
- config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
-
- downloader = MultipartDownloader(
- client, config, InMemoryOSLayer({}), SequentialExecutor
- )
- downloader.download_file(
- 'bucket', 'key', 'filename', len(response_body), {}
- )
-
- # We're storing these in **extra because the assertEqual
- # below is really about verifying we have the correct value
- # for the Range param.
- extra = {'Bucket': 'bucket', 'Key': 'key'}
- self.assertEqual(
- client.get_object.call_args_list,
- # The first call to range=0-3 fails because of the
- # side_effect above where we make the .read() raise a
- # socket.error.
- # The second call to range=0-3 then succeeds.
- [
- mock.call(Range='bytes=0-3', **extra),
- mock.call(Range='bytes=0-3', **extra),
- mock.call(Range='bytes=4-7', **extra),
- mock.call(Range='bytes=8-', **extra),
- ],
- )
-
- def test_exception_raised_on_exceeded_retries(self):
- client = mock.Mock()
- response_body = b'foobarbaz'
- stream_with_errors = mock.Mock()
- stream_with_errors.read.side_effect = socket.error("fake error")
- client.get_object.return_value = {'Body': stream_with_errors}
- config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
-
- downloader = MultipartDownloader(
- client, config, InMemoryOSLayer({}), SequentialExecutor
- )
- with self.assertRaises(RetriesExceededError):
- downloader.download_file(
- 'bucket', 'key', 'filename', len(response_body), {}
- )
-
- def test_io_thread_failure_triggers_shutdown(self):
- client = mock.Mock()
- response_body = b'foobarbaz'
- client.get_object.return_value = {'Body': BytesIO(response_body)}
- os_layer = mock.Mock()
- mock_fileobj = mock.MagicMock()
- mock_fileobj.__enter__.return_value = mock_fileobj
- mock_fileobj.write.side_effect = Exception("fake IO error")
- os_layer.open.return_value = mock_fileobj
-
- downloader = MultipartDownloader(
- client, TransferConfig(), os_layer, SequentialExecutor
- )
- # We're verifying that the exception raised from the IO future
- # propagates back up via download_file().
- with self.assertRaisesRegex(Exception, "fake IO error"):
- downloader.download_file(
- 'bucket', 'key', 'filename', len(response_body), {}
- )
-
- def test_download_futures_fail_triggers_shutdown(self):
- class FailedDownloadParts(SequentialExecutor):
- def __init__(self, max_workers):
- self.is_first = True
-
- def submit(self, function):
- future = futures.Future()
- if self.is_first:
- # This is the download_parts_thread.
- future.set_exception(
- Exception("fake download parts error")
- )
- self.is_first = False
- return future
-
- client = mock.Mock()
- response_body = b'foobarbaz'
- client.get_object.return_value = {'Body': BytesIO(response_body)}
-
- downloader = MultipartDownloader(
- client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts
- )
- with self.assertRaisesRegex(Exception, "fake download parts error"):
- downloader.download_file(
- 'bucket', 'key', 'filename', len(response_body), {}
- )
-
-
-class TestS3Transfer(unittest.TestCase):
- def setUp(self):
- self.client = mock.Mock()
- self.random_file_patch = mock.patch('s3transfer.random_file_extension')
- self.random_file = self.random_file_patch.start()
- self.random_file.return_value = 'RANDOM'
-
- def tearDown(self):
- self.random_file_patch.stop()
-
- def test_callback_handlers_register_on_put_item(self):
- osutil = InMemoryOSLayer({'smallfile': b'foobar'})
- transfer = S3Transfer(self.client, osutil=osutil)
- transfer.upload_file('smallfile', 'bucket', 'key')
- events = self.client.meta.events
- events.register_first.assert_called_with(
- 'request-created.s3',
- disable_upload_callbacks,
- unique_id='s3upload-callback-disable',
- )
- events.register_last.assert_called_with(
- 'request-created.s3',
- enable_upload_callbacks,
- unique_id='s3upload-callback-enable',
- )
-
- def test_upload_below_multipart_threshold_uses_put_object(self):
- fake_files = {
- 'smallfile': b'foobar',
- }
- osutil = InMemoryOSLayer(fake_files)
- transfer = S3Transfer(self.client, osutil=osutil)
- transfer.upload_file('smallfile', 'bucket', 'key')
- self.client.put_object.assert_called_with(
- Bucket='bucket', Key='key', Body=mock.ANY
- )
-
- def test_extra_args_on_uploaded_passed_to_api_call(self):
- extra_args = {'ACL': 'public-read'}
- fake_files = {'smallfile': b'hello world'}
- osutil = InMemoryOSLayer(fake_files)
- transfer = S3Transfer(self.client, osutil=osutil)
- transfer.upload_file(
- 'smallfile', 'bucket', 'key', extra_args=extra_args
- )
- self.client.put_object.assert_called_with(
- Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read'
- )
-
- def test_uses_multipart_upload_when_over_threshold(self):
- with mock.patch('s3transfer.MultipartUploader') as uploader:
- fake_files = {
- 'smallfile': b'foobar',
- }
- osutil = InMemoryOSLayer(fake_files)
- config = TransferConfig(
- multipart_threshold=2, multipart_chunksize=2
- )
- transfer = S3Transfer(self.client, osutil=osutil, config=config)
- transfer.upload_file('smallfile', 'bucket', 'key')
-
- uploader.return_value.upload_file.assert_called_with(
- 'smallfile', 'bucket', 'key', None, {}
- )
-
- def test_uses_multipart_download_when_over_threshold(self):
- with mock.patch('s3transfer.MultipartDownloader') as downloader:
- osutil = InMemoryOSLayer({})
- over_multipart_threshold = 100 * 1024 * 1024
- transfer = S3Transfer(self.client, osutil=osutil)
- callback = mock.sentinel.CALLBACK
- self.client.head_object.return_value = {
- 'ContentLength': over_multipart_threshold,
- }
- transfer.download_file(
- 'bucket', 'key', 'filename', callback=callback
- )
-
- downloader.return_value.download_file.assert_called_with(
- # Note how we're downloading to a temporary random file.
- 'bucket',
- 'key',
- 'filename.RANDOM',
- over_multipart_threshold,
- {},
- callback,
- )
-
- def test_download_file_with_invalid_extra_args(self):
- below_threshold = 20
- osutil = InMemoryOSLayer({})
- transfer = S3Transfer(self.client, osutil=osutil)
- self.client.head_object.return_value = {
- 'ContentLength': below_threshold
- }
- with self.assertRaises(ValueError):
- transfer.download_file(
- 'bucket',
- 'key',
- '/tmp/smallfile',
- extra_args={'BadValue': 'foo'},
- )
-
- def test_upload_file_with_invalid_extra_args(self):
- osutil = InMemoryOSLayer({})
- transfer = S3Transfer(self.client, osutil=osutil)
- bad_args = {"WebsiteRedirectLocation": "/foo"}
- with self.assertRaises(ValueError):
- transfer.upload_file(
- 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args
- )
-
- def test_download_file_fowards_extra_args(self):
- extra_args = {
- 'SSECustomerKey': 'foo',
- 'SSECustomerAlgorithm': 'AES256',
- }
- below_threshold = 20
- osutil = InMemoryOSLayer({'smallfile': b'hello world'})
- transfer = S3Transfer(self.client, osutil=osutil)
- self.client.head_object.return_value = {
- 'ContentLength': below_threshold
- }
- self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
- transfer.download_file(
- 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args
- )
-
- # Note that we need to invoke the HeadObject call
- # and the PutObject call with the extra_args.
- # This is necessary. Trying to HeadObject an SSE object
- # will return a 400 if you don't provide the required
- # params.
- self.client.get_object.assert_called_with(
- Bucket='bucket',
- Key='key',
- SSECustomerAlgorithm='AES256',
- SSECustomerKey='foo',
- )
-
- def test_get_object_stream_is_retried_and_succeeds(self):
- below_threshold = 20
- osutil = InMemoryOSLayer({'smallfile': b'hello world'})
- transfer = S3Transfer(self.client, osutil=osutil)
- self.client.head_object.return_value = {
- 'ContentLength': below_threshold
- }
- self.client.get_object.side_effect = [
- # First request fails.
- socket.error("fake error"),
- # Second succeeds.
- {'Body': BytesIO(b'foobar')},
- ]
- transfer.download_file('bucket', 'key', '/tmp/smallfile')
-
- self.assertEqual(self.client.get_object.call_count, 2)
-
- def test_get_object_stream_uses_all_retries_and_errors_out(self):
- below_threshold = 20
- osutil = InMemoryOSLayer({})
- transfer = S3Transfer(self.client, osutil=osutil)
- self.client.head_object.return_value = {
- 'ContentLength': below_threshold
- }
- # Here we're raising an exception every single time, which
- # will exhaust our retry count and propagate a
- # RetriesExceededError.
- self.client.get_object.side_effect = socket.error("fake error")
- with self.assertRaises(RetriesExceededError):
- transfer.download_file('bucket', 'key', 'smallfile')
-
- self.assertEqual(self.client.get_object.call_count, 5)
- # We should have also cleaned up the in progress file
- # we were downloading to.
- self.assertEqual(osutil.filemap, {})
-
- def test_download_below_multipart_threshold(self):
- below_threshold = 20
- osutil = InMemoryOSLayer({'smallfile': b'hello world'})
- transfer = S3Transfer(self.client, osutil=osutil)
- self.client.head_object.return_value = {
- 'ContentLength': below_threshold
- }
- self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
- transfer.download_file('bucket', 'key', 'smallfile')
-
- self.client.get_object.assert_called_with(Bucket='bucket', Key='key')
-
- def test_can_create_with_just_client(self):
- transfer = S3Transfer(client=mock.Mock())
- self.assertIsInstance(transfer, S3Transfer)
-
-
-class TestShutdownQueue(unittest.TestCase):
- def test_handles_normal_put_get_requests(self):
- q = ShutdownQueue()
- q.put('foo')
- self.assertEqual(q.get(), 'foo')
-
- def test_put_raises_error_on_shutdown(self):
- q = ShutdownQueue()
- q.trigger_shutdown()
- with self.assertRaises(QueueShutdownError):
- q.put('foo')
-
-
-class TestRandomFileExtension(unittest.TestCase):
- def test_has_proper_length(self):
- self.assertEqual(len(random_file_extension(num_digits=4)), 4)
-
-
-class TestCallbackHandlers(unittest.TestCase):
- def setUp(self):
- self.request = mock.Mock()
-
- def test_disable_request_on_put_object(self):
- disable_upload_callbacks(self.request, 'PutObject')
- self.request.body.disable_callback.assert_called_with()
-
- def test_disable_request_on_upload_part(self):
- disable_upload_callbacks(self.request, 'UploadPart')
- self.request.body.disable_callback.assert_called_with()
-
- def test_enable_object_on_put_object(self):
- enable_upload_callbacks(self.request, 'PutObject')
- self.request.body.enable_callback.assert_called_with()
-
- def test_enable_object_on_upload_part(self):
- enable_upload_callbacks(self.request, 'UploadPart')
- self.request.body.enable_callback.assert_called_with()
-
- def test_dont_disable_if_missing_interface(self):
- del self.request.body.disable_callback
- disable_upload_callbacks(self.request, 'PutObject')
- self.assertEqual(self.request.body.method_calls, [])
-
- def test_dont_enable_if_missing_interface(self):
- del self.request.body.enable_callback
- enable_upload_callbacks(self.request, 'PutObject')
- self.assertEqual(self.request.body.method_calls, [])
-
- def test_dont_disable_if_wrong_operation(self):
- disable_upload_callbacks(self.request, 'OtherOperation')
- self.assertFalse(self.request.body.disable_callback.called)
-
- def test_dont_enable_if_wrong_operation(self):
- enable_upload_callbacks(self.request, 'OtherOperation')
- self.assertFalse(self.request.body.enable_callback.called)
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import os
+import shutil
+import socket
+import tempfile
+from concurrent import futures
+from contextlib import closing
+from io import BytesIO, StringIO
+
+from s3transfer import (
+ MultipartDownloader,
+ MultipartUploader,
+ OSUtils,
+ QueueShutdownError,
+ ReadFileChunk,
+ S3Transfer,
+ ShutdownQueue,
+ StreamReaderProgress,
+ TransferConfig,
+ disable_upload_callbacks,
+ enable_upload_callbacks,
+ random_file_extension,
+)
+from s3transfer.exceptions import RetriesExceededError, S3UploadFailedError
+from __tests__ import mock, unittest
+
+
+class InMemoryOSLayer(OSUtils):
+ def __init__(self, filemap):
+ self.filemap = filemap
+
+ def get_file_size(self, filename):
+ return len(self.filemap[filename])
+
+ def open_file_chunk_reader(self, filename, start_byte, size, callback):
+ return closing(BytesIO(self.filemap[filename]))
+
+ def open(self, filename, mode):
+ if 'wb' in mode:
+ fileobj = BytesIO()
+ self.filemap[filename] = fileobj
+ return closing(fileobj)
+ else:
+ return closing(self.filemap[filename])
+
+ def remove_file(self, filename):
+ if filename in self.filemap:
+ del self.filemap[filename]
+
+ def rename_file(self, current_filename, new_filename):
+ if current_filename in self.filemap:
+ self.filemap[new_filename] = self.filemap.pop(current_filename)
+
+
+class SequentialExecutor:
+ def __init__(self, max_workers):
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ pass
+
+ # The real map() interface actually takes *args, but we specifically do
+ # _not_ use this interface.
+ def map(self, function, args):
+ results = []
+ for arg in args:
+ results.append(function(arg))
+ return results
+
+ def submit(self, function):
+ future = futures.Future()
+ future.set_result(function())
+ return future
+
+
+class TestOSUtils(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_get_file_size(self):
+ with mock.patch('os.path.getsize') as m:
+ OSUtils().get_file_size('myfile')
+ m.assert_called_with('myfile')
+
+ def test_open_file_chunk_reader(self):
+ with mock.patch('s3transfer.ReadFileChunk') as m:
+ OSUtils().open_file_chunk_reader('myfile', 0, 100, None)
+ m.from_filename.assert_called_with(
+ 'myfile', 0, 100, None, enable_callback=False
+ )
+
+ def test_open_file(self):
+ fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
+ self.assertTrue(hasattr(fileobj, 'write'))
+
+ def test_remove_file_ignores_errors(self):
+ with mock.patch('os.remove') as remove:
+ remove.side_effect = OSError('fake error')
+ OSUtils().remove_file('foo')
+ remove.assert_called_with('foo')
+
+ def test_remove_file_proxies_remove_file(self):
+ with mock.patch('os.remove') as remove:
+ OSUtils().remove_file('foo')
+ remove.assert_called_with('foo')
+
+ def test_rename_file(self):
+ with mock.patch('s3transfer.compat.rename_file') as rename_file:
+ OSUtils().rename_file('foo', 'newfoo')
+ rename_file.assert_called_with('foo', 'newfoo')
+
+
+class TestReadFileChunk(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def test_read_entire_chunk(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=3
+ )
+ self.assertEqual(chunk.read(), b'one')
+ self.assertEqual(chunk.read(), b'')
+
+ def test_read_with_amount_size(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(1), b'f')
+ self.assertEqual(chunk.read(1), b'o')
+ self.assertEqual(chunk.read(1), b'u')
+ self.assertEqual(chunk.read(1), b'r')
+ self.assertEqual(chunk.read(1), b'')
+
+ def test_reset_stream_emulation(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(), b'four')
+ chunk.seek(0)
+ self.assertEqual(chunk.read(), b'four')
+
+ def test_read_past_end_of_file(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.read(), b'')
+ self.assertEqual(len(chunk), 3)
+
+ def test_tell_and_seek(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.tell(), 0)
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.tell(), 3)
+ chunk.seek(0)
+ self.assertEqual(chunk.tell(), 0)
+
+ def test_file_chunk_supports_context_manager(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'abc')
+ with ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=2
+ ) as chunk:
+ val = chunk.read()
+ self.assertEqual(val, b'ab')
+
+ def test_iter_is_always_empty(self):
+ # This tests the workaround for the httplib bug (see
+ # the source for more info).
+ filename = os.path.join(self.tempdir, 'foo')
+ open(filename, 'wb').close()
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=10
+ )
+ self.assertEqual(list(chunk), [])
+
+
+class TestReadFileChunkWithCallback(TestReadFileChunk):
+ def setUp(self):
+ super().setUp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ with open(self.filename, 'wb') as f:
+ f.write(b'abc')
+ self.amounts_seen = []
+
+ def callback(self, amount):
+ self.amounts_seen.append(amount)
+
+ def test_callback_is_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1, 1, 1])
+
+ def test_callback_can_be_disabled(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.disable_callback()
+ # Now reading from the ReadFileChunk should not invoke
+ # the callback.
+ chunk.read()
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_callback_will_also_be_triggered_by_seek(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename, start_byte=0, chunk_size=3, callback=self.callback
+ )
+ chunk.read(2)
+ chunk.seek(0)
+ chunk.read(2)
+ chunk.seek(1)
+ chunk.read(2)
+ self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
+
+
+class TestStreamReaderProgress(unittest.TestCase):
+ def test_proxies_to_wrapped_stream(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+
+ def test_callback_invoked(self):
+ amounts_seen = []
+
+ def callback(amount):
+ amounts_seen.append(amount)
+
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream, callback)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+ self.assertEqual(amounts_seen, [9])
+
+
+class TestMultipartUploader(unittest.TestCase):
+ def test_multipart_upload_uses_correct_client_calls(self):
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.return_value = {'ETag': 'first'}
+
+ uploader.upload_file('filename', 'bucket', 'key', None, {})
+
+ # We need to check both the sequence of calls (create/upload/complete)
+ # as well as the params passed between the calls, including
+ # 1. The upload_id was plumbed through
+ # 2. The collected etags were added to the complete call.
+ client.create_multipart_upload.assert_called_with(
+ Bucket='bucket', Key='key'
+ )
+ # Should be two parts.
+ client.upload_part.assert_called_with(
+ Body=mock.ANY,
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ PartNumber=1,
+ )
+ client.complete_multipart_upload.assert_called_with(
+ MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ )
+
+ def test_multipart_upload_injects_proper_kwargs(self):
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.return_value = {'ETag': 'first'}
+
+ extra_args = {
+ 'SSECustomerKey': 'fakekey',
+ 'SSECustomerAlgorithm': 'AES256',
+ 'StorageClass': 'REDUCED_REDUNDANCY',
+ }
+ uploader.upload_file('filename', 'bucket', 'key', None, extra_args)
+
+ client.create_multipart_upload.assert_called_with(
+ Bucket='bucket',
+ Key='key',
+ # The initial call should inject all the storage class params.
+ SSECustomerKey='fakekey',
+ SSECustomerAlgorithm='AES256',
+ StorageClass='REDUCED_REDUNDANCY',
+ )
+ client.upload_part.assert_called_with(
+ Body=mock.ANY,
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ PartNumber=1,
+ # We only have to forward certain **extra_args in subsequent
+ # UploadPart calls.
+ SSECustomerKey='fakekey',
+ SSECustomerAlgorithm='AES256',
+ )
+ client.complete_multipart_upload.assert_called_with(
+ MultipartUpload={'Parts': [{'PartNumber': 1, 'ETag': 'first'}]},
+ Bucket='bucket',
+ UploadId='upload_id',
+ Key='key',
+ )
+
+ def test_multipart_upload_is_aborted_on_error(self):
+ # If the create_multipart_upload succeeds and any upload_part
+ # fails, then abort_multipart_upload will be called.
+ client = mock.Mock()
+ uploader = MultipartUploader(
+ client,
+ TransferConfig(),
+ InMemoryOSLayer({'filename': b'foobar'}),
+ SequentialExecutor,
+ )
+ client.create_multipart_upload.return_value = {'UploadId': 'upload_id'}
+ client.upload_part.side_effect = Exception(
+ "Some kind of error occurred."
+ )
+
+ with self.assertRaises(S3UploadFailedError):
+ uploader.upload_file('filename', 'bucket', 'key', None, {})
+
+ client.abort_multipart_upload.assert_called_with(
+ Bucket='bucket', Key='key', UploadId='upload_id'
+ )
+
+
+class TestMultipartDownloader(unittest.TestCase):
+
+ maxDiff = None
+
+ def test_multipart_download_uses_correct_client_calls(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ client.get_object.assert_called_with(
+ Range='bytes=0-', Bucket='bucket', Key='key'
+ )
+
+ def test_multipart_download_with_multiple_parts(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+ # For testing purposes, we're testing with a multipart threshold
+ # of 4 bytes and a chunksize of 4 bytes. Given b'foobarbaz',
+ # this should result in 3 calls. In python slices this would be:
+ # r[0:4], r[4:8], r[8:9]. But the Range param will be slightly
+ # different because they use inclusive ranges.
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ # We're storing these in **extra because the assertEqual
+ # below is really about verifying we have the correct value
+ # for the Range param.
+ extra = {'Bucket': 'bucket', 'Key': 'key'}
+ self.assertEqual(
+ client.get_object.call_args_list,
+ # Note these are inclusive ranges.
+ [
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=4-7', **extra),
+ mock.call(Range='bytes=8-', **extra),
+ ],
+ )
+
+ def test_retry_on_failures_from_stream_reads(self):
+ # If we get an exception during a call to the response body's .read()
+ # method, we should retry the request.
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ stream_with_errors = mock.Mock()
+ stream_with_errors.read.side_effect = [
+ socket.error("fake error"),
+ response_body,
+ ]
+ client.get_object.return_value = {'Body': stream_with_errors}
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ # We're storing these in **extra because the assertEqual
+ # below is really about verifying we have the correct value
+ # for the Range param.
+ extra = {'Bucket': 'bucket', 'Key': 'key'}
+ self.assertEqual(
+ client.get_object.call_args_list,
+ # The first call to range=0-3 fails because of the
+ # side_effect above where we make the .read() raise a
+ # socket.error.
+ # The second call to range=0-3 then succeeds.
+ [
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=0-3', **extra),
+ mock.call(Range='bytes=4-7', **extra),
+ mock.call(Range='bytes=8-', **extra),
+ ],
+ )
+
+ def test_exception_raised_on_exceeded_retries(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ stream_with_errors = mock.Mock()
+ stream_with_errors.read.side_effect = socket.error("fake error")
+ client.get_object.return_value = {'Body': stream_with_errors}
+ config = TransferConfig(multipart_threshold=4, multipart_chunksize=4)
+
+ downloader = MultipartDownloader(
+ client, config, InMemoryOSLayer({}), SequentialExecutor
+ )
+ with self.assertRaises(RetriesExceededError):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ def test_io_thread_failure_triggers_shutdown(self):
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+ os_layer = mock.Mock()
+ mock_fileobj = mock.MagicMock()
+ mock_fileobj.__enter__.return_value = mock_fileobj
+ mock_fileobj.write.side_effect = Exception("fake IO error")
+ os_layer.open.return_value = mock_fileobj
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), os_layer, SequentialExecutor
+ )
+ # We're verifying that the exception raised from the IO future
+ # propagates back up via download_file().
+ with self.assertRaisesRegex(Exception, "fake IO error"):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+ def test_download_futures_fail_triggers_shutdown(self):
+ class FailedDownloadParts(SequentialExecutor):
+ def __init__(self, max_workers):
+ self.is_first = True
+
+ def submit(self, function):
+ future = futures.Future()
+ if self.is_first:
+ # This is the download_parts_thread.
+ future.set_exception(
+ Exception("fake download parts error")
+ )
+ self.is_first = False
+ return future
+
+ client = mock.Mock()
+ response_body = b'foobarbaz'
+ client.get_object.return_value = {'Body': BytesIO(response_body)}
+
+ downloader = MultipartDownloader(
+ client, TransferConfig(), InMemoryOSLayer({}), FailedDownloadParts
+ )
+ with self.assertRaisesRegex(Exception, "fake download parts error"):
+ downloader.download_file(
+ 'bucket', 'key', 'filename', len(response_body), {}
+ )
+
+
+class TestS3Transfer(unittest.TestCase):
+ def setUp(self):
+ self.client = mock.Mock()
+ self.random_file_patch = mock.patch('s3transfer.random_file_extension')
+ self.random_file = self.random_file_patch.start()
+ self.random_file.return_value = 'RANDOM'
+
+ def tearDown(self):
+ self.random_file_patch.stop()
+
+ def test_callback_handlers_register_on_put_item(self):
+ osutil = InMemoryOSLayer({'smallfile': b'foobar'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+ events = self.client.meta.events
+ events.register_first.assert_called_with(
+ 'request-created.s3',
+ disable_upload_callbacks,
+ unique_id='s3upload-callback-disable',
+ )
+ events.register_last.assert_called_with(
+ 'request-created.s3',
+ enable_upload_callbacks,
+ unique_id='s3upload-callback-enable',
+ )
+
+ def test_upload_below_multipart_threshold_uses_put_object(self):
+ fake_files = {
+ 'smallfile': b'foobar',
+ }
+ osutil = InMemoryOSLayer(fake_files)
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+ self.client.put_object.assert_called_with(
+ Bucket='bucket', Key='key', Body=mock.ANY
+ )
+
+ def test_extra_args_on_uploaded_passed_to_api_call(self):
+ extra_args = {'ACL': 'public-read'}
+ fake_files = {'smallfile': b'hello world'}
+ osutil = InMemoryOSLayer(fake_files)
+ transfer = S3Transfer(self.client, osutil=osutil)
+ transfer.upload_file(
+ 'smallfile', 'bucket', 'key', extra_args=extra_args
+ )
+ self.client.put_object.assert_called_with(
+ Bucket='bucket', Key='key', Body=mock.ANY, ACL='public-read'
+ )
+
+ def test_uses_multipart_upload_when_over_threshold(self):
+ with mock.patch('s3transfer.MultipartUploader') as uploader:
+ fake_files = {
+ 'smallfile': b'foobar',
+ }
+ osutil = InMemoryOSLayer(fake_files)
+ config = TransferConfig(
+ multipart_threshold=2, multipart_chunksize=2
+ )
+ transfer = S3Transfer(self.client, osutil=osutil, config=config)
+ transfer.upload_file('smallfile', 'bucket', 'key')
+
+ uploader.return_value.upload_file.assert_called_with(
+ 'smallfile', 'bucket', 'key', None, {}
+ )
+
+ def test_uses_multipart_download_when_over_threshold(self):
+ with mock.patch('s3transfer.MultipartDownloader') as downloader:
+ osutil = InMemoryOSLayer({})
+ over_multipart_threshold = 100 * 1024 * 1024
+ transfer = S3Transfer(self.client, osutil=osutil)
+ callback = mock.sentinel.CALLBACK
+ self.client.head_object.return_value = {
+ 'ContentLength': over_multipart_threshold,
+ }
+ transfer.download_file(
+ 'bucket', 'key', 'filename', callback=callback
+ )
+
+ downloader.return_value.download_file.assert_called_with(
+ # Note how we're downloading to a temporary random file.
+ 'bucket',
+ 'key',
+ 'filename.RANDOM',
+ over_multipart_threshold,
+ {},
+ callback,
+ )
+
+ def test_download_file_with_invalid_extra_args(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ with self.assertRaises(ValueError):
+ transfer.download_file(
+ 'bucket',
+ 'key',
+ '/tmp/smallfile',
+ extra_args={'BadValue': 'foo'},
+ )
+
+ def test_upload_file_with_invalid_extra_args(self):
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ bad_args = {"WebsiteRedirectLocation": "/foo"}
+ with self.assertRaises(ValueError):
+ transfer.upload_file(
+ 'bucket', 'key', '/tmp/smallfile', extra_args=bad_args
+ )
+
+ def test_download_file_fowards_extra_args(self):
+ extra_args = {
+ 'SSECustomerKey': 'foo',
+ 'SSECustomerAlgorithm': 'AES256',
+ }
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
+ transfer.download_file(
+ 'bucket', 'key', '/tmp/smallfile', extra_args=extra_args
+ )
+
+ # Note that we need to invoke the HeadObject call
+ # and the PutObject call with the extra_args.
+ # This is necessary. Trying to HeadObject an SSE object
+ # will return a 400 if you don't provide the required
+ # params.
+ self.client.get_object.assert_called_with(
+ Bucket='bucket',
+ Key='key',
+ SSECustomerAlgorithm='AES256',
+ SSECustomerKey='foo',
+ )
+
+ def test_get_object_stream_is_retried_and_succeeds(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.side_effect = [
+ # First request fails.
+ socket.error("fake error"),
+ # Second succeeds.
+ {'Body': BytesIO(b'foobar')},
+ ]
+ transfer.download_file('bucket', 'key', '/tmp/smallfile')
+
+ self.assertEqual(self.client.get_object.call_count, 2)
+
+ def test_get_object_stream_uses_all_retries_and_errors_out(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ # Here we're raising an exception every single time, which
+ # will exhaust our retry count and propagate a
+ # RetriesExceededError.
+ self.client.get_object.side_effect = socket.error("fake error")
+ with self.assertRaises(RetriesExceededError):
+ transfer.download_file('bucket', 'key', 'smallfile')
+
+ self.assertEqual(self.client.get_object.call_count, 5)
+ # We should have also cleaned up the in progress file
+ # we were downloading to.
+ self.assertEqual(osutil.filemap, {})
+
+ def test_download_below_multipart_threshold(self):
+ below_threshold = 20
+ osutil = InMemoryOSLayer({'smallfile': b'hello world'})
+ transfer = S3Transfer(self.client, osutil=osutil)
+ self.client.head_object.return_value = {
+ 'ContentLength': below_threshold
+ }
+ self.client.get_object.return_value = {'Body': BytesIO(b'foobar')}
+ transfer.download_file('bucket', 'key', 'smallfile')
+
+ self.client.get_object.assert_called_with(Bucket='bucket', Key='key')
+
+ def test_can_create_with_just_client(self):
+ transfer = S3Transfer(client=mock.Mock())
+ self.assertIsInstance(transfer, S3Transfer)
+
+
+class TestShutdownQueue(unittest.TestCase):
+ def test_handles_normal_put_get_requests(self):
+ q = ShutdownQueue()
+ q.put('foo')
+ self.assertEqual(q.get(), 'foo')
+
+ def test_put_raises_error_on_shutdown(self):
+ q = ShutdownQueue()
+ q.trigger_shutdown()
+ with self.assertRaises(QueueShutdownError):
+ q.put('foo')
+
+
+class TestRandomFileExtension(unittest.TestCase):
+ def test_has_proper_length(self):
+ self.assertEqual(len(random_file_extension(num_digits=4)), 4)
+
+
+class TestCallbackHandlers(unittest.TestCase):
+ def setUp(self):
+ self.request = mock.Mock()
+
+ def test_disable_request_on_put_object(self):
+ disable_upload_callbacks(self.request, 'PutObject')
+ self.request.body.disable_callback.assert_called_with()
+
+ def test_disable_request_on_upload_part(self):
+ disable_upload_callbacks(self.request, 'UploadPart')
+ self.request.body.disable_callback.assert_called_with()
+
+ def test_enable_object_on_put_object(self):
+ enable_upload_callbacks(self.request, 'PutObject')
+ self.request.body.enable_callback.assert_called_with()
+
+ def test_enable_object_on_upload_part(self):
+ enable_upload_callbacks(self.request, 'UploadPart')
+ self.request.body.enable_callback.assert_called_with()
+
+ def test_dont_disable_if_missing_interface(self):
+ del self.request.body.disable_callback
+ disable_upload_callbacks(self.request, 'PutObject')
+ self.assertEqual(self.request.body.method_calls, [])
+
+ def test_dont_enable_if_missing_interface(self):
+ del self.request.body.enable_callback
+ enable_upload_callbacks(self.request, 'PutObject')
+ self.assertEqual(self.request.body.method_calls, [])
+
+ def test_dont_disable_if_wrong_operation(self):
+ disable_upload_callbacks(self.request, 'OtherOperation')
+ self.assertFalse(self.request.body.disable_callback.called)
+
+ def test_dont_enable_if_wrong_operation(self):
+ enable_upload_callbacks(self.request, 'OtherOperation')
+ self.assertFalse(self.request.body.enable_callback.called)
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
index fd503e212e..a26d3a548c 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_subscribers.py
@@ -1,91 +1,91 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License'). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the 'license' file accompanying this file. This file is
-# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from s3transfer.exceptions import InvalidSubscriberMethodError
-from s3transfer.subscribers import BaseSubscriber
-from __tests__ import unittest
-
-
-class ExtraMethodsSubscriber(BaseSubscriber):
- def extra_method(self):
- return 'called extra method'
-
-
-class NotCallableSubscriber(BaseSubscriber):
- on_done = 'foo'
-
-
-class NoKwargsSubscriber(BaseSubscriber):
- def on_done(self):
- pass
-
-
-class OverrideMethodSubscriber(BaseSubscriber):
- def on_queued(self, **kwargs):
- return kwargs
-
-
-class OverrideConstructorSubscriber(BaseSubscriber):
- def __init__(self, arg1, arg2):
- self.arg1 = arg1
- self.arg2 = arg2
-
-
-class TestSubscribers(unittest.TestCase):
- def test_can_instantiate_base_subscriber(self):
- try:
- BaseSubscriber()
- except InvalidSubscriberMethodError:
- self.fail('BaseSubscriber should be instantiable')
-
- def test_can_call_base_subscriber_method(self):
- subscriber = BaseSubscriber()
- try:
- subscriber.on_done(future=None)
- except Exception as e:
- self.fail(
- 'Should be able to call base class subscriber method. '
- 'instead got: %s' % e
- )
-
- def test_subclass_can_have_and_call_additional_methods(self):
- subscriber = ExtraMethodsSubscriber()
- self.assertEqual(subscriber.extra_method(), 'called extra method')
-
- def test_can_subclass_and_override_method_from_base_subscriber(self):
- subscriber = OverrideMethodSubscriber()
- # Make sure that the overridden method is called
- self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'})
-
- def test_can_subclass_and_override_constructor_from_base_class(self):
- subscriber = OverrideConstructorSubscriber('foo', arg2='bar')
- # Make sure you can create a custom constructor.
- self.assertEqual(subscriber.arg1, 'foo')
- self.assertEqual(subscriber.arg2, 'bar')
-
- def test_invalid_arguments_in_constructor_of_subclass_subscriber(self):
- # The override constructor should still have validation of
- # constructor args.
- with self.assertRaises(TypeError):
- OverrideConstructorSubscriber()
-
- def test_not_callable_in_subclass_subscriber_method(self):
- with self.assertRaisesRegex(
- InvalidSubscriberMethodError, 'must be callable'
- ):
- NotCallableSubscriber()
-
- def test_no_kwargs_in_subclass_subscriber_method(self):
- with self.assertRaisesRegex(
- InvalidSubscriberMethodError, 'must accept keyword'
- ):
- NoKwargsSubscriber()
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from s3transfer.exceptions import InvalidSubscriberMethodError
+from s3transfer.subscribers import BaseSubscriber
+from __tests__ import unittest
+
+
+class ExtraMethodsSubscriber(BaseSubscriber):
+ def extra_method(self):
+ return 'called extra method'
+
+
+class NotCallableSubscriber(BaseSubscriber):
+ on_done = 'foo'
+
+
+class NoKwargsSubscriber(BaseSubscriber):
+ def on_done(self):
+ pass
+
+
+class OverrideMethodSubscriber(BaseSubscriber):
+ def on_queued(self, **kwargs):
+ return kwargs
+
+
+class OverrideConstructorSubscriber(BaseSubscriber):
+ def __init__(self, arg1, arg2):
+ self.arg1 = arg1
+ self.arg2 = arg2
+
+
+class TestSubscribers(unittest.TestCase):
+ def test_can_instantiate_base_subscriber(self):
+ try:
+ BaseSubscriber()
+ except InvalidSubscriberMethodError:
+ self.fail('BaseSubscriber should be instantiable')
+
+ def test_can_call_base_subscriber_method(self):
+ subscriber = BaseSubscriber()
+ try:
+ subscriber.on_done(future=None)
+ except Exception as e:
+ self.fail(
+ 'Should be able to call base class subscriber method. '
+ 'instead got: %s' % e
+ )
+
+ def test_subclass_can_have_and_call_additional_methods(self):
+ subscriber = ExtraMethodsSubscriber()
+ self.assertEqual(subscriber.extra_method(), 'called extra method')
+
+ def test_can_subclass_and_override_method_from_base_subscriber(self):
+ subscriber = OverrideMethodSubscriber()
+ # Make sure that the overridden method is called
+ self.assertEqual(subscriber.on_queued(foo='bar'), {'foo': 'bar'})
+
+ def test_can_subclass_and_override_constructor_from_base_class(self):
+ subscriber = OverrideConstructorSubscriber('foo', arg2='bar')
+ # Make sure you can create a custom constructor.
+ self.assertEqual(subscriber.arg1, 'foo')
+ self.assertEqual(subscriber.arg2, 'bar')
+
+ def test_invalid_arguments_in_constructor_of_subclass_subscriber(self):
+ # The override constructor should still have validation of
+ # constructor args.
+ with self.assertRaises(TypeError):
+ OverrideConstructorSubscriber()
+
+ def test_not_callable_in_subclass_subscriber_method(self):
+ with self.assertRaisesRegex(
+ InvalidSubscriberMethodError, 'must be callable'
+ ):
+ NotCallableSubscriber()
+
+ def test_no_kwargs_in_subclass_subscriber_method(self):
+ with self.assertRaisesRegex(
+ InvalidSubscriberMethodError, 'must accept keyword'
+ ):
+ NoKwargsSubscriber()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
index 8806598fd7..4f0bc4d1cc 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_tasks.py
@@ -1,833 +1,833 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-from concurrent import futures
-from functools import partial
-from threading import Event
-
-from s3transfer.futures import BoundedExecutor, TransferCoordinator
-from s3transfer.subscribers import BaseSubscriber
-from s3transfer.tasks import (
- CompleteMultipartUploadTask,
- CreateMultipartUploadTask,
- SubmissionTask,
- Task,
-)
-from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
-from __tests__ import (
- BaseSubmissionTaskTest,
- BaseTaskTest,
- RecordingSubscriber,
- unittest,
-)
-
-
-class TaskFailureException(Exception):
- pass
-
-
-class SuccessTask(Task):
- def _main(
- self, return_value='success', callbacks=None, failure_cleanups=None
- ):
- if callbacks:
- for callback in callbacks:
- callback()
- if failure_cleanups:
- for failure_cleanup in failure_cleanups:
- self._transfer_coordinator.add_failure_cleanup(failure_cleanup)
- return return_value
-
-
-class FailureTask(Task):
- def _main(self, exception=TaskFailureException):
- raise exception()
-
-
-class ReturnKwargsTask(Task):
- def _main(self, **kwargs):
- return kwargs
-
-
-class SubmitMoreTasksTask(Task):
- def _main(self, executor, tasks_to_submit):
- for task_to_submit in tasks_to_submit:
- self._transfer_coordinator.submit(executor, task_to_submit)
-
-
-class NOOPSubmissionTask(SubmissionTask):
- def _submit(self, transfer_future, **kwargs):
- pass
-
-
-class ExceptionSubmissionTask(SubmissionTask):
- def _submit(
- self,
- transfer_future,
- executor=None,
- tasks_to_submit=None,
- additional_callbacks=None,
- exception=TaskFailureException,
- ):
- if executor and tasks_to_submit:
- for task_to_submit in tasks_to_submit:
- self._transfer_coordinator.submit(executor, task_to_submit)
- if additional_callbacks:
- for callback in additional_callbacks:
- callback()
- raise exception()
-
-
-class StatusRecordingTransferCoordinator(TransferCoordinator):
- def __init__(self, transfer_id=None):
- super().__init__(transfer_id)
- self.status_changes = [self._status]
-
- def set_status_to_queued(self):
- super().set_status_to_queued()
- self._record_status_change()
-
- def set_status_to_running(self):
- super().set_status_to_running()
- self._record_status_change()
-
- def _record_status_change(self):
- self.status_changes.append(self._status)
-
-
-class RecordingStateSubscriber(BaseSubscriber):
- def __init__(self, transfer_coordinator):
- self._transfer_coordinator = transfer_coordinator
- self.status_during_on_queued = None
-
- def on_queued(self, **kwargs):
- self.status_during_on_queued = self._transfer_coordinator.status
-
-
-class TestSubmissionTask(BaseSubmissionTaskTest):
- def setUp(self):
- super().setUp()
- self.executor = BoundedExecutor(1000, 5)
- self.call_args = CallArgs(subscribers=[])
- self.transfer_future = self.get_transfer_future(self.call_args)
- self.main_kwargs = {'transfer_future': self.transfer_future}
-
- def test_transitions_from_not_started_to_queued_to_running(self):
- self.transfer_coordinator = StatusRecordingTransferCoordinator()
- submission_task = self.get_task(
- NOOPSubmissionTask, main_kwargs=self.main_kwargs
- )
- # Status should be queued until submission task has been ran.
- self.assertEqual(self.transfer_coordinator.status, 'not-started')
-
- submission_task()
- # Once submission task has been ran, the status should now be running.
- self.assertEqual(self.transfer_coordinator.status, 'running')
-
- # Ensure the transitions were as expected as well.
- self.assertEqual(
- self.transfer_coordinator.status_changes,
- ['not-started', 'queued', 'running'],
- )
-
- def test_on_queued_callbacks(self):
- submission_task = self.get_task(
- NOOPSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- subscriber = RecordingSubscriber()
- self.call_args.subscribers.append(subscriber)
- submission_task()
- # Make sure the on_queued callback of the subscriber is called.
- self.assertEqual(
- subscriber.on_queued_calls, [{'future': self.transfer_future}]
- )
-
- def test_on_queued_status_in_callbacks(self):
- submission_task = self.get_task(
- NOOPSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- subscriber = RecordingStateSubscriber(self.transfer_coordinator)
- self.call_args.subscribers.append(subscriber)
- submission_task()
- # Make sure the status was queued during on_queued callback.
- self.assertEqual(subscriber.status_during_on_queued, 'queued')
-
- def test_sets_exception_from_submit(self):
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
- submission_task()
-
- # Make sure the status of the future is failed
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Make sure the future propagates the exception encountered in the
- # submission task.
- with self.assertRaises(TaskFailureException):
- self.transfer_future.result()
-
- def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self):
- self.main_kwargs['exception'] = KeyboardInterrupt
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
- submission_task()
-
- self.assertEqual(self.transfer_coordinator.status, 'failed')
- with self.assertRaises(KeyboardInterrupt):
- self.transfer_future.result()
-
- def test_calls_done_callbacks_on_exception(self):
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- subscriber = RecordingSubscriber()
- self.call_args.subscribers.append(subscriber)
-
- # Add the done callback to the callbacks to be invoked when the
- # transfer is done.
- done_callbacks = get_callbacks(self.transfer_future, 'done')
- for done_callback in done_callbacks:
- self.transfer_coordinator.add_done_callback(done_callback)
- submission_task()
-
- # Make sure the task failed to start
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Make sure the on_done callback of the subscriber is called.
- self.assertEqual(
- subscriber.on_done_calls, [{'future': self.transfer_future}]
- )
-
- def test_calls_failure_cleanups_on_exception(self):
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- # Add the callback to the callbacks to be invoked when the
- # transfer fails.
- invocations_of_cleanup = []
- cleanup_callback = FunctionContainer(
- invocations_of_cleanup.append, 'cleanup happened'
- )
- self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
- submission_task()
-
- # Make sure the task failed to start
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Make sure the cleanup was called.
- self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
-
- def test_cleanups_only_ran_once_on_exception(self):
- # We want to be able to handle the case where the final task completes
- # and anounces done but there is an error in the submission task
- # which will cause it to need to announce done as well. In this case,
- # we do not want the done callbacks to be invoke more than once.
-
- final_task = self.get_task(FailureTask, is_final=True)
- self.main_kwargs['executor'] = self.executor
- self.main_kwargs['tasks_to_submit'] = [final_task]
-
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- subscriber = RecordingSubscriber()
- self.call_args.subscribers.append(subscriber)
-
- # Add the done callback to the callbacks to be invoked when the
- # transfer is done.
- done_callbacks = get_callbacks(self.transfer_future, 'done')
- for done_callback in done_callbacks:
- self.transfer_coordinator.add_done_callback(done_callback)
-
- submission_task()
-
- # Make sure the task failed to start
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Make sure the on_done callback of the subscriber is called only once.
- self.assertEqual(
- subscriber.on_done_calls, [{'future': self.transfer_future}]
- )
-
- def test_done_callbacks_only_ran_once_on_exception(self):
- # We want to be able to handle the case where the final task completes
- # and anounces done but there is an error in the submission task
- # which will cause it to need to announce done as well. In this case,
- # we do not want the failure cleanups to be invoked more than once.
-
- final_task = self.get_task(FailureTask, is_final=True)
- self.main_kwargs['executor'] = self.executor
- self.main_kwargs['tasks_to_submit'] = [final_task]
-
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- # Add the callback to the callbacks to be invoked when the
- # transfer fails.
- invocations_of_cleanup = []
- cleanup_callback = FunctionContainer(
- invocations_of_cleanup.append, 'cleanup happened'
- )
- self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
- submission_task()
-
- # Make sure the task failed to start
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Make sure the cleanup was called only once.
- self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
-
- def test_handles_cleanups_submitted_in_other_tasks(self):
- invocations_of_cleanup = []
- event = Event()
- cleanup_callback = FunctionContainer(
- invocations_of_cleanup.append, 'cleanup happened'
- )
- # We want the cleanup to be added in the execution of the task and
- # still be executed by the submission task when it fails.
- task = self.get_task(
- SuccessTask,
- main_kwargs={
- 'callbacks': [event.set],
- 'failure_cleanups': [cleanup_callback],
- },
- )
-
- self.main_kwargs['executor'] = self.executor
- self.main_kwargs['tasks_to_submit'] = [task]
- self.main_kwargs['additional_callbacks'] = [event.wait]
-
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- submission_task()
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Make sure the cleanup was called even though the callback got
- # added in a completely different task.
- self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
-
- def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self):
- # In this test, we want to make sure that any tasks that may be
- # submitted in another task complete before we start performing
- # cleanups.
- #
- # This is tested by doing the following:
- #
- # ExecutionSubmissionTask
- # |
- # +--submits-->SubmitMoreTasksTask
- # |
- # +--submits-->SuccessTask
- # |
- # +-->sleeps-->adds failure cleanup
- #
- # In the end, the failure cleanup of the SuccessTask should be ran
- # when the ExecutionSubmissionTask fails. If the
- # ExeceptionSubmissionTask did not run the failure cleanup it is most
- # likely that it did not wait for the SuccessTask to complete, which
- # it needs to because the ExeceptionSubmissionTask does not know
- # what failure cleanups it needs to run until all spawned tasks have
- # completed.
- invocations_of_cleanup = []
- event = Event()
- cleanup_callback = FunctionContainer(
- invocations_of_cleanup.append, 'cleanup happened'
- )
-
- cleanup_task = self.get_task(
- SuccessTask,
- main_kwargs={
- 'callbacks': [event.set],
- 'failure_cleanups': [cleanup_callback],
- },
- )
- task_for_submitting_cleanup_task = self.get_task(
- SubmitMoreTasksTask,
- main_kwargs={
- 'executor': self.executor,
- 'tasks_to_submit': [cleanup_task],
- },
- )
-
- self.main_kwargs['executor'] = self.executor
- self.main_kwargs['tasks_to_submit'] = [
- task_for_submitting_cleanup_task
- ]
- self.main_kwargs['additional_callbacks'] = [event.wait]
-
- submission_task = self.get_task(
- ExceptionSubmissionTask, main_kwargs=self.main_kwargs
- )
-
- submission_task()
- self.assertEqual(self.transfer_coordinator.status, 'failed')
- self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
-
- def test_submission_task_announces_done_if_cancelled_before_main(self):
- invocations_of_done = []
- done_callback = FunctionContainer(
- invocations_of_done.append, 'done announced'
- )
- self.transfer_coordinator.add_done_callback(done_callback)
-
- self.transfer_coordinator.cancel()
- submission_task = self.get_task(
- NOOPSubmissionTask, main_kwargs=self.main_kwargs
- )
- submission_task()
-
- # Because the submission task was cancelled before being run
- # it did not submit any extra tasks so a result it is responsible
- # for making sure it announces done as nothing else will.
- self.assertEqual(invocations_of_done, ['done announced'])
-
-
-class TestTask(unittest.TestCase):
- def setUp(self):
- self.transfer_id = 1
- self.transfer_coordinator = TransferCoordinator(
- transfer_id=self.transfer_id
- )
-
- def test_repr(self):
- main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'}
- task = ReturnKwargsTask(
- self.transfer_coordinator, main_kwargs=main_kwargs
- )
- # The repr should not include the other parameter because it is not
- # a desired parameter to include.
- self.assertEqual(
- repr(task),
- 'ReturnKwargsTask(transfer_id={}, {})'.format(
- self.transfer_id, {'bucket': 'mybucket'}
- ),
- )
-
- def test_transfer_id(self):
- task = SuccessTask(self.transfer_coordinator)
- # Make sure that the id is the one provided to the id associated
- # to the transfer coordinator.
- self.assertEqual(task.transfer_id, self.transfer_id)
-
- def test_context_status_transitioning_success(self):
- # The status should be set to running.
- self.transfer_coordinator.set_status_to_running()
- self.assertEqual(self.transfer_coordinator.status, 'running')
-
- # If a task is called, the status still should be running.
- SuccessTask(self.transfer_coordinator)()
- self.assertEqual(self.transfer_coordinator.status, 'running')
-
- # Once the final task is called, the status should be set to success.
- SuccessTask(self.transfer_coordinator, is_final=True)()
- self.assertEqual(self.transfer_coordinator.status, 'success')
-
- def test_context_status_transitioning_failed(self):
- self.transfer_coordinator.set_status_to_running()
-
- SuccessTask(self.transfer_coordinator)()
- self.assertEqual(self.transfer_coordinator.status, 'running')
-
- # A failure task should result in the failed status
- FailureTask(self.transfer_coordinator)()
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- # Even if the final task comes in and succeeds, it should stay failed.
- SuccessTask(self.transfer_coordinator, is_final=True)()
- self.assertEqual(self.transfer_coordinator.status, 'failed')
-
- def test_result_setting_for_success(self):
- override_return = 'foo'
- SuccessTask(self.transfer_coordinator)()
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': override_return},
- is_final=True,
- )()
-
- # The return value for the transfer future should be of the final
- # task.
- self.assertEqual(self.transfer_coordinator.result(), override_return)
-
- def test_result_setting_for_error(self):
- FailureTask(self.transfer_coordinator)()
-
- # If another failure comes in, the result should still throw the
- # original exception when result() is eventually called.
- FailureTask(
- self.transfer_coordinator, main_kwargs={'exception': Exception}
- )()
-
- # Even if a success task comes along, the result of the future
- # should be the original exception
- SuccessTask(self.transfer_coordinator, is_final=True)()
- with self.assertRaises(TaskFailureException):
- self.transfer_coordinator.result()
-
- def test_done_callbacks_success(self):
- callback_results = []
- SuccessTask(
- self.transfer_coordinator,
- done_callbacks=[
- partial(callback_results.append, 'first'),
- partial(callback_results.append, 'second'),
- ],
- )()
- # For successful tasks, the done callbacks should get called.
- self.assertEqual(callback_results, ['first', 'second'])
-
- def test_done_callbacks_failure(self):
- callback_results = []
- FailureTask(
- self.transfer_coordinator,
- done_callbacks=[
- partial(callback_results.append, 'first'),
- partial(callback_results.append, 'second'),
- ],
- )()
- # For even failed tasks, the done callbacks should get called.
- self.assertEqual(callback_results, ['first', 'second'])
-
- # Callbacks should continue to be called even after a related failure
- SuccessTask(
- self.transfer_coordinator,
- done_callbacks=[
- partial(callback_results.append, 'third'),
- partial(callback_results.append, 'fourth'),
- ],
- )()
- self.assertEqual(
- callback_results, ['first', 'second', 'third', 'fourth']
- )
-
- def test_failure_cleanups_on_failure(self):
- callback_results = []
- self.transfer_coordinator.add_failure_cleanup(
- callback_results.append, 'first'
- )
- self.transfer_coordinator.add_failure_cleanup(
- callback_results.append, 'second'
- )
- FailureTask(self.transfer_coordinator)()
- # The failure callbacks should have not been called yet because it
- # is not the last task
- self.assertEqual(callback_results, [])
-
- # Now the failure callbacks should get called.
- SuccessTask(self.transfer_coordinator, is_final=True)()
- self.assertEqual(callback_results, ['first', 'second'])
-
- def test_no_failure_cleanups_on_success(self):
- callback_results = []
- self.transfer_coordinator.add_failure_cleanup(
- callback_results.append, 'first'
- )
- self.transfer_coordinator.add_failure_cleanup(
- callback_results.append, 'second'
- )
- SuccessTask(self.transfer_coordinator, is_final=True)()
- # The failure cleanups should not have been called because no task
- # failed for the transfer context.
- self.assertEqual(callback_results, [])
-
- def test_passing_main_kwargs(self):
- main_kwargs = {'foo': 'bar', 'baz': 'biz'}
- ReturnKwargsTask(
- self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True
- )()
- # The kwargs should have been passed to the main()
- self.assertEqual(self.transfer_coordinator.result(), main_kwargs)
-
- def test_passing_pending_kwargs_single_futures(self):
- pending_kwargs = {}
- ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'}
-
- # Pass some tasks to an executor
- with futures.ThreadPoolExecutor(1) as executor:
- pending_kwargs['foo'] = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': ref_main_kwargs['foo']},
- )
- )
- pending_kwargs['baz'] = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': ref_main_kwargs['baz']},
- )
- )
-
- # Create a task that depends on the tasks passed to the executor
- ReturnKwargsTask(
- self.transfer_coordinator,
- pending_main_kwargs=pending_kwargs,
- is_final=True,
- )()
- # The result should have the pending keyword arg values flushed
- # out.
- self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
-
- def test_passing_pending_kwargs_list_of_futures(self):
- pending_kwargs = {}
- ref_main_kwargs = {'foo': ['first', 'second']}
-
- # Pass some tasks to an executor
- with futures.ThreadPoolExecutor(1) as executor:
- first_future = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': ref_main_kwargs['foo'][0]},
- )
- )
- second_future = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': ref_main_kwargs['foo'][1]},
- )
- )
- # Make the pending keyword arg value a list
- pending_kwargs['foo'] = [first_future, second_future]
-
- # Create a task that depends on the tasks passed to the executor
- ReturnKwargsTask(
- self.transfer_coordinator,
- pending_main_kwargs=pending_kwargs,
- is_final=True,
- )()
- # The result should have the pending keyword arg values flushed
- # out in the expected order.
- self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
-
- def test_passing_pending_and_non_pending_kwargs(self):
- main_kwargs = {'nonpending_value': 'foo'}
- pending_kwargs = {}
- ref_main_kwargs = {
- 'nonpending_value': 'foo',
- 'pending_value': 'bar',
- 'pending_list': ['first', 'second'],
- }
-
- # Create the pending tasks
- with futures.ThreadPoolExecutor(1) as executor:
- pending_kwargs['pending_value'] = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={
- 'return_value': ref_main_kwargs['pending_value']
- },
- )
- )
-
- first_future = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={
- 'return_value': ref_main_kwargs['pending_list'][0]
- },
- )
- )
- second_future = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={
- 'return_value': ref_main_kwargs['pending_list'][1]
- },
- )
- )
- # Make the pending keyword arg value a list
- pending_kwargs['pending_list'] = [first_future, second_future]
-
- # Create a task that depends on the tasks passed to the executor
- # and just regular nonpending kwargs.
- ReturnKwargsTask(
- self.transfer_coordinator,
- main_kwargs=main_kwargs,
- pending_main_kwargs=pending_kwargs,
- is_final=True,
- )()
- # The result should have all of the kwargs (both pending and
- # nonpending)
- self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
-
- def test_single_failed_pending_future(self):
- pending_kwargs = {}
-
- # Pass some tasks to an executor. Make one successful and the other
- # a failure.
- with futures.ThreadPoolExecutor(1) as executor:
- pending_kwargs['foo'] = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': 'bar'},
- )
- )
- pending_kwargs['baz'] = executor.submit(
- FailureTask(self.transfer_coordinator)
- )
-
- # Create a task that depends on the tasks passed to the executor
- ReturnKwargsTask(
- self.transfer_coordinator,
- pending_main_kwargs=pending_kwargs,
- is_final=True,
- )()
- # The end result should raise the exception from the initial
- # pending future value
- with self.assertRaises(TaskFailureException):
- self.transfer_coordinator.result()
-
- def test_single_failed_pending_future_in_list(self):
- pending_kwargs = {}
-
- # Pass some tasks to an executor. Make one successful and the other
- # a failure.
- with futures.ThreadPoolExecutor(1) as executor:
- first_future = executor.submit(
- SuccessTask(
- self.transfer_coordinator,
- main_kwargs={'return_value': 'bar'},
- )
- )
- second_future = executor.submit(
- FailureTask(self.transfer_coordinator)
- )
-
- pending_kwargs['pending_list'] = [first_future, second_future]
-
- # Create a task that depends on the tasks passed to the executor
- ReturnKwargsTask(
- self.transfer_coordinator,
- pending_main_kwargs=pending_kwargs,
- is_final=True,
- )()
- # The end result should raise the exception from the initial
- # pending future value in the list
- with self.assertRaises(TaskFailureException):
- self.transfer_coordinator.result()
-
-
-class BaseMultipartTaskTest(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.bucket = 'mybucket'
- self.key = 'foo'
-
-
-class TestCreateMultipartUploadTask(BaseMultipartTaskTest):
- def test_main(self):
- upload_id = 'foo'
- extra_args = {'Metadata': {'foo': 'bar'}}
- response = {'UploadId': upload_id}
- task = self.get_task(
- CreateMultipartUploadTask,
- main_kwargs={
- 'client': self.client,
- 'bucket': self.bucket,
- 'key': self.key,
- 'extra_args': extra_args,
- },
- )
- self.stubber.add_response(
- method='create_multipart_upload',
- service_response=response,
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'Metadata': {'foo': 'bar'},
- },
- )
- result_id = task()
- self.stubber.assert_no_pending_responses()
- # Ensure the upload id returned is correct
- self.assertEqual(upload_id, result_id)
-
- # Make sure that the abort was added as a cleanup failure
- self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1)
-
- # Make sure if it is called, it will abort correctly
- self.stubber.add_response(
- method='abort_multipart_upload',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': upload_id,
- },
- )
- self.transfer_coordinator.failure_cleanups[0]()
- self.stubber.assert_no_pending_responses()
-
-
-class TestCompleteMultipartUploadTask(BaseMultipartTaskTest):
- def test_main(self):
- upload_id = 'my-id'
- parts = [{'ETag': 'etag', 'PartNumber': 0}]
- task = self.get_task(
- CompleteMultipartUploadTask,
- main_kwargs={
- 'client': self.client,
- 'bucket': self.bucket,
- 'key': self.key,
- 'upload_id': upload_id,
- 'parts': parts,
- 'extra_args': {},
- },
- )
- self.stubber.add_response(
- method='complete_multipart_upload',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': upload_id,
- 'MultipartUpload': {'Parts': parts},
- },
- )
- task()
- self.stubber.assert_no_pending_responses()
-
- def test_includes_extra_args(self):
- upload_id = 'my-id'
- parts = [{'ETag': 'etag', 'PartNumber': 0}]
- task = self.get_task(
- CompleteMultipartUploadTask,
- main_kwargs={
- 'client': self.client,
- 'bucket': self.bucket,
- 'key': self.key,
- 'upload_id': upload_id,
- 'parts': parts,
- 'extra_args': {'RequestPayer': 'requester'},
- },
- )
- self.stubber.add_response(
- method='complete_multipart_upload',
- service_response={},
- expected_params={
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': upload_id,
- 'MultipartUpload': {'Parts': parts},
- 'RequestPayer': 'requester',
- },
- )
- task()
- self.stubber.assert_no_pending_responses()
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from concurrent import futures
+from functools import partial
+from threading import Event
+
+from s3transfer.futures import BoundedExecutor, TransferCoordinator
+from s3transfer.subscribers import BaseSubscriber
+from s3transfer.tasks import (
+ CompleteMultipartUploadTask,
+ CreateMultipartUploadTask,
+ SubmissionTask,
+ Task,
+)
+from s3transfer.utils import CallArgs, FunctionContainer, get_callbacks
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ RecordingSubscriber,
+ unittest,
+)
+
+
+class TaskFailureException(Exception):
+ pass
+
+
+class SuccessTask(Task):
+ def _main(
+ self, return_value='success', callbacks=None, failure_cleanups=None
+ ):
+ if callbacks:
+ for callback in callbacks:
+ callback()
+ if failure_cleanups:
+ for failure_cleanup in failure_cleanups:
+ self._transfer_coordinator.add_failure_cleanup(failure_cleanup)
+ return return_value
+
+
+class FailureTask(Task):
+ def _main(self, exception=TaskFailureException):
+ raise exception()
+
+
+class ReturnKwargsTask(Task):
+ def _main(self, **kwargs):
+ return kwargs
+
+
+class SubmitMoreTasksTask(Task):
+ def _main(self, executor, tasks_to_submit):
+ for task_to_submit in tasks_to_submit:
+ self._transfer_coordinator.submit(executor, task_to_submit)
+
+
+class NOOPSubmissionTask(SubmissionTask):
+ def _submit(self, transfer_future, **kwargs):
+ pass
+
+
+class ExceptionSubmissionTask(SubmissionTask):
+ def _submit(
+ self,
+ transfer_future,
+ executor=None,
+ tasks_to_submit=None,
+ additional_callbacks=None,
+ exception=TaskFailureException,
+ ):
+ if executor and tasks_to_submit:
+ for task_to_submit in tasks_to_submit:
+ self._transfer_coordinator.submit(executor, task_to_submit)
+ if additional_callbacks:
+ for callback in additional_callbacks:
+ callback()
+ raise exception()
+
+
+class StatusRecordingTransferCoordinator(TransferCoordinator):
+ def __init__(self, transfer_id=None):
+ super().__init__(transfer_id)
+ self.status_changes = [self._status]
+
+ def set_status_to_queued(self):
+ super().set_status_to_queued()
+ self._record_status_change()
+
+ def set_status_to_running(self):
+ super().set_status_to_running()
+ self._record_status_change()
+
+ def _record_status_change(self):
+ self.status_changes.append(self._status)
+
+
+class RecordingStateSubscriber(BaseSubscriber):
+ def __init__(self, transfer_coordinator):
+ self._transfer_coordinator = transfer_coordinator
+ self.status_during_on_queued = None
+
+ def on_queued(self, **kwargs):
+ self.status_during_on_queued = self._transfer_coordinator.status
+
+
+class TestSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.executor = BoundedExecutor(1000, 5)
+ self.call_args = CallArgs(subscribers=[])
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.main_kwargs = {'transfer_future': self.transfer_future}
+
+ def test_transitions_from_not_started_to_queued_to_running(self):
+ self.transfer_coordinator = StatusRecordingTransferCoordinator()
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ # Status should be queued until submission task has been ran.
+ self.assertEqual(self.transfer_coordinator.status, 'not-started')
+
+ submission_task()
+ # Once submission task has been ran, the status should now be running.
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # Ensure the transitions were as expected as well.
+ self.assertEqual(
+ self.transfer_coordinator.status_changes,
+ ['not-started', 'queued', 'running'],
+ )
+
+ def test_on_queued_callbacks(self):
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+ submission_task()
+ # Make sure the on_queued callback of the subscriber is called.
+ self.assertEqual(
+ subscriber.on_queued_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_on_queued_status_in_callbacks(self):
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingStateSubscriber(self.transfer_coordinator)
+ self.call_args.subscribers.append(subscriber)
+ submission_task()
+ # Make sure the status was queued during on_queued callback.
+ self.assertEqual(subscriber.status_during_on_queued, 'queued')
+
+ def test_sets_exception_from_submit(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ # Make sure the status of the future is failed
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the future propagates the exception encountered in the
+ # submission task.
+ with self.assertRaises(TaskFailureException):
+ self.transfer_future.result()
+
+ def test_catches_and_sets_keyboard_interrupt_exception_from_submit(self):
+ self.main_kwargs['exception'] = KeyboardInterrupt
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ with self.assertRaises(KeyboardInterrupt):
+ self.transfer_future.result()
+
+ def test_calls_done_callbacks_on_exception(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+
+ # Add the done callback to the callbacks to be invoked when the
+ # transfer is done.
+ done_callbacks = get_callbacks(self.transfer_future, 'done')
+ for done_callback in done_callbacks:
+ self.transfer_coordinator.add_done_callback(done_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the on_done callback of the subscriber is called.
+ self.assertEqual(
+ subscriber.on_done_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_calls_failure_cleanups_on_exception(self):
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ # Add the callback to the callbacks to be invoked when the
+ # transfer fails.
+ invocations_of_cleanup = []
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_cleanups_only_ran_once_on_exception(self):
+ # We want to be able to handle the case where the final task completes
+ # and anounces done but there is an error in the submission task
+ # which will cause it to need to announce done as well. In this case,
+ # we do not want the done callbacks to be invoke more than once.
+
+ final_task = self.get_task(FailureTask, is_final=True)
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [final_task]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ subscriber = RecordingSubscriber()
+ self.call_args.subscribers.append(subscriber)
+
+ # Add the done callback to the callbacks to be invoked when the
+ # transfer is done.
+ done_callbacks = get_callbacks(self.transfer_future, 'done')
+ for done_callback in done_callbacks:
+ self.transfer_coordinator.add_done_callback(done_callback)
+
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the on_done callback of the subscriber is called only once.
+ self.assertEqual(
+ subscriber.on_done_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_done_callbacks_only_ran_once_on_exception(self):
+ # We want to be able to handle the case where the final task completes
+ # and anounces done but there is an error in the submission task
+ # which will cause it to need to announce done as well. In this case,
+ # we do not want the failure cleanups to be invoked more than once.
+
+ final_task = self.get_task(FailureTask, is_final=True)
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [final_task]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ # Add the callback to the callbacks to be invoked when the
+ # transfer fails.
+ invocations_of_cleanup = []
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ self.transfer_coordinator.add_failure_cleanup(cleanup_callback)
+ submission_task()
+
+ # Make sure the task failed to start
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called only once.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_handles_cleanups_submitted_in_other_tasks(self):
+ invocations_of_cleanup = []
+ event = Event()
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+ # We want the cleanup to be added in the execution of the task and
+ # still be executed by the submission task when it fails.
+ task = self.get_task(
+ SuccessTask,
+ main_kwargs={
+ 'callbacks': [event.set],
+ 'failure_cleanups': [cleanup_callback],
+ },
+ )
+
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [task]
+ self.main_kwargs['additional_callbacks'] = [event.wait]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ submission_task()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Make sure the cleanup was called even though the callback got
+ # added in a completely different task.
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_waits_for_tasks_submitted_by_other_tasks_on_exception(self):
+ # In this test, we want to make sure that any tasks that may be
+ # submitted in another task complete before we start performing
+ # cleanups.
+ #
+ # This is tested by doing the following:
+ #
+ # ExecutionSubmissionTask
+ # |
+ # +--submits-->SubmitMoreTasksTask
+ # |
+ # +--submits-->SuccessTask
+ # |
+ # +-->sleeps-->adds failure cleanup
+ #
+ # In the end, the failure cleanup of the SuccessTask should be ran
+ # when the ExecutionSubmissionTask fails. If the
+ # ExeceptionSubmissionTask did not run the failure cleanup it is most
+ # likely that it did not wait for the SuccessTask to complete, which
+ # it needs to because the ExeceptionSubmissionTask does not know
+ # what failure cleanups it needs to run until all spawned tasks have
+ # completed.
+ invocations_of_cleanup = []
+ event = Event()
+ cleanup_callback = FunctionContainer(
+ invocations_of_cleanup.append, 'cleanup happened'
+ )
+
+ cleanup_task = self.get_task(
+ SuccessTask,
+ main_kwargs={
+ 'callbacks': [event.set],
+ 'failure_cleanups': [cleanup_callback],
+ },
+ )
+ task_for_submitting_cleanup_task = self.get_task(
+ SubmitMoreTasksTask,
+ main_kwargs={
+ 'executor': self.executor,
+ 'tasks_to_submit': [cleanup_task],
+ },
+ )
+
+ self.main_kwargs['executor'] = self.executor
+ self.main_kwargs['tasks_to_submit'] = [
+ task_for_submitting_cleanup_task
+ ]
+ self.main_kwargs['additional_callbacks'] = [event.wait]
+
+ submission_task = self.get_task(
+ ExceptionSubmissionTask, main_kwargs=self.main_kwargs
+ )
+
+ submission_task()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+ self.assertEqual(invocations_of_cleanup, ['cleanup happened'])
+
+ def test_submission_task_announces_done_if_cancelled_before_main(self):
+ invocations_of_done = []
+ done_callback = FunctionContainer(
+ invocations_of_done.append, 'done announced'
+ )
+ self.transfer_coordinator.add_done_callback(done_callback)
+
+ self.transfer_coordinator.cancel()
+ submission_task = self.get_task(
+ NOOPSubmissionTask, main_kwargs=self.main_kwargs
+ )
+ submission_task()
+
+ # Because the submission task was cancelled before being run
+ # it did not submit any extra tasks so a result it is responsible
+ # for making sure it announces done as nothing else will.
+ self.assertEqual(invocations_of_done, ['done announced'])
+
+
+class TestTask(unittest.TestCase):
+ def setUp(self):
+ self.transfer_id = 1
+ self.transfer_coordinator = TransferCoordinator(
+ transfer_id=self.transfer_id
+ )
+
+ def test_repr(self):
+ main_kwargs = {'bucket': 'mybucket', 'param_to_not_include': 'foo'}
+ task = ReturnKwargsTask(
+ self.transfer_coordinator, main_kwargs=main_kwargs
+ )
+ # The repr should not include the other parameter because it is not
+ # a desired parameter to include.
+ self.assertEqual(
+ repr(task),
+ 'ReturnKwargsTask(transfer_id={}, {})'.format(
+ self.transfer_id, {'bucket': 'mybucket'}
+ ),
+ )
+
+ def test_transfer_id(self):
+ task = SuccessTask(self.transfer_coordinator)
+ # Make sure that the id is the one provided to the id associated
+ # to the transfer coordinator.
+ self.assertEqual(task.transfer_id, self.transfer_id)
+
+ def test_context_status_transitioning_success(self):
+ # The status should be set to running.
+ self.transfer_coordinator.set_status_to_running()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # If a task is called, the status still should be running.
+ SuccessTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # Once the final task is called, the status should be set to success.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(self.transfer_coordinator.status, 'success')
+
+ def test_context_status_transitioning_failed(self):
+ self.transfer_coordinator.set_status_to_running()
+
+ SuccessTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'running')
+
+ # A failure task should result in the failed status
+ FailureTask(self.transfer_coordinator)()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ # Even if the final task comes in and succeeds, it should stay failed.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(self.transfer_coordinator.status, 'failed')
+
+ def test_result_setting_for_success(self):
+ override_return = 'foo'
+ SuccessTask(self.transfer_coordinator)()
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': override_return},
+ is_final=True,
+ )()
+
+ # The return value for the transfer future should be of the final
+ # task.
+ self.assertEqual(self.transfer_coordinator.result(), override_return)
+
+ def test_result_setting_for_error(self):
+ FailureTask(self.transfer_coordinator)()
+
+ # If another failure comes in, the result should still throw the
+ # original exception when result() is eventually called.
+ FailureTask(
+ self.transfer_coordinator, main_kwargs={'exception': Exception}
+ )()
+
+ # Even if a success task comes along, the result of the future
+ # should be the original exception
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+ def test_done_callbacks_success(self):
+ callback_results = []
+ SuccessTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'first'),
+ partial(callback_results.append, 'second'),
+ ],
+ )()
+ # For successful tasks, the done callbacks should get called.
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ def test_done_callbacks_failure(self):
+ callback_results = []
+ FailureTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'first'),
+ partial(callback_results.append, 'second'),
+ ],
+ )()
+ # For even failed tasks, the done callbacks should get called.
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ # Callbacks should continue to be called even after a related failure
+ SuccessTask(
+ self.transfer_coordinator,
+ done_callbacks=[
+ partial(callback_results.append, 'third'),
+ partial(callback_results.append, 'fourth'),
+ ],
+ )()
+ self.assertEqual(
+ callback_results, ['first', 'second', 'third', 'fourth']
+ )
+
+ def test_failure_cleanups_on_failure(self):
+ callback_results = []
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'first'
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'second'
+ )
+ FailureTask(self.transfer_coordinator)()
+ # The failure callbacks should have not been called yet because it
+ # is not the last task
+ self.assertEqual(callback_results, [])
+
+ # Now the failure callbacks should get called.
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ self.assertEqual(callback_results, ['first', 'second'])
+
+ def test_no_failure_cleanups_on_success(self):
+ callback_results = []
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'first'
+ )
+ self.transfer_coordinator.add_failure_cleanup(
+ callback_results.append, 'second'
+ )
+ SuccessTask(self.transfer_coordinator, is_final=True)()
+ # The failure cleanups should not have been called because no task
+ # failed for the transfer context.
+ self.assertEqual(callback_results, [])
+
+ def test_passing_main_kwargs(self):
+ main_kwargs = {'foo': 'bar', 'baz': 'biz'}
+ ReturnKwargsTask(
+ self.transfer_coordinator, main_kwargs=main_kwargs, is_final=True
+ )()
+ # The kwargs should have been passed to the main()
+ self.assertEqual(self.transfer_coordinator.result(), main_kwargs)
+
+ def test_passing_pending_kwargs_single_futures(self):
+ pending_kwargs = {}
+ ref_main_kwargs = {'foo': 'bar', 'baz': 'biz'}
+
+ # Pass some tasks to an executor
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['foo'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo']},
+ )
+ )
+ pending_kwargs['baz'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['baz']},
+ )
+ )
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have the pending keyword arg values flushed
+ # out.
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_passing_pending_kwargs_list_of_futures(self):
+ pending_kwargs = {}
+ ref_main_kwargs = {'foo': ['first', 'second']}
+
+ # Pass some tasks to an executor
+ with futures.ThreadPoolExecutor(1) as executor:
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo'][0]},
+ )
+ )
+ second_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': ref_main_kwargs['foo'][1]},
+ )
+ )
+ # Make the pending keyword arg value a list
+ pending_kwargs['foo'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have the pending keyword arg values flushed
+ # out in the expected order.
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_passing_pending_and_non_pending_kwargs(self):
+ main_kwargs = {'nonpending_value': 'foo'}
+ pending_kwargs = {}
+ ref_main_kwargs = {
+ 'nonpending_value': 'foo',
+ 'pending_value': 'bar',
+ 'pending_list': ['first', 'second'],
+ }
+
+ # Create the pending tasks
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['pending_value'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_value']
+ },
+ )
+ )
+
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_list'][0]
+ },
+ )
+ )
+ second_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={
+ 'return_value': ref_main_kwargs['pending_list'][1]
+ },
+ )
+ )
+ # Make the pending keyword arg value a list
+ pending_kwargs['pending_list'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ # and just regular nonpending kwargs.
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ main_kwargs=main_kwargs,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The result should have all of the kwargs (both pending and
+ # nonpending)
+ self.assertEqual(self.transfer_coordinator.result(), ref_main_kwargs)
+
+ def test_single_failed_pending_future(self):
+ pending_kwargs = {}
+
+ # Pass some tasks to an executor. Make one successful and the other
+ # a failure.
+ with futures.ThreadPoolExecutor(1) as executor:
+ pending_kwargs['foo'] = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': 'bar'},
+ )
+ )
+ pending_kwargs['baz'] = executor.submit(
+ FailureTask(self.transfer_coordinator)
+ )
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The end result should raise the exception from the initial
+ # pending future value
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+ def test_single_failed_pending_future_in_list(self):
+ pending_kwargs = {}
+
+ # Pass some tasks to an executor. Make one successful and the other
+ # a failure.
+ with futures.ThreadPoolExecutor(1) as executor:
+ first_future = executor.submit(
+ SuccessTask(
+ self.transfer_coordinator,
+ main_kwargs={'return_value': 'bar'},
+ )
+ )
+ second_future = executor.submit(
+ FailureTask(self.transfer_coordinator)
+ )
+
+ pending_kwargs['pending_list'] = [first_future, second_future]
+
+ # Create a task that depends on the tasks passed to the executor
+ ReturnKwargsTask(
+ self.transfer_coordinator,
+ pending_main_kwargs=pending_kwargs,
+ is_final=True,
+ )()
+ # The end result should raise the exception from the initial
+ # pending future value in the list
+ with self.assertRaises(TaskFailureException):
+ self.transfer_coordinator.result()
+
+
+class BaseMultipartTaskTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'foo'
+
+
+class TestCreateMultipartUploadTask(BaseMultipartTaskTest):
+ def test_main(self):
+ upload_id = 'foo'
+ extra_args = {'Metadata': {'foo': 'bar'}}
+ response = {'UploadId': upload_id}
+ task = self.get_task(
+ CreateMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response=response,
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Metadata': {'foo': 'bar'},
+ },
+ )
+ result_id = task()
+ self.stubber.assert_no_pending_responses()
+ # Ensure the upload id returned is correct
+ self.assertEqual(upload_id, result_id)
+
+ # Make sure that the abort was added as a cleanup failure
+ self.assertEqual(len(self.transfer_coordinator.failure_cleanups), 1)
+
+ # Make sure if it is called, it will abort correctly
+ self.stubber.add_response(
+ method='abort_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ },
+ )
+ self.transfer_coordinator.failure_cleanups[0]()
+ self.stubber.assert_no_pending_responses()
+
+
+class TestCompleteMultipartUploadTask(BaseMultipartTaskTest):
+ def test_main(self):
+ upload_id = 'my-id'
+ parts = [{'ETag': 'etag', 'PartNumber': 0}]
+ task = self.get_task(
+ CompleteMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'parts': parts,
+ 'extra_args': {},
+ },
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {'Parts': parts},
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
+
+ def test_includes_extra_args(self):
+ upload_id = 'my-id'
+ parts = [{'ETag': 'etag', 'PartNumber': 0}]
+ task = self.get_task(
+ CompleteMultipartUploadTask,
+ main_kwargs={
+ 'client': self.client,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'parts': parts,
+ 'extra_args': {'RequestPayer': 'requester'},
+ },
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload',
+ service_response={},
+ expected_params={
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'MultipartUpload': {'Parts': parts},
+ 'RequestPayer': 'requester',
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_upload.py b/contrib/python/s3transfer/py3/tests/unit/test_upload.py
index 622255417c..1ac38b3616 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_upload.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_upload.py
@@ -1,694 +1,694 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License'). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the 'license' file accompanying this file. This file is
-# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-
-import math
-import os
-import shutil
-import tempfile
-from io import BytesIO
-
-from botocore.stub import ANY
-
-from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
-from s3transfer.manager import TransferConfig
-from s3transfer.upload import (
- AggregatedProgressCallback,
- InterruptReader,
- PutObjectTask,
- UploadFilenameInputManager,
- UploadNonSeekableInputManager,
- UploadPartTask,
- UploadSeekableInputManager,
- UploadSubmissionTask,
-)
-from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
-from __tests__ import (
- BaseSubmissionTaskTest,
- BaseTaskTest,
- FileSizeProvider,
- NonSeekableReader,
- RecordingExecutor,
- RecordingSubscriber,
- unittest,
-)
-
-
-class InterruptionError(Exception):
- pass
-
-
-class OSUtilsExceptionOnFileSize(OSUtils):
- def get_file_size(self, filename):
- raise AssertionError(
- "The file %s should not have been stated" % filename
- )
-
-
-class BaseUploadTest(BaseTaskTest):
- def setUp(self):
- super().setUp()
- self.bucket = 'mybucket'
- self.key = 'foo'
- self.osutil = OSUtils()
-
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
- self.content = b'my content'
- self.subscribers = []
-
- with open(self.filename, 'wb') as f:
- f.write(self.content)
-
- # A list to keep track of all of the bodies sent over the wire
- # and their order.
- self.sent_bodies = []
- self.client.meta.events.register(
- 'before-parameter-build.s3.*', self.collect_body
- )
-
- def tearDown(self):
- super().tearDown()
- shutil.rmtree(self.tempdir)
-
- def collect_body(self, params, **kwargs):
- if 'Body' in params:
- self.sent_bodies.append(params['Body'].read())
-
-
-class TestAggregatedProgressCallback(unittest.TestCase):
- def setUp(self):
- self.aggregated_amounts = []
- self.threshold = 3
- self.aggregated_progress_callback = AggregatedProgressCallback(
- [self.callback], self.threshold
- )
-
- def callback(self, bytes_transferred):
- self.aggregated_amounts.append(bytes_transferred)
-
- def test_under_threshold(self):
- one_under_threshold_amount = self.threshold - 1
- self.aggregated_progress_callback(one_under_threshold_amount)
- self.assertEqual(self.aggregated_amounts, [])
- self.aggregated_progress_callback(1)
- self.assertEqual(self.aggregated_amounts, [self.threshold])
-
- def test_at_threshold(self):
- self.aggregated_progress_callback(self.threshold)
- self.assertEqual(self.aggregated_amounts, [self.threshold])
-
- def test_over_threshold(self):
- over_threshold_amount = self.threshold + 1
- self.aggregated_progress_callback(over_threshold_amount)
- self.assertEqual(self.aggregated_amounts, [over_threshold_amount])
-
- def test_flush(self):
- under_threshold_amount = self.threshold - 1
- self.aggregated_progress_callback(under_threshold_amount)
- self.assertEqual(self.aggregated_amounts, [])
- self.aggregated_progress_callback.flush()
- self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
-
- def test_flush_with_nothing_to_flush(self):
- under_threshold_amount = self.threshold - 1
- self.aggregated_progress_callback(under_threshold_amount)
- self.assertEqual(self.aggregated_amounts, [])
- self.aggregated_progress_callback.flush()
- self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
- # Flushing again should do nothing as it was just flushed
- self.aggregated_progress_callback.flush()
- self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
-
-
-class TestInterruptReader(BaseUploadTest):
- def test_read_raises_exception(self):
- with open(self.filename, 'rb') as f:
- reader = InterruptReader(f, self.transfer_coordinator)
- # Read some bytes to show it can be read.
- self.assertEqual(reader.read(1), self.content[0:1])
- # Then set an exception in the transfer coordinator
- self.transfer_coordinator.set_exception(InterruptionError())
- # The next read should have the exception propograte
- with self.assertRaises(InterruptionError):
- reader.read()
-
- def test_seek(self):
- with open(self.filename, 'rb') as f:
- reader = InterruptReader(f, self.transfer_coordinator)
- # Ensure it can seek correctly
- reader.seek(1)
- self.assertEqual(reader.read(1), self.content[1:2])
-
- def test_tell(self):
- with open(self.filename, 'rb') as f:
- reader = InterruptReader(f, self.transfer_coordinator)
- # Ensure it can tell correctly
- reader.seek(1)
- self.assertEqual(reader.tell(), 1)
-
-
-class BaseUploadInputManagerTest(BaseUploadTest):
- def setUp(self):
- super().setUp()
- self.osutil = OSUtils()
- self.config = TransferConfig()
- self.recording_subscriber = RecordingSubscriber()
- self.subscribers.append(self.recording_subscriber)
-
- def _get_expected_body_for_part(self, part_number):
- # A helper method for retrieving the expected body for a specific
- # part number of the data
- total_size = len(self.content)
- chunk_size = self.config.multipart_chunksize
- start_index = (part_number - 1) * chunk_size
- end_index = part_number * chunk_size
- if end_index >= total_size:
- return self.content[start_index:]
- return self.content[start_index:end_index]
-
-
-class TestUploadFilenameInputManager(BaseUploadInputManagerTest):
- def setUp(self):
- super().setUp()
- self.upload_input_manager = UploadFilenameInputManager(
- self.osutil, self.transfer_coordinator
- )
- self.call_args = CallArgs(
- fileobj=self.filename, subscribers=self.subscribers
- )
- self.future = self.get_transfer_future(self.call_args)
-
- def test_is_compatible(self):
- self.assertTrue(
- self.upload_input_manager.is_compatible(
- self.future.meta.call_args.fileobj
- )
- )
-
- def test_stores_bodies_in_memory_put_object(self):
- self.assertFalse(
- self.upload_input_manager.stores_body_in_memory('put_object')
- )
-
- def test_stores_bodies_in_memory_upload_part(self):
- self.assertFalse(
- self.upload_input_manager.stores_body_in_memory('upload_part')
- )
-
- def test_provide_transfer_size(self):
- self.upload_input_manager.provide_transfer_size(self.future)
- # The provided file size should be equal to size of the contents of
- # the file.
- self.assertEqual(self.future.meta.size, len(self.content))
-
- def test_requires_multipart_upload(self):
- self.future.meta.provide_transfer_size(len(self.content))
- # With the default multipart threshold, the length of the content
- # should be smaller than the threshold thus not requiring a multipart
- # transfer.
- self.assertFalse(
- self.upload_input_manager.requires_multipart_upload(
- self.future, self.config
- )
- )
- # Decreasing the threshold to that of the length of the content of
- # the file should trigger the need for a multipart upload.
- self.config.multipart_threshold = len(self.content)
- self.assertTrue(
- self.upload_input_manager.requires_multipart_upload(
- self.future, self.config
- )
- )
-
- def test_get_put_object_body(self):
- self.future.meta.provide_transfer_size(len(self.content))
- read_file_chunk = self.upload_input_manager.get_put_object_body(
- self.future
- )
- read_file_chunk.enable_callback()
- # The file-like object provided back should be the same as the content
- # of the file.
- with read_file_chunk:
- self.assertEqual(read_file_chunk.read(), self.content)
- # The file-like object should also have been wrapped with the
- # on_queued callbacks to track the amount of bytes being transferred.
- self.assertEqual(
- self.recording_subscriber.calculate_bytes_seen(), len(self.content)
- )
-
- def test_get_put_object_body_is_interruptable(self):
- self.future.meta.provide_transfer_size(len(self.content))
- read_file_chunk = self.upload_input_manager.get_put_object_body(
- self.future
- )
-
- # Set an exception in the transfer coordinator
- self.transfer_coordinator.set_exception(InterruptionError)
- # Ensure the returned read file chunk can be interrupted with that
- # error.
- with self.assertRaises(InterruptionError):
- read_file_chunk.read()
-
- def test_yield_upload_part_bodies(self):
- # Adjust the chunk size to something more grainular for testing.
- self.config.multipart_chunksize = 4
- self.future.meta.provide_transfer_size(len(self.content))
-
- # Get an iterator that will yield all of the bodies and their
- # respective part number.
- part_iterator = self.upload_input_manager.yield_upload_part_bodies(
- self.future, self.config.multipart_chunksize
- )
- expected_part_number = 1
- for part_number, read_file_chunk in part_iterator:
- # Ensure that the part number is as expected
- self.assertEqual(part_number, expected_part_number)
- read_file_chunk.enable_callback()
- # Ensure that the body is correct for that part.
- with read_file_chunk:
- self.assertEqual(
- read_file_chunk.read(),
- self._get_expected_body_for_part(part_number),
- )
- expected_part_number += 1
-
- # All of the file-like object should also have been wrapped with the
- # on_queued callbacks to track the amount of bytes being transferred.
- self.assertEqual(
- self.recording_subscriber.calculate_bytes_seen(), len(self.content)
- )
-
- def test_yield_upload_part_bodies_are_interruptable(self):
- # Adjust the chunk size to something more grainular for testing.
- self.config.multipart_chunksize = 4
- self.future.meta.provide_transfer_size(len(self.content))
-
- # Get an iterator that will yield all of the bodies and their
- # respective part number.
- part_iterator = self.upload_input_manager.yield_upload_part_bodies(
- self.future, self.config.multipart_chunksize
- )
-
- # Set an exception in the transfer coordinator
- self.transfer_coordinator.set_exception(InterruptionError)
- for _, read_file_chunk in part_iterator:
- # Ensure that each read file chunk yielded can be interrupted
- # with that error.
- with self.assertRaises(InterruptionError):
- read_file_chunk.read()
-
-
-class TestUploadSeekableInputManager(TestUploadFilenameInputManager):
- def setUp(self):
- super().setUp()
- self.upload_input_manager = UploadSeekableInputManager(
- self.osutil, self.transfer_coordinator
- )
- self.fileobj = open(self.filename, 'rb')
- self.call_args = CallArgs(
- fileobj=self.fileobj, subscribers=self.subscribers
- )
- self.future = self.get_transfer_future(self.call_args)
-
- def tearDown(self):
- self.fileobj.close()
- super().tearDown()
-
- def test_is_compatible_bytes_io(self):
- self.assertTrue(self.upload_input_manager.is_compatible(BytesIO()))
-
- def test_not_compatible_for_non_filelike_obj(self):
- self.assertFalse(self.upload_input_manager.is_compatible(object()))
-
- def test_stores_bodies_in_memory_upload_part(self):
- self.assertTrue(
- self.upload_input_manager.stores_body_in_memory('upload_part')
- )
-
- def test_get_put_object_body(self):
- start_pos = 3
- self.fileobj.seek(start_pos)
- adjusted_size = len(self.content) - start_pos
- self.future.meta.provide_transfer_size(adjusted_size)
- read_file_chunk = self.upload_input_manager.get_put_object_body(
- self.future
- )
-
- read_file_chunk.enable_callback()
- # The fact that the file was seeked to start should be taken into
- # account in length and content for the read file chunk.
- with read_file_chunk:
- self.assertEqual(len(read_file_chunk), adjusted_size)
- self.assertEqual(read_file_chunk.read(), self.content[start_pos:])
- self.assertEqual(
- self.recording_subscriber.calculate_bytes_seen(), adjusted_size
- )
-
-
-class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager):
- def setUp(self):
- super().setUp()
- self.upload_input_manager = UploadNonSeekableInputManager(
- self.osutil, self.transfer_coordinator
- )
- self.fileobj = NonSeekableReader(self.content)
- self.call_args = CallArgs(
- fileobj=self.fileobj, subscribers=self.subscribers
- )
- self.future = self.get_transfer_future(self.call_args)
-
- def assert_multipart_parts(self):
- """
- Asserts that the input manager will generate a multipart upload
- and that each part is in order and the correct size.
- """
- # Assert that a multipart upload is required.
- self.assertTrue(
- self.upload_input_manager.requires_multipart_upload(
- self.future, self.config
- )
- )
-
- # Get a list of all the parts that would be sent.
- parts = list(
- self.upload_input_manager.yield_upload_part_bodies(
- self.future, self.config.multipart_chunksize
- )
- )
-
- # Assert that the actual number of parts is what we would expect
- # based on the configuration.
- size = self.config.multipart_chunksize
- num_parts = math.ceil(len(self.content) / size)
- self.assertEqual(len(parts), num_parts)
-
- # Run for every part but the last part.
- for i, part in enumerate(parts[:-1]):
- # Assert the part number is correct.
- self.assertEqual(part[0], i + 1)
- # Assert the part contains the right amount of data.
- data = part[1].read()
- self.assertEqual(len(data), size)
-
- # Assert that the last part is the correct size.
- expected_final_size = len(self.content) - ((num_parts - 1) * size)
- final_part = parts[-1]
- self.assertEqual(len(final_part[1].read()), expected_final_size)
-
- # Assert that the last part has the correct part number.
- self.assertEqual(final_part[0], len(parts))
-
- def test_provide_transfer_size(self):
- self.upload_input_manager.provide_transfer_size(self.future)
- # There is no way to get the size without reading the entire body.
- self.assertEqual(self.future.meta.size, None)
-
- def test_stores_bodies_in_memory_upload_part(self):
- self.assertTrue(
- self.upload_input_manager.stores_body_in_memory('upload_part')
- )
-
- def test_stores_bodies_in_memory_put_object(self):
- self.assertTrue(
- self.upload_input_manager.stores_body_in_memory('put_object')
- )
-
- def test_initial_data_parts_threshold_lesser(self):
- # threshold < size
- self.config.multipart_chunksize = 4
- self.config.multipart_threshold = 2
- self.assert_multipart_parts()
-
- def test_initial_data_parts_threshold_equal(self):
- # threshold == size
- self.config.multipart_chunksize = 4
- self.config.multipart_threshold = 4
- self.assert_multipart_parts()
-
- def test_initial_data_parts_threshold_greater(self):
- # threshold > size
- self.config.multipart_chunksize = 4
- self.config.multipart_threshold = 8
- self.assert_multipart_parts()
-
-
-class TestUploadSubmissionTask(BaseSubmissionTaskTest):
- def setUp(self):
- super().setUp()
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'myfile')
- self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3)
- self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE
- self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5
-
- with open(self.filename, 'wb') as f:
- f.write(self.content)
-
- self.bucket = 'mybucket'
- self.key = 'mykey'
- self.extra_args = {}
- self.subscribers = []
-
- # A list to keep track of all of the bodies sent over the wire
- # and their order.
- self.sent_bodies = []
- self.client.meta.events.register(
- 'before-parameter-build.s3.*', self.collect_body
- )
-
- self.call_args = self.get_call_args()
- self.transfer_future = self.get_transfer_future(self.call_args)
- self.submission_main_kwargs = {
- 'client': self.client,
- 'config': self.config,
- 'osutil': self.osutil,
- 'request_executor': self.executor,
- 'transfer_future': self.transfer_future,
- }
- self.submission_task = self.get_task(
- UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
-
- def tearDown(self):
- super().tearDown()
- shutil.rmtree(self.tempdir)
-
- def collect_body(self, params, **kwargs):
- if 'Body' in params:
- self.sent_bodies.append(params['Body'].read())
-
- def get_call_args(self, **kwargs):
- default_call_args = {
- 'fileobj': self.filename,
- 'bucket': self.bucket,
- 'key': self.key,
- 'extra_args': self.extra_args,
- 'subscribers': self.subscribers,
- }
- default_call_args.update(kwargs)
- return CallArgs(**default_call_args)
-
- def add_multipart_upload_stubbed_responses(self):
- self.stubber.add_response(
- method='create_multipart_upload',
- service_response={'UploadId': 'my-id'},
- )
- self.stubber.add_response(
- method='upload_part', service_response={'ETag': 'etag-1'}
- )
- self.stubber.add_response(
- method='upload_part', service_response={'ETag': 'etag-2'}
- )
- self.stubber.add_response(
- method='upload_part', service_response={'ETag': 'etag-3'}
- )
- self.stubber.add_response(
- method='complete_multipart_upload', service_response={}
- )
-
- def wrap_executor_in_recorder(self):
- self.executor = RecordingExecutor(self.executor)
- self.submission_main_kwargs['request_executor'] = self.executor
-
- def use_fileobj_in_call_args(self, fileobj):
- self.call_args = self.get_call_args(fileobj=fileobj)
- self.transfer_future = self.get_transfer_future(self.call_args)
- self.submission_main_kwargs['transfer_future'] = self.transfer_future
-
- def assert_tag_value_for_put_object(self, tag_value):
- self.assertEqual(self.executor.submissions[0]['tag'], tag_value)
-
- def assert_tag_value_for_upload_parts(self, tag_value):
- for submission in self.executor.submissions[1:-1]:
- self.assertEqual(submission['tag'], tag_value)
-
- def test_provide_file_size_on_put(self):
- self.call_args.subscribers.append(FileSizeProvider(len(self.content)))
- self.stubber.add_response(
- method='put_object',
- service_response={},
- expected_params={
- 'Body': ANY,
- 'Bucket': self.bucket,
- 'Key': self.key,
- },
- )
-
- # With this submitter, it will fail to stat the file if a transfer
- # size is not provided.
- self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize()
-
- self.submission_task = self.get_task(
- UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
- self.submission_task()
- self.transfer_future.result()
- self.stubber.assert_no_pending_responses()
- self.assertEqual(self.sent_bodies, [self.content])
-
- def test_submits_no_tag_for_put_object_filename(self):
- self.wrap_executor_in_recorder()
- self.stubber.add_response('put_object', {})
-
- self.submission_task = self.get_task(
- UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
- self.submission_task()
- self.transfer_future.result()
- self.stubber.assert_no_pending_responses()
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_value_for_put_object(None)
-
- def test_submits_no_tag_for_multipart_filename(self):
- self.wrap_executor_in_recorder()
-
- # Set up for a multipart upload.
- self.add_multipart_upload_stubbed_responses()
- self.config.multipart_threshold = 1
-
- self.submission_task = self.get_task(
- UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
- self.submission_task()
- self.transfer_future.result()
- self.stubber.assert_no_pending_responses()
-
- # Make sure no tag to limit any of the upload part tasks were
- # were associated when submitted to the executor
- self.assert_tag_value_for_upload_parts(None)
-
- def test_submits_no_tag_for_put_object_fileobj(self):
- self.wrap_executor_in_recorder()
- self.stubber.add_response('put_object', {})
-
- with open(self.filename, 'rb') as f:
- self.use_fileobj_in_call_args(f)
- self.submission_task = self.get_task(
- UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
- self.submission_task()
- self.transfer_future.result()
- self.stubber.assert_no_pending_responses()
-
- # Make sure no tag to limit that task specifically was not associated
- # to that task submission.
- self.assert_tag_value_for_put_object(None)
-
- def test_submits_tag_for_multipart_fileobj(self):
- self.wrap_executor_in_recorder()
-
- # Set up for a multipart upload.
- self.add_multipart_upload_stubbed_responses()
- self.config.multipart_threshold = 1
-
- with open(self.filename, 'rb') as f:
- self.use_fileobj_in_call_args(f)
- self.submission_task = self.get_task(
- UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
- )
- self.submission_task()
- self.transfer_future.result()
- self.stubber.assert_no_pending_responses()
-
- # Make sure tags to limit all of the upload part tasks were
- # were associated when submitted to the executor as these tasks will
- # have chunks of data stored with them in memory.
- self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG)
-
-
-class TestPutObjectTask(BaseUploadTest):
- def test_main(self):
- extra_args = {'Metadata': {'foo': 'bar'}}
- with open(self.filename, 'rb') as fileobj:
- task = self.get_task(
- PutObjectTask,
- main_kwargs={
- 'client': self.client,
- 'fileobj': fileobj,
- 'bucket': self.bucket,
- 'key': self.key,
- 'extra_args': extra_args,
- },
- )
- self.stubber.add_response(
- method='put_object',
- service_response={},
- expected_params={
- 'Body': ANY,
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'Metadata': {'foo': 'bar'},
- },
- )
- task()
- self.stubber.assert_no_pending_responses()
- self.assertEqual(self.sent_bodies, [self.content])
-
-
-class TestUploadPartTask(BaseUploadTest):
- def test_main(self):
- extra_args = {'RequestPayer': 'requester'}
- upload_id = 'my-id'
- part_number = 1
- etag = 'foo'
- with open(self.filename, 'rb') as fileobj:
- task = self.get_task(
- UploadPartTask,
- main_kwargs={
- 'client': self.client,
- 'fileobj': fileobj,
- 'bucket': self.bucket,
- 'key': self.key,
- 'upload_id': upload_id,
- 'part_number': part_number,
- 'extra_args': extra_args,
- },
- )
- self.stubber.add_response(
- method='upload_part',
- service_response={'ETag': etag},
- expected_params={
- 'Body': ANY,
- 'Bucket': self.bucket,
- 'Key': self.key,
- 'UploadId': upload_id,
- 'PartNumber': part_number,
- 'RequestPayer': 'requester',
- },
- )
- rval = task()
- self.stubber.assert_no_pending_responses()
- self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number})
- self.assertEqual(self.sent_bodies, [self.content])
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License'). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the 'license' file accompanying this file. This file is
+# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+
+import math
+import os
+import shutil
+import tempfile
+from io import BytesIO
+
+from botocore.stub import ANY
+
+from s3transfer.futures import IN_MEMORY_UPLOAD_TAG
+from s3transfer.manager import TransferConfig
+from s3transfer.upload import (
+ AggregatedProgressCallback,
+ InterruptReader,
+ PutObjectTask,
+ UploadFilenameInputManager,
+ UploadNonSeekableInputManager,
+ UploadPartTask,
+ UploadSeekableInputManager,
+ UploadSubmissionTask,
+)
+from s3transfer.utils import MIN_UPLOAD_CHUNKSIZE, CallArgs, OSUtils
+from __tests__ import (
+ BaseSubmissionTaskTest,
+ BaseTaskTest,
+ FileSizeProvider,
+ NonSeekableReader,
+ RecordingExecutor,
+ RecordingSubscriber,
+ unittest,
+)
+
+
+class InterruptionError(Exception):
+ pass
+
+
+class OSUtilsExceptionOnFileSize(OSUtils):
+ def get_file_size(self, filename):
+ raise AssertionError(
+ "The file %s should not have been stated" % filename
+ )
+
+
+class BaseUploadTest(BaseTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.bucket = 'mybucket'
+ self.key = 'foo'
+ self.osutil = OSUtils()
+
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'my content'
+ self.subscribers = []
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def collect_body(self, params, **kwargs):
+ if 'Body' in params:
+ self.sent_bodies.append(params['Body'].read())
+
+
+class TestAggregatedProgressCallback(unittest.TestCase):
+ def setUp(self):
+ self.aggregated_amounts = []
+ self.threshold = 3
+ self.aggregated_progress_callback = AggregatedProgressCallback(
+ [self.callback], self.threshold
+ )
+
+ def callback(self, bytes_transferred):
+ self.aggregated_amounts.append(bytes_transferred)
+
+ def test_under_threshold(self):
+ one_under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(one_under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback(1)
+ self.assertEqual(self.aggregated_amounts, [self.threshold])
+
+ def test_at_threshold(self):
+ self.aggregated_progress_callback(self.threshold)
+ self.assertEqual(self.aggregated_amounts, [self.threshold])
+
+ def test_over_threshold(self):
+ over_threshold_amount = self.threshold + 1
+ self.aggregated_progress_callback(over_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [over_threshold_amount])
+
+ def test_flush(self):
+ under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+
+ def test_flush_with_nothing_to_flush(self):
+ under_threshold_amount = self.threshold - 1
+ self.aggregated_progress_callback(under_threshold_amount)
+ self.assertEqual(self.aggregated_amounts, [])
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+ # Flushing again should do nothing as it was just flushed
+ self.aggregated_progress_callback.flush()
+ self.assertEqual(self.aggregated_amounts, [under_threshold_amount])
+
+
+class TestInterruptReader(BaseUploadTest):
+ def test_read_raises_exception(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Read some bytes to show it can be read.
+ self.assertEqual(reader.read(1), self.content[0:1])
+ # Then set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError())
+ # The next read should have the exception propograte
+ with self.assertRaises(InterruptionError):
+ reader.read()
+
+ def test_seek(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Ensure it can seek correctly
+ reader.seek(1)
+ self.assertEqual(reader.read(1), self.content[1:2])
+
+ def test_tell(self):
+ with open(self.filename, 'rb') as f:
+ reader = InterruptReader(f, self.transfer_coordinator)
+ # Ensure it can tell correctly
+ reader.seek(1)
+ self.assertEqual(reader.tell(), 1)
+
+
+class BaseUploadInputManagerTest(BaseUploadTest):
+ def setUp(self):
+ super().setUp()
+ self.osutil = OSUtils()
+ self.config = TransferConfig()
+ self.recording_subscriber = RecordingSubscriber()
+ self.subscribers.append(self.recording_subscriber)
+
+ def _get_expected_body_for_part(self, part_number):
+ # A helper method for retrieving the expected body for a specific
+ # part number of the data
+ total_size = len(self.content)
+ chunk_size = self.config.multipart_chunksize
+ start_index = (part_number - 1) * chunk_size
+ end_index = part_number * chunk_size
+ if end_index >= total_size:
+ return self.content[start_index:]
+ return self.content[start_index:end_index]
+
+
+class TestUploadFilenameInputManager(BaseUploadInputManagerTest):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadFilenameInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.call_args = CallArgs(
+ fileobj=self.filename, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def test_is_compatible(self):
+ self.assertTrue(
+ self.upload_input_manager.is_compatible(
+ self.future.meta.call_args.fileobj
+ )
+ )
+
+ def test_stores_bodies_in_memory_put_object(self):
+ self.assertFalse(
+ self.upload_input_manager.stores_body_in_memory('put_object')
+ )
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertFalse(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_provide_transfer_size(self):
+ self.upload_input_manager.provide_transfer_size(self.future)
+ # The provided file size should be equal to size of the contents of
+ # the file.
+ self.assertEqual(self.future.meta.size, len(self.content))
+
+ def test_requires_multipart_upload(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ # With the default multipart threshold, the length of the content
+ # should be smaller than the threshold thus not requiring a multipart
+ # transfer.
+ self.assertFalse(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+ # Decreasing the threshold to that of the length of the content of
+ # the file should trigger the need for a multipart upload.
+ self.config.multipart_threshold = len(self.content)
+ self.assertTrue(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+
+ def test_get_put_object_body(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+ read_file_chunk.enable_callback()
+ # The file-like object provided back should be the same as the content
+ # of the file.
+ with read_file_chunk:
+ self.assertEqual(read_file_chunk.read(), self.content)
+ # The file-like object should also have been wrapped with the
+ # on_queued callbacks to track the amount of bytes being transferred.
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ def test_get_put_object_body_is_interruptable(self):
+ self.future.meta.provide_transfer_size(len(self.content))
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+
+ # Set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError)
+ # Ensure the returned read file chunk can be interrupted with that
+ # error.
+ with self.assertRaises(InterruptionError):
+ read_file_chunk.read()
+
+ def test_yield_upload_part_bodies(self):
+ # Adjust the chunk size to something more grainular for testing.
+ self.config.multipart_chunksize = 4
+ self.future.meta.provide_transfer_size(len(self.content))
+
+ # Get an iterator that will yield all of the bodies and their
+ # respective part number.
+ part_iterator = self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+ expected_part_number = 1
+ for part_number, read_file_chunk in part_iterator:
+ # Ensure that the part number is as expected
+ self.assertEqual(part_number, expected_part_number)
+ read_file_chunk.enable_callback()
+ # Ensure that the body is correct for that part.
+ with read_file_chunk:
+ self.assertEqual(
+ read_file_chunk.read(),
+ self._get_expected_body_for_part(part_number),
+ )
+ expected_part_number += 1
+
+ # All of the file-like object should also have been wrapped with the
+ # on_queued callbacks to track the amount of bytes being transferred.
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), len(self.content)
+ )
+
+ def test_yield_upload_part_bodies_are_interruptable(self):
+ # Adjust the chunk size to something more grainular for testing.
+ self.config.multipart_chunksize = 4
+ self.future.meta.provide_transfer_size(len(self.content))
+
+ # Get an iterator that will yield all of the bodies and their
+ # respective part number.
+ part_iterator = self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+
+ # Set an exception in the transfer coordinator
+ self.transfer_coordinator.set_exception(InterruptionError)
+ for _, read_file_chunk in part_iterator:
+ # Ensure that each read file chunk yielded can be interrupted
+ # with that error.
+ with self.assertRaises(InterruptionError):
+ read_file_chunk.read()
+
+
+class TestUploadSeekableInputManager(TestUploadFilenameInputManager):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadSeekableInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.fileobj = open(self.filename, 'rb')
+ self.call_args = CallArgs(
+ fileobj=self.fileobj, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def tearDown(self):
+ self.fileobj.close()
+ super().tearDown()
+
+ def test_is_compatible_bytes_io(self):
+ self.assertTrue(self.upload_input_manager.is_compatible(BytesIO()))
+
+ def test_not_compatible_for_non_filelike_obj(self):
+ self.assertFalse(self.upload_input_manager.is_compatible(object()))
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_get_put_object_body(self):
+ start_pos = 3
+ self.fileobj.seek(start_pos)
+ adjusted_size = len(self.content) - start_pos
+ self.future.meta.provide_transfer_size(adjusted_size)
+ read_file_chunk = self.upload_input_manager.get_put_object_body(
+ self.future
+ )
+
+ read_file_chunk.enable_callback()
+ # The fact that the file was seeked to start should be taken into
+ # account in length and content for the read file chunk.
+ with read_file_chunk:
+ self.assertEqual(len(read_file_chunk), adjusted_size)
+ self.assertEqual(read_file_chunk.read(), self.content[start_pos:])
+ self.assertEqual(
+ self.recording_subscriber.calculate_bytes_seen(), adjusted_size
+ )
+
+
+class TestUploadNonSeekableInputManager(TestUploadFilenameInputManager):
+ def setUp(self):
+ super().setUp()
+ self.upload_input_manager = UploadNonSeekableInputManager(
+ self.osutil, self.transfer_coordinator
+ )
+ self.fileobj = NonSeekableReader(self.content)
+ self.call_args = CallArgs(
+ fileobj=self.fileobj, subscribers=self.subscribers
+ )
+ self.future = self.get_transfer_future(self.call_args)
+
+ def assert_multipart_parts(self):
+ """
+ Asserts that the input manager will generate a multipart upload
+ and that each part is in order and the correct size.
+ """
+ # Assert that a multipart upload is required.
+ self.assertTrue(
+ self.upload_input_manager.requires_multipart_upload(
+ self.future, self.config
+ )
+ )
+
+ # Get a list of all the parts that would be sent.
+ parts = list(
+ self.upload_input_manager.yield_upload_part_bodies(
+ self.future, self.config.multipart_chunksize
+ )
+ )
+
+ # Assert that the actual number of parts is what we would expect
+ # based on the configuration.
+ size = self.config.multipart_chunksize
+ num_parts = math.ceil(len(self.content) / size)
+ self.assertEqual(len(parts), num_parts)
+
+ # Run for every part but the last part.
+ for i, part in enumerate(parts[:-1]):
+ # Assert the part number is correct.
+ self.assertEqual(part[0], i + 1)
+ # Assert the part contains the right amount of data.
+ data = part[1].read()
+ self.assertEqual(len(data), size)
+
+ # Assert that the last part is the correct size.
+ expected_final_size = len(self.content) - ((num_parts - 1) * size)
+ final_part = parts[-1]
+ self.assertEqual(len(final_part[1].read()), expected_final_size)
+
+ # Assert that the last part has the correct part number.
+ self.assertEqual(final_part[0], len(parts))
+
+ def test_provide_transfer_size(self):
+ self.upload_input_manager.provide_transfer_size(self.future)
+ # There is no way to get the size without reading the entire body.
+ self.assertEqual(self.future.meta.size, None)
+
+ def test_stores_bodies_in_memory_upload_part(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('upload_part')
+ )
+
+ def test_stores_bodies_in_memory_put_object(self):
+ self.assertTrue(
+ self.upload_input_manager.stores_body_in_memory('put_object')
+ )
+
+ def test_initial_data_parts_threshold_lesser(self):
+ # threshold < size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 2
+ self.assert_multipart_parts()
+
+ def test_initial_data_parts_threshold_equal(self):
+ # threshold == size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 4
+ self.assert_multipart_parts()
+
+ def test_initial_data_parts_threshold_greater(self):
+ # threshold > size
+ self.config.multipart_chunksize = 4
+ self.config.multipart_threshold = 8
+ self.assert_multipart_parts()
+
+
+class TestUploadSubmissionTask(BaseSubmissionTaskTest):
+ def setUp(self):
+ super().setUp()
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'myfile')
+ self.content = b'0' * (MIN_UPLOAD_CHUNKSIZE * 3)
+ self.config.multipart_chunksize = MIN_UPLOAD_CHUNKSIZE
+ self.config.multipart_threshold = MIN_UPLOAD_CHUNKSIZE * 5
+
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+
+ self.bucket = 'mybucket'
+ self.key = 'mykey'
+ self.extra_args = {}
+ self.subscribers = []
+
+ # A list to keep track of all of the bodies sent over the wire
+ # and their order.
+ self.sent_bodies = []
+ self.client.meta.events.register(
+ 'before-parameter-build.s3.*', self.collect_body
+ )
+
+ self.call_args = self.get_call_args()
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs = {
+ 'client': self.client,
+ 'config': self.config,
+ 'osutil': self.osutil,
+ 'request_executor': self.executor,
+ 'transfer_future': self.transfer_future,
+ }
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+
+ def tearDown(self):
+ super().tearDown()
+ shutil.rmtree(self.tempdir)
+
+ def collect_body(self, params, **kwargs):
+ if 'Body' in params:
+ self.sent_bodies.append(params['Body'].read())
+
+ def get_call_args(self, **kwargs):
+ default_call_args = {
+ 'fileobj': self.filename,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': self.extra_args,
+ 'subscribers': self.subscribers,
+ }
+ default_call_args.update(kwargs)
+ return CallArgs(**default_call_args)
+
+ def add_multipart_upload_stubbed_responses(self):
+ self.stubber.add_response(
+ method='create_multipart_upload',
+ service_response={'UploadId': 'my-id'},
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-1'}
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-2'}
+ )
+ self.stubber.add_response(
+ method='upload_part', service_response={'ETag': 'etag-3'}
+ )
+ self.stubber.add_response(
+ method='complete_multipart_upload', service_response={}
+ )
+
+ def wrap_executor_in_recorder(self):
+ self.executor = RecordingExecutor(self.executor)
+ self.submission_main_kwargs['request_executor'] = self.executor
+
+ def use_fileobj_in_call_args(self, fileobj):
+ self.call_args = self.get_call_args(fileobj=fileobj)
+ self.transfer_future = self.get_transfer_future(self.call_args)
+ self.submission_main_kwargs['transfer_future'] = self.transfer_future
+
+ def assert_tag_value_for_put_object(self, tag_value):
+ self.assertEqual(self.executor.submissions[0]['tag'], tag_value)
+
+ def assert_tag_value_for_upload_parts(self, tag_value):
+ for submission in self.executor.submissions[1:-1]:
+ self.assertEqual(submission['tag'], tag_value)
+
+ def test_provide_file_size_on_put(self):
+ self.call_args.subscribers.append(FileSizeProvider(len(self.content)))
+ self.stubber.add_response(
+ method='put_object',
+ service_response={},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ },
+ )
+
+ # With this submitter, it will fail to stat the file if a transfer
+ # size is not provided.
+ self.submission_main_kwargs['osutil'] = OSUtilsExceptionOnFileSize()
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(self.sent_bodies, [self.content])
+
+ def test_submits_no_tag_for_put_object_filename(self):
+ self.wrap_executor_in_recorder()
+ self.stubber.add_response('put_object', {})
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_value_for_put_object(None)
+
+ def test_submits_no_tag_for_multipart_filename(self):
+ self.wrap_executor_in_recorder()
+
+ # Set up for a multipart upload.
+ self.add_multipart_upload_stubbed_responses()
+ self.config.multipart_threshold = 1
+
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit any of the upload part tasks were
+ # were associated when submitted to the executor
+ self.assert_tag_value_for_upload_parts(None)
+
+ def test_submits_no_tag_for_put_object_fileobj(self):
+ self.wrap_executor_in_recorder()
+ self.stubber.add_response('put_object', {})
+
+ with open(self.filename, 'rb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure no tag to limit that task specifically was not associated
+ # to that task submission.
+ self.assert_tag_value_for_put_object(None)
+
+ def test_submits_tag_for_multipart_fileobj(self):
+ self.wrap_executor_in_recorder()
+
+ # Set up for a multipart upload.
+ self.add_multipart_upload_stubbed_responses()
+ self.config.multipart_threshold = 1
+
+ with open(self.filename, 'rb') as f:
+ self.use_fileobj_in_call_args(f)
+ self.submission_task = self.get_task(
+ UploadSubmissionTask, main_kwargs=self.submission_main_kwargs
+ )
+ self.submission_task()
+ self.transfer_future.result()
+ self.stubber.assert_no_pending_responses()
+
+ # Make sure tags to limit all of the upload part tasks were
+ # were associated when submitted to the executor as these tasks will
+ # have chunks of data stored with them in memory.
+ self.assert_tag_value_for_upload_parts(IN_MEMORY_UPLOAD_TAG)
+
+
+class TestPutObjectTask(BaseUploadTest):
+ def test_main(self):
+ extra_args = {'Metadata': {'foo': 'bar'}}
+ with open(self.filename, 'rb') as fileobj:
+ task = self.get_task(
+ PutObjectTask,
+ main_kwargs={
+ 'client': self.client,
+ 'fileobj': fileobj,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='put_object',
+ service_response={},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'Metadata': {'foo': 'bar'},
+ },
+ )
+ task()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(self.sent_bodies, [self.content])
+
+
+class TestUploadPartTask(BaseUploadTest):
+ def test_main(self):
+ extra_args = {'RequestPayer': 'requester'}
+ upload_id = 'my-id'
+ part_number = 1
+ etag = 'foo'
+ with open(self.filename, 'rb') as fileobj:
+ task = self.get_task(
+ UploadPartTask,
+ main_kwargs={
+ 'client': self.client,
+ 'fileobj': fileobj,
+ 'bucket': self.bucket,
+ 'key': self.key,
+ 'upload_id': upload_id,
+ 'part_number': part_number,
+ 'extra_args': extra_args,
+ },
+ )
+ self.stubber.add_response(
+ method='upload_part',
+ service_response={'ETag': etag},
+ expected_params={
+ 'Body': ANY,
+ 'Bucket': self.bucket,
+ 'Key': self.key,
+ 'UploadId': upload_id,
+ 'PartNumber': part_number,
+ 'RequestPayer': 'requester',
+ },
+ )
+ rval = task()
+ self.stubber.assert_no_pending_responses()
+ self.assertEqual(rval, {'ETag': etag, 'PartNumber': part_number})
+ self.assertEqual(self.sent_bodies, [self.content])
diff --git a/contrib/python/s3transfer/py3/tests/unit/test_utils.py b/contrib/python/s3transfer/py3/tests/unit/test_utils.py
index b7a51c7845..a1ff904e7a 100644
--- a/contrib/python/s3transfer/py3/tests/unit/test_utils.py
+++ b/contrib/python/s3transfer/py3/tests/unit/test_utils.py
@@ -1,1189 +1,1189 @@
-# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"). You
-# may not use this file except in compliance with the License. A copy of
-# the License is located at
-#
-# http://aws.amazon.com/apache2.0/
-#
-# or in the "license" file accompanying this file. This file is
-# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
-# ANY KIND, either express or implied. See the License for the specific
-# language governing permissions and limitations under the License.
-import io
-import os.path
-import random
-import re
-import shutil
-import tempfile
-import threading
-import time
-from io import BytesIO, StringIO
-
-from s3transfer.futures import TransferFuture, TransferMeta
-from s3transfer.utils import (
- MAX_PARTS,
- MAX_SINGLE_UPLOAD_SIZE,
- MIN_UPLOAD_CHUNKSIZE,
- CallArgs,
- ChunksizeAdjuster,
- CountCallbackInvoker,
- DeferredOpenFile,
- FunctionContainer,
- NoResourcesAvailable,
- OSUtils,
- ReadFileChunk,
- SlidingWindowSemaphore,
- StreamReaderProgress,
- TaskSemaphore,
- calculate_num_parts,
- calculate_range_parameter,
- get_callbacks,
- get_filtered_dict,
- invoke_progress_callbacks,
- random_file_extension,
-)
-from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
-
-
-class TestGetCallbacks(unittest.TestCase):
- def setUp(self):
- self.subscriber = RecordingSubscriber()
- self.second_subscriber = RecordingSubscriber()
- self.call_args = CallArgs(
- subscribers=[self.subscriber, self.second_subscriber]
- )
- self.transfer_meta = TransferMeta(self.call_args)
- self.transfer_future = TransferFuture(self.transfer_meta)
-
- def test_get_callbacks(self):
- callbacks = get_callbacks(self.transfer_future, 'queued')
- # Make sure two callbacks were added as both subscribers had
- # an on_queued method.
- self.assertEqual(len(callbacks), 2)
-
- # Ensure that the callback was injected with the future by calling
- # one of them and checking that the future was used in the call.
- callbacks[0]()
- self.assertEqual(
- self.subscriber.on_queued_calls, [{'future': self.transfer_future}]
- )
-
- def test_get_callbacks_for_missing_type(self):
- callbacks = get_callbacks(self.transfer_future, 'fake_state')
- # There should be no callbacks as the subscribers will not have the
- # on_fake_state method
- self.assertEqual(len(callbacks), 0)
-
-
-class TestGetFilteredDict(unittest.TestCase):
- def test_get_filtered_dict(self):
- original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'}
- whitelist = ['Include']
- self.assertEqual(
- get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'}
- )
-
-
-class TestCallArgs(unittest.TestCase):
- def test_call_args(self):
- call_args = CallArgs(foo='bar', biz='baz')
- self.assertEqual(call_args.foo, 'bar')
- self.assertEqual(call_args.biz, 'baz')
-
-
-class TestFunctionContainer(unittest.TestCase):
- def get_args_kwargs(self, *args, **kwargs):
- return args, kwargs
-
- def test_call(self):
- func_container = FunctionContainer(
- self.get_args_kwargs, 'foo', bar='baz'
- )
- self.assertEqual(func_container(), (('foo',), {'bar': 'baz'}))
-
- def test_repr(self):
- func_container = FunctionContainer(
- self.get_args_kwargs, 'foo', bar='baz'
- )
- self.assertEqual(
- str(func_container),
- 'Function: {} with args {} and kwargs {}'.format(
- self.get_args_kwargs, ('foo',), {'bar': 'baz'}
- ),
- )
-
-
-class TestCountCallbackInvoker(unittest.TestCase):
- def invoke_callback(self):
- self.ref_results.append('callback invoked')
-
- def assert_callback_invoked(self):
- self.assertEqual(self.ref_results, ['callback invoked'])
-
- def assert_callback_not_invoked(self):
- self.assertEqual(self.ref_results, [])
-
- def setUp(self):
- self.ref_results = []
- self.invoker = CountCallbackInvoker(self.invoke_callback)
-
- def test_increment(self):
- self.invoker.increment()
- self.assertEqual(self.invoker.current_count, 1)
-
- def test_decrement(self):
- self.invoker.increment()
- self.invoker.increment()
- self.invoker.decrement()
- self.assertEqual(self.invoker.current_count, 1)
-
- def test_count_cannot_go_below_zero(self):
- with self.assertRaises(RuntimeError):
- self.invoker.decrement()
-
- def test_callback_invoked_only_once_finalized(self):
- self.invoker.increment()
- self.invoker.decrement()
- self.assert_callback_not_invoked()
- self.invoker.finalize()
- # Callback should only be invoked once finalized
- self.assert_callback_invoked()
-
- def test_callback_invoked_after_finalizing_and_count_reaching_zero(self):
- self.invoker.increment()
- self.invoker.finalize()
- # Make sure that it does not get invoked immediately after
- # finalizing as the count is currently one
- self.assert_callback_not_invoked()
- self.invoker.decrement()
- self.assert_callback_invoked()
-
- def test_cannot_increment_after_finalization(self):
- self.invoker.finalize()
- with self.assertRaises(RuntimeError):
- self.invoker.increment()
-
-
-class TestRandomFileExtension(unittest.TestCase):
- def test_has_proper_length(self):
- self.assertEqual(len(random_file_extension(num_digits=4)), 4)
-
-
-class TestInvokeProgressCallbacks(unittest.TestCase):
- def test_invoke_progress_callbacks(self):
- recording_subscriber = RecordingSubscriber()
- invoke_progress_callbacks([recording_subscriber.on_progress], 2)
- self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2)
-
- def test_invoke_progress_callbacks_with_no_progress(self):
- recording_subscriber = RecordingSubscriber()
- invoke_progress_callbacks([recording_subscriber.on_progress], 0)
- self.assertEqual(len(recording_subscriber.on_progress_calls), 0)
-
-
-class TestCalculateNumParts(unittest.TestCase):
- def test_calculate_num_parts_divisible(self):
- self.assertEqual(calculate_num_parts(size=4, part_size=2), 2)
-
- def test_calculate_num_parts_not_divisible(self):
- self.assertEqual(calculate_num_parts(size=3, part_size=2), 2)
-
-
-class TestCalculateRangeParameter(unittest.TestCase):
- def setUp(self):
- self.part_size = 5
- self.part_index = 1
- self.num_parts = 3
-
- def test_calculate_range_paramter(self):
- range_val = calculate_range_parameter(
- self.part_size, self.part_index, self.num_parts
- )
- self.assertEqual(range_val, 'bytes=5-9')
-
- def test_last_part_with_no_total_size(self):
- range_val = calculate_range_parameter(
- self.part_size, self.part_index, num_parts=2
- )
- self.assertEqual(range_val, 'bytes=5-')
-
- def test_last_part_with_total_size(self):
- range_val = calculate_range_parameter(
- self.part_size, self.part_index, num_parts=2, total_size=8
- )
- self.assertEqual(range_val, 'bytes=5-7')
-
-
-class BaseUtilsTest(unittest.TestCase):
- def setUp(self):
- self.tempdir = tempfile.mkdtemp()
- self.filename = os.path.join(self.tempdir, 'foo')
- self.content = b'abc'
- with open(self.filename, 'wb') as f:
- f.write(self.content)
- self.amounts_seen = []
- self.num_close_callback_calls = 0
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
-
- def callback(self, bytes_transferred):
- self.amounts_seen.append(bytes_transferred)
-
- def close_callback(self):
- self.num_close_callback_calls += 1
-
-
-class TestOSUtils(BaseUtilsTest):
- def test_get_file_size(self):
- self.assertEqual(
- OSUtils().get_file_size(self.filename), len(self.content)
- )
-
- def test_open_file_chunk_reader(self):
- reader = OSUtils().open_file_chunk_reader(
- self.filename, 0, 3, [self.callback]
- )
-
- # The returned reader should be a ReadFileChunk.
- self.assertIsInstance(reader, ReadFileChunk)
- # The content of the reader should be correct.
- self.assertEqual(reader.read(), self.content)
- # Callbacks should be disabled depspite being passed in.
- self.assertEqual(self.amounts_seen, [])
-
- def test_open_file_chunk_reader_from_fileobj(self):
- with open(self.filename, 'rb') as f:
- reader = OSUtils().open_file_chunk_reader_from_fileobj(
- f, len(self.content), len(self.content), [self.callback]
- )
-
- # The returned reader should be a ReadFileChunk.
- self.assertIsInstance(reader, ReadFileChunk)
- # The content of the reader should be correct.
- self.assertEqual(reader.read(), self.content)
- reader.close()
- # Callbacks should be disabled depspite being passed in.
- self.assertEqual(self.amounts_seen, [])
- self.assertEqual(self.num_close_callback_calls, 0)
-
- def test_open_file(self):
- fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
- self.assertTrue(hasattr(fileobj, 'write'))
-
- def test_remove_file_ignores_errors(self):
- non_existent_file = os.path.join(self.tempdir, 'no-exist')
- # This should not exist to start.
- self.assertFalse(os.path.exists(non_existent_file))
- try:
- OSUtils().remove_file(non_existent_file)
- except OSError as e:
- self.fail('OSError should have been caught: %s' % e)
-
- def test_remove_file_proxies_remove_file(self):
- OSUtils().remove_file(self.filename)
- self.assertFalse(os.path.exists(self.filename))
-
- def test_rename_file(self):
- new_filename = os.path.join(self.tempdir, 'newfoo')
- OSUtils().rename_file(self.filename, new_filename)
- self.assertFalse(os.path.exists(self.filename))
- self.assertTrue(os.path.exists(new_filename))
-
- def test_is_special_file_for_normal_file(self):
- self.assertFalse(OSUtils().is_special_file(self.filename))
-
- def test_is_special_file_for_non_existant_file(self):
- non_existant_filename = os.path.join(self.tempdir, 'no-exist')
- self.assertFalse(os.path.exists(non_existant_filename))
- self.assertFalse(OSUtils().is_special_file(non_existant_filename))
-
- def test_get_temp_filename(self):
- filename = 'myfile'
- self.assertIsNotNone(
- re.match(
- r'%s\.[0-9A-Fa-f]{8}$' % filename,
- OSUtils().get_temp_filename(filename),
- )
- )
-
- def test_get_temp_filename_len_255(self):
- filename = 'a' * 255
- temp_filename = OSUtils().get_temp_filename(filename)
- self.assertLessEqual(len(temp_filename), 255)
-
- def test_get_temp_filename_len_gt_255(self):
- filename = 'a' * 280
- temp_filename = OSUtils().get_temp_filename(filename)
- self.assertLessEqual(len(temp_filename), 255)
-
- def test_allocate(self):
- truncate_size = 1
- OSUtils().allocate(self.filename, truncate_size)
- with open(self.filename, 'rb') as f:
- self.assertEqual(len(f.read()), truncate_size)
-
- @mock.patch('s3transfer.utils.fallocate')
- def test_allocate_with_io_error(self, mock_fallocate):
- mock_fallocate.side_effect = IOError()
- with self.assertRaises(IOError):
- OSUtils().allocate(self.filename, 1)
- self.assertFalse(os.path.exists(self.filename))
-
- @mock.patch('s3transfer.utils.fallocate')
- def test_allocate_with_os_error(self, mock_fallocate):
- mock_fallocate.side_effect = OSError()
- with self.assertRaises(OSError):
- OSUtils().allocate(self.filename, 1)
- self.assertFalse(os.path.exists(self.filename))
-
-
-class TestDeferredOpenFile(BaseUtilsTest):
- def setUp(self):
- super().setUp()
- self.filename = os.path.join(self.tempdir, 'foo')
- self.contents = b'my contents'
- with open(self.filename, 'wb') as f:
- f.write(self.contents)
- self.deferred_open_file = DeferredOpenFile(
- self.filename, open_function=self.recording_open_function
- )
- self.open_call_args = []
-
- def tearDown(self):
- self.deferred_open_file.close()
- super().tearDown()
-
- def recording_open_function(self, filename, mode):
- self.open_call_args.append((filename, mode))
- return open(filename, mode)
-
- def open_nonseekable(self, filename, mode):
- self.open_call_args.append((filename, mode))
- return NonSeekableWriter(BytesIO(self.content))
-
- def test_instantiation_does_not_open_file(self):
- DeferredOpenFile(
- self.filename, open_function=self.recording_open_function
- )
- self.assertEqual(len(self.open_call_args), 0)
-
- def test_name(self):
- self.assertEqual(self.deferred_open_file.name, self.filename)
-
- def test_read(self):
- content = self.deferred_open_file.read(2)
- self.assertEqual(content, self.contents[0:2])
- content = self.deferred_open_file.read(2)
- self.assertEqual(content, self.contents[2:4])
- self.assertEqual(len(self.open_call_args), 1)
-
- def test_write(self):
- self.deferred_open_file = DeferredOpenFile(
- self.filename,
- mode='wb',
- open_function=self.recording_open_function,
- )
-
- write_content = b'foo'
- self.deferred_open_file.write(write_content)
- self.deferred_open_file.write(write_content)
- self.deferred_open_file.close()
- # Both of the writes should now be in the file.
- with open(self.filename, 'rb') as f:
- self.assertEqual(f.read(), write_content * 2)
- # Open should have only been called once.
- self.assertEqual(len(self.open_call_args), 1)
-
- def test_seek(self):
- self.deferred_open_file.seek(2)
- content = self.deferred_open_file.read(2)
- self.assertEqual(content, self.contents[2:4])
- self.assertEqual(len(self.open_call_args), 1)
-
- def test_open_does_not_seek_with_zero_start_byte(self):
- self.deferred_open_file = DeferredOpenFile(
- self.filename,
- mode='wb',
- start_byte=0,
- open_function=self.open_nonseekable,
- )
-
- try:
- # If this seeks, an UnsupportedOperation error will be raised.
- self.deferred_open_file.write(b'data')
- except io.UnsupportedOperation:
- self.fail('DeferredOpenFile seeked upon opening')
-
- def test_open_seeks_with_nonzero_start_byte(self):
- self.deferred_open_file = DeferredOpenFile(
- self.filename,
- mode='wb',
- start_byte=5,
- open_function=self.open_nonseekable,
- )
-
- # Since a non-seekable file is being opened, calling Seek will raise
- # an UnsupportedOperation error.
- with self.assertRaises(io.UnsupportedOperation):
- self.deferred_open_file.write(b'data')
-
- def test_tell(self):
- self.deferred_open_file.tell()
- # tell() should not have opened the file if it has not been seeked
- # or read because we know the start bytes upfront.
- self.assertEqual(len(self.open_call_args), 0)
-
- self.deferred_open_file.seek(2)
- self.assertEqual(self.deferred_open_file.tell(), 2)
- self.assertEqual(len(self.open_call_args), 1)
-
- def test_open_args(self):
- self.deferred_open_file = DeferredOpenFile(
- self.filename,
- mode='ab+',
- open_function=self.recording_open_function,
- )
- # Force an open
- self.deferred_open_file.write(b'data')
- self.assertEqual(len(self.open_call_args), 1)
- self.assertEqual(self.open_call_args[0], (self.filename, 'ab+'))
-
- def test_context_handler(self):
- with self.deferred_open_file:
- self.assertEqual(len(self.open_call_args), 1)
-
-
-class TestReadFileChunk(BaseUtilsTest):
- def test_read_entire_chunk(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=0, chunk_size=3
- )
- self.assertEqual(chunk.read(), b'one')
- self.assertEqual(chunk.read(), b'')
-
- def test_read_with_amount_size(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=11, chunk_size=4
- )
- self.assertEqual(chunk.read(1), b'f')
- self.assertEqual(chunk.read(1), b'o')
- self.assertEqual(chunk.read(1), b'u')
- self.assertEqual(chunk.read(1), b'r')
- self.assertEqual(chunk.read(1), b'')
-
- def test_reset_stream_emulation(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=11, chunk_size=4
- )
- self.assertEqual(chunk.read(), b'four')
- chunk.seek(0)
- self.assertEqual(chunk.read(), b'four')
-
- def test_read_past_end_of_file(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=36, chunk_size=100000
- )
- self.assertEqual(chunk.read(), b'ten')
- self.assertEqual(chunk.read(), b'')
- self.assertEqual(len(chunk), 3)
-
- def test_tell_and_seek(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'onetwothreefourfivesixseveneightnineten')
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=36, chunk_size=100000
- )
- self.assertEqual(chunk.tell(), 0)
- self.assertEqual(chunk.read(), b'ten')
- self.assertEqual(chunk.tell(), 3)
- chunk.seek(0)
- self.assertEqual(chunk.tell(), 0)
- chunk.seek(1, whence=1)
- self.assertEqual(chunk.tell(), 1)
- chunk.seek(-1, whence=1)
- self.assertEqual(chunk.tell(), 0)
- chunk.seek(-1, whence=2)
- self.assertEqual(chunk.tell(), 2)
-
- def test_tell_and_seek_boundaries(self):
- # Test to ensure ReadFileChunk behaves the same as the
- # Python standard library around seeking and reading out
- # of bounds in a file object.
- data = b'abcdefghij12345678klmnopqrst'
- start_pos = 10
- chunk_size = 8
-
- # Create test file
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(data)
-
- # ReadFileChunk should be a substring of only numbers
- file_objects = [
- ReadFileChunk.from_filename(
- filename, start_byte=start_pos, chunk_size=chunk_size
- )
- ]
-
- # Uncomment next line to validate we match Python's io.BytesIO
- # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size]))
-
- for obj in file_objects:
- self._assert_whence_start_behavior(obj)
- self._assert_whence_end_behavior(obj)
- self._assert_whence_relative_behavior(obj)
- self._assert_boundary_behavior(obj)
-
- def _assert_whence_start_behavior(self, file_obj):
- self.assertEqual(file_obj.tell(), 0)
-
- file_obj.seek(1, 0)
- self.assertEqual(file_obj.tell(), 1)
-
- file_obj.seek(1)
- self.assertEqual(file_obj.tell(), 1)
- self.assertEqual(file_obj.read(), b'2345678')
-
- file_obj.seek(3, 0)
- self.assertEqual(file_obj.tell(), 3)
-
- file_obj.seek(0, 0)
- self.assertEqual(file_obj.tell(), 0)
-
- def _assert_whence_relative_behavior(self, file_obj):
- self.assertEqual(file_obj.tell(), 0)
-
- file_obj.seek(2, 1)
- self.assertEqual(file_obj.tell(), 2)
-
- file_obj.seek(1, 1)
- self.assertEqual(file_obj.tell(), 3)
- self.assertEqual(file_obj.read(), b'45678')
-
- file_obj.seek(20, 1)
- self.assertEqual(file_obj.tell(), 28)
-
- file_obj.seek(-30, 1)
- self.assertEqual(file_obj.tell(), 0)
- self.assertEqual(file_obj.read(), b'12345678')
-
- file_obj.seek(-8, 1)
- self.assertEqual(file_obj.tell(), 0)
-
- def _assert_whence_end_behavior(self, file_obj):
- self.assertEqual(file_obj.tell(), 0)
-
- file_obj.seek(-1, 2)
- self.assertEqual(file_obj.tell(), 7)
-
- file_obj.seek(1, 2)
- self.assertEqual(file_obj.tell(), 9)
-
- file_obj.seek(3, 2)
- self.assertEqual(file_obj.tell(), 11)
- self.assertEqual(file_obj.read(), b'')
-
- file_obj.seek(-15, 2)
- self.assertEqual(file_obj.tell(), 0)
- self.assertEqual(file_obj.read(), b'12345678')
-
- file_obj.seek(-8, 2)
- self.assertEqual(file_obj.tell(), 0)
-
- def _assert_boundary_behavior(self, file_obj):
- # Verify we're at the start
- self.assertEqual(file_obj.tell(), 0)
-
- # Verify we can't move backwards beyond start of file
- file_obj.seek(-10, 1)
- self.assertEqual(file_obj.tell(), 0)
-
- # Verify we *can* move after end of file, but return nothing
- file_obj.seek(10, 2)
- self.assertEqual(file_obj.tell(), 18)
- self.assertEqual(file_obj.read(), b'')
- self.assertEqual(file_obj.read(10), b'')
-
- # Verify we can partially rewind
- file_obj.seek(-12, 1)
- self.assertEqual(file_obj.tell(), 6)
- self.assertEqual(file_obj.read(), b'78')
- self.assertEqual(file_obj.tell(), 8)
-
- # Verify we can rewind to start
- file_obj.seek(0)
- self.assertEqual(file_obj.tell(), 0)
-
- def test_file_chunk_supports_context_manager(self):
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(b'abc')
- with ReadFileChunk.from_filename(
- filename, start_byte=0, chunk_size=2
- ) as chunk:
- val = chunk.read()
- self.assertEqual(val, b'ab')
-
- def test_iter_is_always_empty(self):
- # This tests the workaround for the httplib bug (see
- # the source for more info).
- filename = os.path.join(self.tempdir, 'foo')
- open(filename, 'wb').close()
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=0, chunk_size=10
- )
- self.assertEqual(list(chunk), [])
-
- def test_callback_is_invoked_on_read(self):
- chunk = ReadFileChunk.from_filename(
- self.filename,
- start_byte=0,
- chunk_size=3,
- callbacks=[self.callback],
- )
- chunk.read(1)
- chunk.read(1)
- chunk.read(1)
- self.assertEqual(self.amounts_seen, [1, 1, 1])
-
- def test_all_callbacks_invoked_on_read(self):
- chunk = ReadFileChunk.from_filename(
- self.filename,
- start_byte=0,
- chunk_size=3,
- callbacks=[self.callback, self.callback],
- )
- chunk.read(1)
- chunk.read(1)
- chunk.read(1)
- # The list should be twice as long because there are two callbacks
- # recording the amount read.
- self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1])
-
- def test_callback_can_be_disabled(self):
- chunk = ReadFileChunk.from_filename(
- self.filename,
- start_byte=0,
- chunk_size=3,
- callbacks=[self.callback],
- )
- chunk.disable_callback()
- # Now reading from the ReadFileChunk should not invoke
- # the callback.
- chunk.read()
- self.assertEqual(self.amounts_seen, [])
-
- def test_callback_will_also_be_triggered_by_seek(self):
- chunk = ReadFileChunk.from_filename(
- self.filename,
- start_byte=0,
- chunk_size=3,
- callbacks=[self.callback],
- )
- chunk.read(2)
- chunk.seek(0)
- chunk.read(2)
- chunk.seek(1)
- chunk.read(2)
- self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
-
- def test_callback_triggered_by_out_of_bound_seeks(self):
- data = b'abcdefghij1234567890klmnopqr'
-
- # Create test file
- filename = os.path.join(self.tempdir, 'foo')
- with open(filename, 'wb') as f:
- f.write(data)
- chunk = ReadFileChunk.from_filename(
- filename, start_byte=10, chunk_size=10, callbacks=[self.callback]
- )
-
- # Seek calls that generate "0" progress are skipped by
- # invoke_progress_callbacks and won't appear in the list.
- expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10]
-
- self._assert_out_of_bound_start_seek(chunk, expected_callback_prog)
- self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog)
- self._assert_out_of_bound_end_seek(chunk, expected_callback_prog)
-
- def _assert_out_of_bound_start_seek(self, chunk, expected):
- # clear amounts_seen
- self.amounts_seen = []
- self.assertEqual(self.amounts_seen, [])
-
- # (position, change)
- chunk.seek(20) # (20, 10)
- chunk.seek(5) # (5, -5)
- chunk.seek(20) # (20, 5)
- chunk.seek(9) # (9, -1)
- chunk.seek(20) # (20, 1)
- chunk.seek(11) # (11, 0)
- chunk.seek(20) # (20, 0)
- chunk.seek(9) # (9, -1)
- chunk.seek(20) # (20, 1)
- chunk.seek(5) # (5, -5)
- chunk.seek(20) # (20, 5)
- chunk.seek(0) # (0, -10)
- chunk.seek(0) # (0, 0)
-
- self.assertEqual(self.amounts_seen, expected)
-
- def _assert_out_of_bound_relative_seek(self, chunk, expected):
- # clear amounts_seen
- self.amounts_seen = []
- self.assertEqual(self.amounts_seen, [])
-
- # (position, change)
- chunk.seek(20, 1) # (20, 10)
- chunk.seek(-15, 1) # (5, -5)
- chunk.seek(15, 1) # (20, 5)
- chunk.seek(-11, 1) # (9, -1)
- chunk.seek(11, 1) # (20, 1)
- chunk.seek(-9, 1) # (11, 0)
- chunk.seek(9, 1) # (20, 0)
- chunk.seek(-11, 1) # (9, -1)
- chunk.seek(11, 1) # (20, 1)
- chunk.seek(-15, 1) # (5, -5)
- chunk.seek(15, 1) # (20, 5)
- chunk.seek(-20, 1) # (0, -10)
- chunk.seek(-1000, 1) # (0, 0)
-
- self.assertEqual(self.amounts_seen, expected)
-
- def _assert_out_of_bound_end_seek(self, chunk, expected):
- # clear amounts_seen
- self.amounts_seen = []
- self.assertEqual(self.amounts_seen, [])
-
- # (position, change)
- chunk.seek(10, 2) # (20, 10)
- chunk.seek(-5, 2) # (5, -5)
- chunk.seek(10, 2) # (20, 5)
- chunk.seek(-1, 2) # (9, -1)
- chunk.seek(10, 2) # (20, 1)
- chunk.seek(1, 2) # (11, 0)
- chunk.seek(10, 2) # (20, 0)
- chunk.seek(-1, 2) # (9, -1)
- chunk.seek(10, 2) # (20, 1)
- chunk.seek(-5, 2) # (5, -5)
- chunk.seek(10, 2) # (20, 5)
- chunk.seek(-10, 2) # (0, -10)
- chunk.seek(-1000, 2) # (0, 0)
-
- self.assertEqual(self.amounts_seen, expected)
-
- def test_close_callbacks(self):
- with open(self.filename) as f:
- chunk = ReadFileChunk(
- f,
- chunk_size=1,
- full_file_size=3,
- close_callbacks=[self.close_callback],
- )
- chunk.close()
- self.assertEqual(self.num_close_callback_calls, 1)
-
- def test_close_callbacks_when_not_enabled(self):
- with open(self.filename) as f:
- chunk = ReadFileChunk(
- f,
- chunk_size=1,
- full_file_size=3,
- enable_callbacks=False,
- close_callbacks=[self.close_callback],
- )
- chunk.close()
- self.assertEqual(self.num_close_callback_calls, 0)
-
- def test_close_callbacks_when_context_handler_is_used(self):
- with open(self.filename) as f:
- with ReadFileChunk(
- f,
- chunk_size=1,
- full_file_size=3,
- close_callbacks=[self.close_callback],
- ) as chunk:
- chunk.read(1)
- self.assertEqual(self.num_close_callback_calls, 1)
-
- def test_signal_transferring(self):
- chunk = ReadFileChunk.from_filename(
- self.filename,
- start_byte=0,
- chunk_size=3,
- callbacks=[self.callback],
- )
- chunk.signal_not_transferring()
- chunk.read(1)
- self.assertEqual(self.amounts_seen, [])
- chunk.signal_transferring()
- chunk.read(1)
- self.assertEqual(self.amounts_seen, [1])
-
- def test_signal_transferring_to_underlying_fileobj(self):
- underlying_stream = mock.Mock()
- underlying_stream.tell.return_value = 0
- chunk = ReadFileChunk(underlying_stream, 3, 3)
- chunk.signal_transferring()
- self.assertTrue(underlying_stream.signal_transferring.called)
-
- def test_no_call_signal_transferring_to_underlying_fileobj(self):
- underlying_stream = mock.Mock(io.RawIOBase)
- underlying_stream.tell.return_value = 0
- chunk = ReadFileChunk(underlying_stream, 3, 3)
- try:
- chunk.signal_transferring()
- except AttributeError:
- self.fail(
- 'The stream should not have tried to call signal_transferring '
- 'to the underlying stream.'
- )
-
- def test_signal_not_transferring_to_underlying_fileobj(self):
- underlying_stream = mock.Mock()
- underlying_stream.tell.return_value = 0
- chunk = ReadFileChunk(underlying_stream, 3, 3)
- chunk.signal_not_transferring()
- self.assertTrue(underlying_stream.signal_not_transferring.called)
-
- def test_no_call_signal_not_transferring_to_underlying_fileobj(self):
- underlying_stream = mock.Mock(io.RawIOBase)
- underlying_stream.tell.return_value = 0
- chunk = ReadFileChunk(underlying_stream, 3, 3)
- try:
- chunk.signal_not_transferring()
- except AttributeError:
- self.fail(
- 'The stream should not have tried to call '
- 'signal_not_transferring to the underlying stream.'
- )
-
-
-class TestStreamReaderProgress(BaseUtilsTest):
- def test_proxies_to_wrapped_stream(self):
- original_stream = StringIO('foobarbaz')
- wrapped = StreamReaderProgress(original_stream)
- self.assertEqual(wrapped.read(), 'foobarbaz')
-
- def test_callback_invoked(self):
- original_stream = StringIO('foobarbaz')
- wrapped = StreamReaderProgress(
- original_stream, [self.callback, self.callback]
- )
- self.assertEqual(wrapped.read(), 'foobarbaz')
- self.assertEqual(self.amounts_seen, [9, 9])
-
-
-class TestTaskSemaphore(unittest.TestCase):
- def setUp(self):
- self.semaphore = TaskSemaphore(1)
-
- def test_should_block_at_max_capacity(self):
- self.semaphore.acquire('a', blocking=False)
- with self.assertRaises(NoResourcesAvailable):
- self.semaphore.acquire('a', blocking=False)
-
- def test_release_capacity(self):
- acquire_token = self.semaphore.acquire('a', blocking=False)
- self.semaphore.release('a', acquire_token)
- try:
- self.semaphore.acquire('a', blocking=False)
- except NoResourcesAvailable:
- self.fail(
- 'The release of the semaphore should have allowed for '
- 'the second acquire to not be blocked'
- )
-
-
-class TestSlidingWindowSemaphore(unittest.TestCase):
- # These tests use block=False to tests will fail
- # instead of hang the test runner in the case of x
- # incorrect behavior.
- def test_acquire_release_basic_case(self):
- sem = SlidingWindowSemaphore(1)
- # Count is 1
-
- num = sem.acquire('a', blocking=False)
- self.assertEqual(num, 0)
- sem.release('a', 0)
- # Count now back to 1.
-
- def test_can_acquire_release_multiple_times(self):
- sem = SlidingWindowSemaphore(1)
- num = sem.acquire('a', blocking=False)
- self.assertEqual(num, 0)
- sem.release('a', num)
-
- num = sem.acquire('a', blocking=False)
- self.assertEqual(num, 1)
- sem.release('a', num)
-
- def test_can_acquire_a_range(self):
- sem = SlidingWindowSemaphore(3)
- self.assertEqual(sem.acquire('a', blocking=False), 0)
- self.assertEqual(sem.acquire('a', blocking=False), 1)
- self.assertEqual(sem.acquire('a', blocking=False), 2)
- sem.release('a', 0)
- sem.release('a', 1)
- sem.release('a', 2)
- # Now we're reset so we should be able to acquire the same
- # sequence again.
- self.assertEqual(sem.acquire('a', blocking=False), 3)
- self.assertEqual(sem.acquire('a', blocking=False), 4)
- self.assertEqual(sem.acquire('a', blocking=False), 5)
- self.assertEqual(sem.current_count(), 0)
-
- def test_counter_release_only_on_min_element(self):
- sem = SlidingWindowSemaphore(3)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
-
- # The count only increases when we free the min
- # element. This means if we're currently failing to
- # acquire now:
- with self.assertRaises(NoResourcesAvailable):
- sem.acquire('a', blocking=False)
-
- # Then freeing a non-min element:
- sem.release('a', 1)
-
- # doesn't change anything. We still fail to acquire.
- with self.assertRaises(NoResourcesAvailable):
- sem.acquire('a', blocking=False)
- self.assertEqual(sem.current_count(), 0)
-
- def test_raises_error_when_count_is_zero(self):
- sem = SlidingWindowSemaphore(3)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
-
- # Count is now 0 so trying to acquire should fail.
- with self.assertRaises(NoResourcesAvailable):
- sem.acquire('a', blocking=False)
-
- def test_release_counters_can_increment_counter_repeatedly(self):
- sem = SlidingWindowSemaphore(3)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
-
- # These two releases don't increment the counter
- # because we're waiting on 0.
- sem.release('a', 1)
- sem.release('a', 2)
- self.assertEqual(sem.current_count(), 0)
- # But as soon as we release 0, we free up 0, 1, and 2.
- sem.release('a', 0)
- self.assertEqual(sem.current_count(), 3)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
-
- def test_error_to_release_unknown_tag(self):
- sem = SlidingWindowSemaphore(3)
- with self.assertRaises(ValueError):
- sem.release('a', 0)
-
- def test_can_track_multiple_tags(self):
- sem = SlidingWindowSemaphore(3)
- self.assertEqual(sem.acquire('a', blocking=False), 0)
- self.assertEqual(sem.acquire('b', blocking=False), 0)
- self.assertEqual(sem.acquire('a', blocking=False), 1)
-
- # We're at our max of 3 even though 2 are for A and 1 is for B.
- with self.assertRaises(NoResourcesAvailable):
- sem.acquire('a', blocking=False)
- with self.assertRaises(NoResourcesAvailable):
- sem.acquire('b', blocking=False)
-
- def test_can_handle_multiple_tags_released(self):
- sem = SlidingWindowSemaphore(4)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('b', blocking=False)
- sem.acquire('b', blocking=False)
-
- sem.release('b', 1)
- sem.release('a', 1)
- self.assertEqual(sem.current_count(), 0)
-
- sem.release('b', 0)
- self.assertEqual(sem.acquire('a', blocking=False), 2)
-
- sem.release('a', 0)
- self.assertEqual(sem.acquire('b', blocking=False), 2)
-
- def test_is_error_to_release_unknown_sequence_number(self):
- sem = SlidingWindowSemaphore(3)
- sem.acquire('a', blocking=False)
- with self.assertRaises(ValueError):
- sem.release('a', 1)
-
- def test_is_error_to_double_release(self):
- # This is different than other error tests because
- # we're verifying we can reset the state after an
- # acquire/release cycle.
- sem = SlidingWindowSemaphore(2)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.release('a', 0)
- sem.release('a', 1)
- self.assertEqual(sem.current_count(), 2)
- with self.assertRaises(ValueError):
- sem.release('a', 0)
-
- def test_can_check_in_partial_range(self):
- sem = SlidingWindowSemaphore(4)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
-
- sem.release('a', 1)
- sem.release('a', 3)
- sem.release('a', 0)
- self.assertEqual(sem.current_count(), 2)
-
-
-class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase):
- # These tests focus on mutithreaded properties of the range
- # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore.
- def setUp(self):
- self.threads = []
-
- def tearDown(self):
- self.join_threads()
-
- def join_threads(self):
- for thread in self.threads:
- thread.join()
- self.threads = []
-
- def start_threads(self):
- for thread in self.threads:
- thread.start()
-
- def test_acquire_blocks_until_release_is_called(self):
- sem = SlidingWindowSemaphore(2)
- sem.acquire('a', blocking=False)
- sem.acquire('a', blocking=False)
-
- def acquire():
- # This next call to acquire will block.
- self.assertEqual(sem.acquire('a', blocking=True), 2)
-
- t = threading.Thread(target=acquire)
- self.threads.append(t)
- # Starting the thread will block the sem.acquire()
- # in the acquire function above.
- t.start()
- # This still will keep the thread blocked.
- sem.release('a', 1)
- # Releasing the min element will unblock the thread.
- sem.release('a', 0)
- t.join()
- sem.release('a', 2)
-
- def test_stress_invariants_random_order(self):
- sem = SlidingWindowSemaphore(100)
- for _ in range(10):
- recorded = []
- for _ in range(100):
- recorded.append(sem.acquire('a', blocking=False))
- # Release them in randomized order. As long as we
- # eventually free all 100, we should have all the
- # resources released.
- random.shuffle(recorded)
- for i in recorded:
- sem.release('a', i)
-
- # Everything's freed so should be back at count == 100
- self.assertEqual(sem.current_count(), 100)
-
- def test_blocking_stress(self):
- sem = SlidingWindowSemaphore(5)
- num_threads = 10
- num_iterations = 50
-
- def acquire():
- for _ in range(num_iterations):
- num = sem.acquire('a', blocking=True)
- time.sleep(0.001)
- sem.release('a', num)
-
- for i in range(num_threads):
- t = threading.Thread(target=acquire)
- self.threads.append(t)
- self.start_threads()
- self.join_threads()
- # Should have all the available resources freed.
- self.assertEqual(sem.current_count(), 5)
- # Should have acquired num_threads * num_iterations
- self.assertEqual(
- sem.acquire('a', blocking=False), num_threads * num_iterations
- )
-
-
-class TestAdjustChunksize(unittest.TestCase):
- def setUp(self):
- self.adjuster = ChunksizeAdjuster()
-
- def test_valid_chunksize(self):
- chunksize = 7 * (1024 ** 2)
- file_size = 8 * (1024 ** 2)
- new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
- self.assertEqual(new_size, chunksize)
-
- def test_chunksize_below_minimum(self):
- chunksize = MIN_UPLOAD_CHUNKSIZE - 1
- file_size = 3 * MIN_UPLOAD_CHUNKSIZE
- new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
- self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
-
- def test_chunksize_above_maximum(self):
- chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
- file_size = MAX_SINGLE_UPLOAD_SIZE * 2
- new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
- self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
-
- def test_chunksize_too_small(self):
- chunksize = 7 * (1024 ** 2)
- file_size = 5 * (1024 ** 4)
- # If we try to upload a 5TB file, we'll need to use 896MB part
- # sizes.
- new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
- self.assertEqual(new_size, 896 * (1024 ** 2))
- num_parts = file_size / new_size
- self.assertLessEqual(num_parts, MAX_PARTS)
-
- def test_unknown_file_size_with_valid_chunksize(self):
- chunksize = 7 * (1024 ** 2)
- new_size = self.adjuster.adjust_chunksize(chunksize)
- self.assertEqual(new_size, chunksize)
-
- def test_unknown_file_size_below_minimum(self):
- chunksize = MIN_UPLOAD_CHUNKSIZE - 1
- new_size = self.adjuster.adjust_chunksize(chunksize)
- self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
-
- def test_unknown_file_size_above_maximum(self):
- chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
- new_size = self.adjuster.adjust_chunksize(chunksize)
- self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
+# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+import io
+import os.path
+import random
+import re
+import shutil
+import tempfile
+import threading
+import time
+from io import BytesIO, StringIO
+
+from s3transfer.futures import TransferFuture, TransferMeta
+from s3transfer.utils import (
+ MAX_PARTS,
+ MAX_SINGLE_UPLOAD_SIZE,
+ MIN_UPLOAD_CHUNKSIZE,
+ CallArgs,
+ ChunksizeAdjuster,
+ CountCallbackInvoker,
+ DeferredOpenFile,
+ FunctionContainer,
+ NoResourcesAvailable,
+ OSUtils,
+ ReadFileChunk,
+ SlidingWindowSemaphore,
+ StreamReaderProgress,
+ TaskSemaphore,
+ calculate_num_parts,
+ calculate_range_parameter,
+ get_callbacks,
+ get_filtered_dict,
+ invoke_progress_callbacks,
+ random_file_extension,
+)
+from __tests__ import NonSeekableWriter, RecordingSubscriber, mock, unittest
+
+
+class TestGetCallbacks(unittest.TestCase):
+ def setUp(self):
+ self.subscriber = RecordingSubscriber()
+ self.second_subscriber = RecordingSubscriber()
+ self.call_args = CallArgs(
+ subscribers=[self.subscriber, self.second_subscriber]
+ )
+ self.transfer_meta = TransferMeta(self.call_args)
+ self.transfer_future = TransferFuture(self.transfer_meta)
+
+ def test_get_callbacks(self):
+ callbacks = get_callbacks(self.transfer_future, 'queued')
+ # Make sure two callbacks were added as both subscribers had
+ # an on_queued method.
+ self.assertEqual(len(callbacks), 2)
+
+ # Ensure that the callback was injected with the future by calling
+ # one of them and checking that the future was used in the call.
+ callbacks[0]()
+ self.assertEqual(
+ self.subscriber.on_queued_calls, [{'future': self.transfer_future}]
+ )
+
+ def test_get_callbacks_for_missing_type(self):
+ callbacks = get_callbacks(self.transfer_future, 'fake_state')
+ # There should be no callbacks as the subscribers will not have the
+ # on_fake_state method
+ self.assertEqual(len(callbacks), 0)
+
+
+class TestGetFilteredDict(unittest.TestCase):
+ def test_get_filtered_dict(self):
+ original = {'Include': 'IncludeValue', 'NotInlude': 'NotIncludeValue'}
+ whitelist = ['Include']
+ self.assertEqual(
+ get_filtered_dict(original, whitelist), {'Include': 'IncludeValue'}
+ )
+
+
+class TestCallArgs(unittest.TestCase):
+ def test_call_args(self):
+ call_args = CallArgs(foo='bar', biz='baz')
+ self.assertEqual(call_args.foo, 'bar')
+ self.assertEqual(call_args.biz, 'baz')
+
+
+class TestFunctionContainer(unittest.TestCase):
+ def get_args_kwargs(self, *args, **kwargs):
+ return args, kwargs
+
+ def test_call(self):
+ func_container = FunctionContainer(
+ self.get_args_kwargs, 'foo', bar='baz'
+ )
+ self.assertEqual(func_container(), (('foo',), {'bar': 'baz'}))
+
+ def test_repr(self):
+ func_container = FunctionContainer(
+ self.get_args_kwargs, 'foo', bar='baz'
+ )
+ self.assertEqual(
+ str(func_container),
+ 'Function: {} with args {} and kwargs {}'.format(
+ self.get_args_kwargs, ('foo',), {'bar': 'baz'}
+ ),
+ )
+
+
+class TestCountCallbackInvoker(unittest.TestCase):
+ def invoke_callback(self):
+ self.ref_results.append('callback invoked')
+
+ def assert_callback_invoked(self):
+ self.assertEqual(self.ref_results, ['callback invoked'])
+
+ def assert_callback_not_invoked(self):
+ self.assertEqual(self.ref_results, [])
+
+ def setUp(self):
+ self.ref_results = []
+ self.invoker = CountCallbackInvoker(self.invoke_callback)
+
+ def test_increment(self):
+ self.invoker.increment()
+ self.assertEqual(self.invoker.current_count, 1)
+
+ def test_decrement(self):
+ self.invoker.increment()
+ self.invoker.increment()
+ self.invoker.decrement()
+ self.assertEqual(self.invoker.current_count, 1)
+
+ def test_count_cannot_go_below_zero(self):
+ with self.assertRaises(RuntimeError):
+ self.invoker.decrement()
+
+ def test_callback_invoked_only_once_finalized(self):
+ self.invoker.increment()
+ self.invoker.decrement()
+ self.assert_callback_not_invoked()
+ self.invoker.finalize()
+ # Callback should only be invoked once finalized
+ self.assert_callback_invoked()
+
+ def test_callback_invoked_after_finalizing_and_count_reaching_zero(self):
+ self.invoker.increment()
+ self.invoker.finalize()
+ # Make sure that it does not get invoked immediately after
+ # finalizing as the count is currently one
+ self.assert_callback_not_invoked()
+ self.invoker.decrement()
+ self.assert_callback_invoked()
+
+ def test_cannot_increment_after_finalization(self):
+ self.invoker.finalize()
+ with self.assertRaises(RuntimeError):
+ self.invoker.increment()
+
+
+class TestRandomFileExtension(unittest.TestCase):
+ def test_has_proper_length(self):
+ self.assertEqual(len(random_file_extension(num_digits=4)), 4)
+
+
+class TestInvokeProgressCallbacks(unittest.TestCase):
+ def test_invoke_progress_callbacks(self):
+ recording_subscriber = RecordingSubscriber()
+ invoke_progress_callbacks([recording_subscriber.on_progress], 2)
+ self.assertEqual(recording_subscriber.calculate_bytes_seen(), 2)
+
+ def test_invoke_progress_callbacks_with_no_progress(self):
+ recording_subscriber = RecordingSubscriber()
+ invoke_progress_callbacks([recording_subscriber.on_progress], 0)
+ self.assertEqual(len(recording_subscriber.on_progress_calls), 0)
+
+
+class TestCalculateNumParts(unittest.TestCase):
+ def test_calculate_num_parts_divisible(self):
+ self.assertEqual(calculate_num_parts(size=4, part_size=2), 2)
+
+ def test_calculate_num_parts_not_divisible(self):
+ self.assertEqual(calculate_num_parts(size=3, part_size=2), 2)
+
+
+class TestCalculateRangeParameter(unittest.TestCase):
+ def setUp(self):
+ self.part_size = 5
+ self.part_index = 1
+ self.num_parts = 3
+
+ def test_calculate_range_paramter(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, self.num_parts
+ )
+ self.assertEqual(range_val, 'bytes=5-9')
+
+ def test_last_part_with_no_total_size(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, num_parts=2
+ )
+ self.assertEqual(range_val, 'bytes=5-')
+
+ def test_last_part_with_total_size(self):
+ range_val = calculate_range_parameter(
+ self.part_size, self.part_index, num_parts=2, total_size=8
+ )
+ self.assertEqual(range_val, 'bytes=5-7')
+
+
+class BaseUtilsTest(unittest.TestCase):
+ def setUp(self):
+ self.tempdir = tempfile.mkdtemp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ self.content = b'abc'
+ with open(self.filename, 'wb') as f:
+ f.write(self.content)
+ self.amounts_seen = []
+ self.num_close_callback_calls = 0
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir)
+
+ def callback(self, bytes_transferred):
+ self.amounts_seen.append(bytes_transferred)
+
+ def close_callback(self):
+ self.num_close_callback_calls += 1
+
+
+class TestOSUtils(BaseUtilsTest):
+ def test_get_file_size(self):
+ self.assertEqual(
+ OSUtils().get_file_size(self.filename), len(self.content)
+ )
+
+ def test_open_file_chunk_reader(self):
+ reader = OSUtils().open_file_chunk_reader(
+ self.filename, 0, 3, [self.callback]
+ )
+
+ # The returned reader should be a ReadFileChunk.
+ self.assertIsInstance(reader, ReadFileChunk)
+ # The content of the reader should be correct.
+ self.assertEqual(reader.read(), self.content)
+ # Callbacks should be disabled depspite being passed in.
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_open_file_chunk_reader_from_fileobj(self):
+ with open(self.filename, 'rb') as f:
+ reader = OSUtils().open_file_chunk_reader_from_fileobj(
+ f, len(self.content), len(self.content), [self.callback]
+ )
+
+ # The returned reader should be a ReadFileChunk.
+ self.assertIsInstance(reader, ReadFileChunk)
+ # The content of the reader should be correct.
+ self.assertEqual(reader.read(), self.content)
+ reader.close()
+ # Callbacks should be disabled depspite being passed in.
+ self.assertEqual(self.amounts_seen, [])
+ self.assertEqual(self.num_close_callback_calls, 0)
+
+ def test_open_file(self):
+ fileobj = OSUtils().open(os.path.join(self.tempdir, 'foo'), 'w')
+ self.assertTrue(hasattr(fileobj, 'write'))
+
+ def test_remove_file_ignores_errors(self):
+ non_existent_file = os.path.join(self.tempdir, 'no-exist')
+ # This should not exist to start.
+ self.assertFalse(os.path.exists(non_existent_file))
+ try:
+ OSUtils().remove_file(non_existent_file)
+ except OSError as e:
+ self.fail('OSError should have been caught: %s' % e)
+
+ def test_remove_file_proxies_remove_file(self):
+ OSUtils().remove_file(self.filename)
+ self.assertFalse(os.path.exists(self.filename))
+
+ def test_rename_file(self):
+ new_filename = os.path.join(self.tempdir, 'newfoo')
+ OSUtils().rename_file(self.filename, new_filename)
+ self.assertFalse(os.path.exists(self.filename))
+ self.assertTrue(os.path.exists(new_filename))
+
+ def test_is_special_file_for_normal_file(self):
+ self.assertFalse(OSUtils().is_special_file(self.filename))
+
+ def test_is_special_file_for_non_existant_file(self):
+ non_existant_filename = os.path.join(self.tempdir, 'no-exist')
+ self.assertFalse(os.path.exists(non_existant_filename))
+ self.assertFalse(OSUtils().is_special_file(non_existant_filename))
+
+ def test_get_temp_filename(self):
+ filename = 'myfile'
+ self.assertIsNotNone(
+ re.match(
+ r'%s\.[0-9A-Fa-f]{8}$' % filename,
+ OSUtils().get_temp_filename(filename),
+ )
+ )
+
+ def test_get_temp_filename_len_255(self):
+ filename = 'a' * 255
+ temp_filename = OSUtils().get_temp_filename(filename)
+ self.assertLessEqual(len(temp_filename), 255)
+
+ def test_get_temp_filename_len_gt_255(self):
+ filename = 'a' * 280
+ temp_filename = OSUtils().get_temp_filename(filename)
+ self.assertLessEqual(len(temp_filename), 255)
+
+ def test_allocate(self):
+ truncate_size = 1
+ OSUtils().allocate(self.filename, truncate_size)
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(len(f.read()), truncate_size)
+
+ @mock.patch('s3transfer.utils.fallocate')
+ def test_allocate_with_io_error(self, mock_fallocate):
+ mock_fallocate.side_effect = IOError()
+ with self.assertRaises(IOError):
+ OSUtils().allocate(self.filename, 1)
+ self.assertFalse(os.path.exists(self.filename))
+
+ @mock.patch('s3transfer.utils.fallocate')
+ def test_allocate_with_os_error(self, mock_fallocate):
+ mock_fallocate.side_effect = OSError()
+ with self.assertRaises(OSError):
+ OSUtils().allocate(self.filename, 1)
+ self.assertFalse(os.path.exists(self.filename))
+
+
+class TestDeferredOpenFile(BaseUtilsTest):
+ def setUp(self):
+ super().setUp()
+ self.filename = os.path.join(self.tempdir, 'foo')
+ self.contents = b'my contents'
+ with open(self.filename, 'wb') as f:
+ f.write(self.contents)
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename, open_function=self.recording_open_function
+ )
+ self.open_call_args = []
+
+ def tearDown(self):
+ self.deferred_open_file.close()
+ super().tearDown()
+
+ def recording_open_function(self, filename, mode):
+ self.open_call_args.append((filename, mode))
+ return open(filename, mode)
+
+ def open_nonseekable(self, filename, mode):
+ self.open_call_args.append((filename, mode))
+ return NonSeekableWriter(BytesIO(self.content))
+
+ def test_instantiation_does_not_open_file(self):
+ DeferredOpenFile(
+ self.filename, open_function=self.recording_open_function
+ )
+ self.assertEqual(len(self.open_call_args), 0)
+
+ def test_name(self):
+ self.assertEqual(self.deferred_open_file.name, self.filename)
+
+ def test_read(self):
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[0:2])
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[2:4])
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_write(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ open_function=self.recording_open_function,
+ )
+
+ write_content = b'foo'
+ self.deferred_open_file.write(write_content)
+ self.deferred_open_file.write(write_content)
+ self.deferred_open_file.close()
+ # Both of the writes should now be in the file.
+ with open(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), write_content * 2)
+ # Open should have only been called once.
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_seek(self):
+ self.deferred_open_file.seek(2)
+ content = self.deferred_open_file.read(2)
+ self.assertEqual(content, self.contents[2:4])
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_open_does_not_seek_with_zero_start_byte(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ start_byte=0,
+ open_function=self.open_nonseekable,
+ )
+
+ try:
+ # If this seeks, an UnsupportedOperation error will be raised.
+ self.deferred_open_file.write(b'data')
+ except io.UnsupportedOperation:
+ self.fail('DeferredOpenFile seeked upon opening')
+
+ def test_open_seeks_with_nonzero_start_byte(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='wb',
+ start_byte=5,
+ open_function=self.open_nonseekable,
+ )
+
+ # Since a non-seekable file is being opened, calling Seek will raise
+ # an UnsupportedOperation error.
+ with self.assertRaises(io.UnsupportedOperation):
+ self.deferred_open_file.write(b'data')
+
+ def test_tell(self):
+ self.deferred_open_file.tell()
+ # tell() should not have opened the file if it has not been seeked
+ # or read because we know the start bytes upfront.
+ self.assertEqual(len(self.open_call_args), 0)
+
+ self.deferred_open_file.seek(2)
+ self.assertEqual(self.deferred_open_file.tell(), 2)
+ self.assertEqual(len(self.open_call_args), 1)
+
+ def test_open_args(self):
+ self.deferred_open_file = DeferredOpenFile(
+ self.filename,
+ mode='ab+',
+ open_function=self.recording_open_function,
+ )
+ # Force an open
+ self.deferred_open_file.write(b'data')
+ self.assertEqual(len(self.open_call_args), 1)
+ self.assertEqual(self.open_call_args[0], (self.filename, 'ab+'))
+
+ def test_context_handler(self):
+ with self.deferred_open_file:
+ self.assertEqual(len(self.open_call_args), 1)
+
+
+class TestReadFileChunk(BaseUtilsTest):
+ def test_read_entire_chunk(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=3
+ )
+ self.assertEqual(chunk.read(), b'one')
+ self.assertEqual(chunk.read(), b'')
+
+ def test_read_with_amount_size(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(1), b'f')
+ self.assertEqual(chunk.read(1), b'o')
+ self.assertEqual(chunk.read(1), b'u')
+ self.assertEqual(chunk.read(1), b'r')
+ self.assertEqual(chunk.read(1), b'')
+
+ def test_reset_stream_emulation(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=11, chunk_size=4
+ )
+ self.assertEqual(chunk.read(), b'four')
+ chunk.seek(0)
+ self.assertEqual(chunk.read(), b'four')
+
+ def test_read_past_end_of_file(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.read(), b'')
+ self.assertEqual(len(chunk), 3)
+
+ def test_tell_and_seek(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'onetwothreefourfivesixseveneightnineten')
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=36, chunk_size=100000
+ )
+ self.assertEqual(chunk.tell(), 0)
+ self.assertEqual(chunk.read(), b'ten')
+ self.assertEqual(chunk.tell(), 3)
+ chunk.seek(0)
+ self.assertEqual(chunk.tell(), 0)
+ chunk.seek(1, whence=1)
+ self.assertEqual(chunk.tell(), 1)
+ chunk.seek(-1, whence=1)
+ self.assertEqual(chunk.tell(), 0)
+ chunk.seek(-1, whence=2)
+ self.assertEqual(chunk.tell(), 2)
+
+ def test_tell_and_seek_boundaries(self):
+ # Test to ensure ReadFileChunk behaves the same as the
+ # Python standard library around seeking and reading out
+ # of bounds in a file object.
+ data = b'abcdefghij12345678klmnopqrst'
+ start_pos = 10
+ chunk_size = 8
+
+ # Create test file
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(data)
+
+ # ReadFileChunk should be a substring of only numbers
+ file_objects = [
+ ReadFileChunk.from_filename(
+ filename, start_byte=start_pos, chunk_size=chunk_size
+ )
+ ]
+
+ # Uncomment next line to validate we match Python's io.BytesIO
+ # file_objects.append(io.BytesIO(data[start_pos:start_pos+chunk_size]))
+
+ for obj in file_objects:
+ self._assert_whence_start_behavior(obj)
+ self._assert_whence_end_behavior(obj)
+ self._assert_whence_relative_behavior(obj)
+ self._assert_boundary_behavior(obj)
+
+ def _assert_whence_start_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(1, 0)
+ self.assertEqual(file_obj.tell(), 1)
+
+ file_obj.seek(1)
+ self.assertEqual(file_obj.tell(), 1)
+ self.assertEqual(file_obj.read(), b'2345678')
+
+ file_obj.seek(3, 0)
+ self.assertEqual(file_obj.tell(), 3)
+
+ file_obj.seek(0, 0)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_whence_relative_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(2, 1)
+ self.assertEqual(file_obj.tell(), 2)
+
+ file_obj.seek(1, 1)
+ self.assertEqual(file_obj.tell(), 3)
+ self.assertEqual(file_obj.read(), b'45678')
+
+ file_obj.seek(20, 1)
+ self.assertEqual(file_obj.tell(), 28)
+
+ file_obj.seek(-30, 1)
+ self.assertEqual(file_obj.tell(), 0)
+ self.assertEqual(file_obj.read(), b'12345678')
+
+ file_obj.seek(-8, 1)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_whence_end_behavior(self, file_obj):
+ self.assertEqual(file_obj.tell(), 0)
+
+ file_obj.seek(-1, 2)
+ self.assertEqual(file_obj.tell(), 7)
+
+ file_obj.seek(1, 2)
+ self.assertEqual(file_obj.tell(), 9)
+
+ file_obj.seek(3, 2)
+ self.assertEqual(file_obj.tell(), 11)
+ self.assertEqual(file_obj.read(), b'')
+
+ file_obj.seek(-15, 2)
+ self.assertEqual(file_obj.tell(), 0)
+ self.assertEqual(file_obj.read(), b'12345678')
+
+ file_obj.seek(-8, 2)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def _assert_boundary_behavior(self, file_obj):
+ # Verify we're at the start
+ self.assertEqual(file_obj.tell(), 0)
+
+ # Verify we can't move backwards beyond start of file
+ file_obj.seek(-10, 1)
+ self.assertEqual(file_obj.tell(), 0)
+
+ # Verify we *can* move after end of file, but return nothing
+ file_obj.seek(10, 2)
+ self.assertEqual(file_obj.tell(), 18)
+ self.assertEqual(file_obj.read(), b'')
+ self.assertEqual(file_obj.read(10), b'')
+
+ # Verify we can partially rewind
+ file_obj.seek(-12, 1)
+ self.assertEqual(file_obj.tell(), 6)
+ self.assertEqual(file_obj.read(), b'78')
+ self.assertEqual(file_obj.tell(), 8)
+
+ # Verify we can rewind to start
+ file_obj.seek(0)
+ self.assertEqual(file_obj.tell(), 0)
+
+ def test_file_chunk_supports_context_manager(self):
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(b'abc')
+ with ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=2
+ ) as chunk:
+ val = chunk.read()
+ self.assertEqual(val, b'ab')
+
+ def test_iter_is_always_empty(self):
+ # This tests the workaround for the httplib bug (see
+ # the source for more info).
+ filename = os.path.join(self.tempdir, 'foo')
+ open(filename, 'wb').close()
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=0, chunk_size=10
+ )
+ self.assertEqual(list(chunk), [])
+
+ def test_callback_is_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1, 1, 1])
+
+ def test_all_callbacks_invoked_on_read(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback, self.callback],
+ )
+ chunk.read(1)
+ chunk.read(1)
+ chunk.read(1)
+ # The list should be twice as long because there are two callbacks
+ # recording the amount read.
+ self.assertEqual(self.amounts_seen, [1, 1, 1, 1, 1, 1])
+
+ def test_callback_can_be_disabled(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.disable_callback()
+ # Now reading from the ReadFileChunk should not invoke
+ # the callback.
+ chunk.read()
+ self.assertEqual(self.amounts_seen, [])
+
+ def test_callback_will_also_be_triggered_by_seek(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.read(2)
+ chunk.seek(0)
+ chunk.read(2)
+ chunk.seek(1)
+ chunk.read(2)
+ self.assertEqual(self.amounts_seen, [2, -2, 2, -1, 2])
+
+ def test_callback_triggered_by_out_of_bound_seeks(self):
+ data = b'abcdefghij1234567890klmnopqr'
+
+ # Create test file
+ filename = os.path.join(self.tempdir, 'foo')
+ with open(filename, 'wb') as f:
+ f.write(data)
+ chunk = ReadFileChunk.from_filename(
+ filename, start_byte=10, chunk_size=10, callbacks=[self.callback]
+ )
+
+ # Seek calls that generate "0" progress are skipped by
+ # invoke_progress_callbacks and won't appear in the list.
+ expected_callback_prog = [10, -5, 5, -1, 1, -1, 1, -5, 5, -10]
+
+ self._assert_out_of_bound_start_seek(chunk, expected_callback_prog)
+ self._assert_out_of_bound_relative_seek(chunk, expected_callback_prog)
+ self._assert_out_of_bound_end_seek(chunk, expected_callback_prog)
+
+ def _assert_out_of_bound_start_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(20) # (20, 10)
+ chunk.seek(5) # (5, -5)
+ chunk.seek(20) # (20, 5)
+ chunk.seek(9) # (9, -1)
+ chunk.seek(20) # (20, 1)
+ chunk.seek(11) # (11, 0)
+ chunk.seek(20) # (20, 0)
+ chunk.seek(9) # (9, -1)
+ chunk.seek(20) # (20, 1)
+ chunk.seek(5) # (5, -5)
+ chunk.seek(20) # (20, 5)
+ chunk.seek(0) # (0, -10)
+ chunk.seek(0) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def _assert_out_of_bound_relative_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(20, 1) # (20, 10)
+ chunk.seek(-15, 1) # (5, -5)
+ chunk.seek(15, 1) # (20, 5)
+ chunk.seek(-11, 1) # (9, -1)
+ chunk.seek(11, 1) # (20, 1)
+ chunk.seek(-9, 1) # (11, 0)
+ chunk.seek(9, 1) # (20, 0)
+ chunk.seek(-11, 1) # (9, -1)
+ chunk.seek(11, 1) # (20, 1)
+ chunk.seek(-15, 1) # (5, -5)
+ chunk.seek(15, 1) # (20, 5)
+ chunk.seek(-20, 1) # (0, -10)
+ chunk.seek(-1000, 1) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def _assert_out_of_bound_end_seek(self, chunk, expected):
+ # clear amounts_seen
+ self.amounts_seen = []
+ self.assertEqual(self.amounts_seen, [])
+
+ # (position, change)
+ chunk.seek(10, 2) # (20, 10)
+ chunk.seek(-5, 2) # (5, -5)
+ chunk.seek(10, 2) # (20, 5)
+ chunk.seek(-1, 2) # (9, -1)
+ chunk.seek(10, 2) # (20, 1)
+ chunk.seek(1, 2) # (11, 0)
+ chunk.seek(10, 2) # (20, 0)
+ chunk.seek(-1, 2) # (9, -1)
+ chunk.seek(10, 2) # (20, 1)
+ chunk.seek(-5, 2) # (5, -5)
+ chunk.seek(10, 2) # (20, 5)
+ chunk.seek(-10, 2) # (0, -10)
+ chunk.seek(-1000, 2) # (0, 0)
+
+ self.assertEqual(self.amounts_seen, expected)
+
+ def test_close_callbacks(self):
+ with open(self.filename) as f:
+ chunk = ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ close_callbacks=[self.close_callback],
+ )
+ chunk.close()
+ self.assertEqual(self.num_close_callback_calls, 1)
+
+ def test_close_callbacks_when_not_enabled(self):
+ with open(self.filename) as f:
+ chunk = ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ enable_callbacks=False,
+ close_callbacks=[self.close_callback],
+ )
+ chunk.close()
+ self.assertEqual(self.num_close_callback_calls, 0)
+
+ def test_close_callbacks_when_context_handler_is_used(self):
+ with open(self.filename) as f:
+ with ReadFileChunk(
+ f,
+ chunk_size=1,
+ full_file_size=3,
+ close_callbacks=[self.close_callback],
+ ) as chunk:
+ chunk.read(1)
+ self.assertEqual(self.num_close_callback_calls, 1)
+
+ def test_signal_transferring(self):
+ chunk = ReadFileChunk.from_filename(
+ self.filename,
+ start_byte=0,
+ chunk_size=3,
+ callbacks=[self.callback],
+ )
+ chunk.signal_not_transferring()
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [])
+ chunk.signal_transferring()
+ chunk.read(1)
+ self.assertEqual(self.amounts_seen, [1])
+
+ def test_signal_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock()
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ chunk.signal_transferring()
+ self.assertTrue(underlying_stream.signal_transferring.called)
+
+ def test_no_call_signal_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock(io.RawIOBase)
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ try:
+ chunk.signal_transferring()
+ except AttributeError:
+ self.fail(
+ 'The stream should not have tried to call signal_transferring '
+ 'to the underlying stream.'
+ )
+
+ def test_signal_not_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock()
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ chunk.signal_not_transferring()
+ self.assertTrue(underlying_stream.signal_not_transferring.called)
+
+ def test_no_call_signal_not_transferring_to_underlying_fileobj(self):
+ underlying_stream = mock.Mock(io.RawIOBase)
+ underlying_stream.tell.return_value = 0
+ chunk = ReadFileChunk(underlying_stream, 3, 3)
+ try:
+ chunk.signal_not_transferring()
+ except AttributeError:
+ self.fail(
+ 'The stream should not have tried to call '
+ 'signal_not_transferring to the underlying stream.'
+ )
+
+
+class TestStreamReaderProgress(BaseUtilsTest):
+ def test_proxies_to_wrapped_stream(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(original_stream)
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+
+ def test_callback_invoked(self):
+ original_stream = StringIO('foobarbaz')
+ wrapped = StreamReaderProgress(
+ original_stream, [self.callback, self.callback]
+ )
+ self.assertEqual(wrapped.read(), 'foobarbaz')
+ self.assertEqual(self.amounts_seen, [9, 9])
+
+
+class TestTaskSemaphore(unittest.TestCase):
+ def setUp(self):
+ self.semaphore = TaskSemaphore(1)
+
+ def test_should_block_at_max_capacity(self):
+ self.semaphore.acquire('a', blocking=False)
+ with self.assertRaises(NoResourcesAvailable):
+ self.semaphore.acquire('a', blocking=False)
+
+ def test_release_capacity(self):
+ acquire_token = self.semaphore.acquire('a', blocking=False)
+ self.semaphore.release('a', acquire_token)
+ try:
+ self.semaphore.acquire('a', blocking=False)
+ except NoResourcesAvailable:
+ self.fail(
+ 'The release of the semaphore should have allowed for '
+ 'the second acquire to not be blocked'
+ )
+
+
+class TestSlidingWindowSemaphore(unittest.TestCase):
+ # These tests use block=False to tests will fail
+ # instead of hang the test runner in the case of x
+ # incorrect behavior.
+ def test_acquire_release_basic_case(self):
+ sem = SlidingWindowSemaphore(1)
+ # Count is 1
+
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 0)
+ sem.release('a', 0)
+ # Count now back to 1.
+
+ def test_can_acquire_release_multiple_times(self):
+ sem = SlidingWindowSemaphore(1)
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 0)
+ sem.release('a', num)
+
+ num = sem.acquire('a', blocking=False)
+ self.assertEqual(num, 1)
+ sem.release('a', num)
+
+ def test_can_acquire_a_range(self):
+ sem = SlidingWindowSemaphore(3)
+ self.assertEqual(sem.acquire('a', blocking=False), 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 1)
+ self.assertEqual(sem.acquire('a', blocking=False), 2)
+ sem.release('a', 0)
+ sem.release('a', 1)
+ sem.release('a', 2)
+ # Now we're reset so we should be able to acquire the same
+ # sequence again.
+ self.assertEqual(sem.acquire('a', blocking=False), 3)
+ self.assertEqual(sem.acquire('a', blocking=False), 4)
+ self.assertEqual(sem.acquire('a', blocking=False), 5)
+ self.assertEqual(sem.current_count(), 0)
+
+ def test_counter_release_only_on_min_element(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # The count only increases when we free the min
+ # element. This means if we're currently failing to
+ # acquire now:
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+
+ # Then freeing a non-min element:
+ sem.release('a', 1)
+
+ # doesn't change anything. We still fail to acquire.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+ self.assertEqual(sem.current_count(), 0)
+
+ def test_raises_error_when_count_is_zero(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # Count is now 0 so trying to acquire should fail.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+
+ def test_release_counters_can_increment_counter_repeatedly(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ # These two releases don't increment the counter
+ # because we're waiting on 0.
+ sem.release('a', 1)
+ sem.release('a', 2)
+ self.assertEqual(sem.current_count(), 0)
+ # But as soon as we release 0, we free up 0, 1, and 2.
+ sem.release('a', 0)
+ self.assertEqual(sem.current_count(), 3)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ def test_error_to_release_unknown_tag(self):
+ sem = SlidingWindowSemaphore(3)
+ with self.assertRaises(ValueError):
+ sem.release('a', 0)
+
+ def test_can_track_multiple_tags(self):
+ sem = SlidingWindowSemaphore(3)
+ self.assertEqual(sem.acquire('a', blocking=False), 0)
+ self.assertEqual(sem.acquire('b', blocking=False), 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 1)
+
+ # We're at our max of 3 even though 2 are for A and 1 is for B.
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('a', blocking=False)
+ with self.assertRaises(NoResourcesAvailable):
+ sem.acquire('b', blocking=False)
+
+ def test_can_handle_multiple_tags_released(self):
+ sem = SlidingWindowSemaphore(4)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('b', blocking=False)
+ sem.acquire('b', blocking=False)
+
+ sem.release('b', 1)
+ sem.release('a', 1)
+ self.assertEqual(sem.current_count(), 0)
+
+ sem.release('b', 0)
+ self.assertEqual(sem.acquire('a', blocking=False), 2)
+
+ sem.release('a', 0)
+ self.assertEqual(sem.acquire('b', blocking=False), 2)
+
+ def test_is_error_to_release_unknown_sequence_number(self):
+ sem = SlidingWindowSemaphore(3)
+ sem.acquire('a', blocking=False)
+ with self.assertRaises(ValueError):
+ sem.release('a', 1)
+
+ def test_is_error_to_double_release(self):
+ # This is different than other error tests because
+ # we're verifying we can reset the state after an
+ # acquire/release cycle.
+ sem = SlidingWindowSemaphore(2)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.release('a', 0)
+ sem.release('a', 1)
+ self.assertEqual(sem.current_count(), 2)
+ with self.assertRaises(ValueError):
+ sem.release('a', 0)
+
+ def test_can_check_in_partial_range(self):
+ sem = SlidingWindowSemaphore(4)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ sem.release('a', 1)
+ sem.release('a', 3)
+ sem.release('a', 0)
+ self.assertEqual(sem.current_count(), 2)
+
+
+class TestThreadingPropertiesForSlidingWindowSemaphore(unittest.TestCase):
+ # These tests focus on mutithreaded properties of the range
+ # semaphore. Basic functionality is tested in TestSlidingWindowSemaphore.
+ def setUp(self):
+ self.threads = []
+
+ def tearDown(self):
+ self.join_threads()
+
+ def join_threads(self):
+ for thread in self.threads:
+ thread.join()
+ self.threads = []
+
+ def start_threads(self):
+ for thread in self.threads:
+ thread.start()
+
+ def test_acquire_blocks_until_release_is_called(self):
+ sem = SlidingWindowSemaphore(2)
+ sem.acquire('a', blocking=False)
+ sem.acquire('a', blocking=False)
+
+ def acquire():
+ # This next call to acquire will block.
+ self.assertEqual(sem.acquire('a', blocking=True), 2)
+
+ t = threading.Thread(target=acquire)
+ self.threads.append(t)
+ # Starting the thread will block the sem.acquire()
+ # in the acquire function above.
+ t.start()
+ # This still will keep the thread blocked.
+ sem.release('a', 1)
+ # Releasing the min element will unblock the thread.
+ sem.release('a', 0)
+ t.join()
+ sem.release('a', 2)
+
+ def test_stress_invariants_random_order(self):
+ sem = SlidingWindowSemaphore(100)
+ for _ in range(10):
+ recorded = []
+ for _ in range(100):
+ recorded.append(sem.acquire('a', blocking=False))
+ # Release them in randomized order. As long as we
+ # eventually free all 100, we should have all the
+ # resources released.
+ random.shuffle(recorded)
+ for i in recorded:
+ sem.release('a', i)
+
+ # Everything's freed so should be back at count == 100
+ self.assertEqual(sem.current_count(), 100)
+
+ def test_blocking_stress(self):
+ sem = SlidingWindowSemaphore(5)
+ num_threads = 10
+ num_iterations = 50
+
+ def acquire():
+ for _ in range(num_iterations):
+ num = sem.acquire('a', blocking=True)
+ time.sleep(0.001)
+ sem.release('a', num)
+
+ for i in range(num_threads):
+ t = threading.Thread(target=acquire)
+ self.threads.append(t)
+ self.start_threads()
+ self.join_threads()
+ # Should have all the available resources freed.
+ self.assertEqual(sem.current_count(), 5)
+ # Should have acquired num_threads * num_iterations
+ self.assertEqual(
+ sem.acquire('a', blocking=False), num_threads * num_iterations
+ )
+
+
+class TestAdjustChunksize(unittest.TestCase):
+ def setUp(self):
+ self.adjuster = ChunksizeAdjuster()
+
+ def test_valid_chunksize(self):
+ chunksize = 7 * (1024 ** 2)
+ file_size = 8 * (1024 ** 2)
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, chunksize)
+
+ def test_chunksize_below_minimum(self):
+ chunksize = MIN_UPLOAD_CHUNKSIZE - 1
+ file_size = 3 * MIN_UPLOAD_CHUNKSIZE
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
+
+ def test_chunksize_above_maximum(self):
+ chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
+ file_size = MAX_SINGLE_UPLOAD_SIZE * 2
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
+
+ def test_chunksize_too_small(self):
+ chunksize = 7 * (1024 ** 2)
+ file_size = 5 * (1024 ** 4)
+ # If we try to upload a 5TB file, we'll need to use 896MB part
+ # sizes.
+ new_size = self.adjuster.adjust_chunksize(chunksize, file_size)
+ self.assertEqual(new_size, 896 * (1024 ** 2))
+ num_parts = file_size / new_size
+ self.assertLessEqual(num_parts, MAX_PARTS)
+
+ def test_unknown_file_size_with_valid_chunksize(self):
+ chunksize = 7 * (1024 ** 2)
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, chunksize)
+
+ def test_unknown_file_size_below_minimum(self):
+ chunksize = MIN_UPLOAD_CHUNKSIZE - 1
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, MIN_UPLOAD_CHUNKSIZE)
+
+ def test_unknown_file_size_above_maximum(self):
+ chunksize = MAX_SINGLE_UPLOAD_SIZE + 1
+ new_size = self.adjuster.adjust_chunksize(chunksize)
+ self.assertEqual(new_size, MAX_SINGLE_UPLOAD_SIZE)
diff --git a/contrib/python/s3transfer/py3/tests/ya.make b/contrib/python/s3transfer/py3/tests/ya.make
index 2fdeaf8ca2..fdbf22b0c5 100644
--- a/contrib/python/s3transfer/py3/tests/ya.make
+++ b/contrib/python/s3transfer/py3/tests/ya.make
@@ -1,44 +1,44 @@
-PY3TEST()
-
-OWNER(g:python-contrib)
-
-SIZE(MEDIUM)
-
-FORK_TESTS()
-
-PEERDIR(
- contrib/python/mock
- contrib/python/s3transfer
-)
-
-TEST_SRCS(
- functional/__init__.py
- functional/test_copy.py
- functional/test_crt.py
- functional/test_delete.py
- functional/test_download.py
- functional/test_manager.py
- functional/test_processpool.py
- functional/test_upload.py
- functional/test_utils.py
- __init__.py
- unit/__init__.py
- unit/test_bandwidth.py
- unit/test_compat.py
- unit/test_copies.py
- unit/test_crt.py
- unit/test_delete.py
- unit/test_download.py
- unit/test_futures.py
- unit/test_manager.py
- unit/test_processpool.py
- unit/test_s3transfer.py
- unit/test_subscribers.py
- unit/test_tasks.py
- unit/test_upload.py
- unit/test_utils.py
-)
-
-NO_LINT()
-
-END()
+PY3TEST()
+
+OWNER(g:python-contrib)
+
+SIZE(MEDIUM)
+
+FORK_TESTS()
+
+PEERDIR(
+ contrib/python/mock
+ contrib/python/s3transfer
+)
+
+TEST_SRCS(
+ functional/__init__.py
+ functional/test_copy.py
+ functional/test_crt.py
+ functional/test_delete.py
+ functional/test_download.py
+ functional/test_manager.py
+ functional/test_processpool.py
+ functional/test_upload.py
+ functional/test_utils.py
+ __init__.py
+ unit/__init__.py
+ unit/test_bandwidth.py
+ unit/test_compat.py
+ unit/test_copies.py
+ unit/test_crt.py
+ unit/test_delete.py
+ unit/test_download.py
+ unit/test_futures.py
+ unit/test_manager.py
+ unit/test_processpool.py
+ unit/test_s3transfer.py
+ unit/test_subscribers.py
+ unit/test_tasks.py
+ unit/test_upload.py
+ unit/test_utils.py
+)
+
+NO_LINT()
+
+END()
diff --git a/contrib/python/s3transfer/py3/ya.make b/contrib/python/s3transfer/py3/ya.make
index 73a52aafa9..964a630639 100644
--- a/contrib/python/s3transfer/py3/ya.make
+++ b/contrib/python/s3transfer/py3/ya.make
@@ -1,51 +1,51 @@
-# Generated by devtools/yamaker (pypi).
-
-PY3_LIBRARY()
+# Generated by devtools/yamaker (pypi).
-OWNER(gebetix g:python-contrib)
+PY3_LIBRARY()
-VERSION(0.5.1)
+OWNER(gebetix g:python-contrib)
-LICENSE(Apache-2.0)
+VERSION(0.5.1)
+
+LICENSE(Apache-2.0)
PEERDIR(
contrib/python/botocore
)
-NO_LINT()
-
-NO_CHECK_IMPORTS(
- s3transfer.crt
-)
-
+NO_LINT()
+
+NO_CHECK_IMPORTS(
+ s3transfer.crt
+)
+
PY_SRCS(
TOP_LEVEL
s3transfer/__init__.py
s3transfer/bandwidth.py
s3transfer/compat.py
- s3transfer/constants.py
- s3transfer/copies.py
- s3transfer/crt.py
+ s3transfer/constants.py
+ s3transfer/copies.py
+ s3transfer/crt.py
s3transfer/delete.py
s3transfer/download.py
s3transfer/exceptions.py
s3transfer/futures.py
s3transfer/manager.py
- s3transfer/processpool.py
+ s3transfer/processpool.py
s3transfer/subscribers.py
s3transfer/tasks.py
s3transfer/upload.py
s3transfer/utils.py
)
-RESOURCE_FILES(
- PREFIX contrib/python/s3transfer/py3/
- .dist-info/METADATA
- .dist-info/top_level.txt
-)
-
+RESOURCE_FILES(
+ PREFIX contrib/python/s3transfer/py3/
+ .dist-info/METADATA
+ .dist-info/top_level.txt
+)
+
END()
-
-RECURSE_FOR_TESTS(
- tests
-)
+
+RECURSE_FOR_TESTS(
+ tests
+)