diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source')
94 files changed, 22595 insertions, 0 deletions
diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonSerializableWebServiceRequest.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonSerializableWebServiceRequest.cpp new file mode 100644 index 0000000000..0c401b01b2 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonSerializableWebServiceRequest.cpp @@ -0,0 +1,24 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/AmazonSerializableWebServiceRequest.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> + +using namespace Aws; + +std::shared_ptr<Aws::IOStream> AmazonSerializableWebServiceRequest::GetBody() const +{ + Aws::String&& payload = SerializePayload(); + std::shared_ptr<Aws::IOStream> payloadBody; + + if (!payload.empty()) + { + payloadBody = Aws::MakeShared<Aws::StringStream>("AmazonSerializableWebServiceRequest"); + *payloadBody << payload; + } + + return payloadBody; +} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonStreamingWebServiceRequest.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonStreamingWebServiceRequest.cpp new file mode 100644 index 0000000000..92e61c7ad4 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonStreamingWebServiceRequest.cpp @@ -0,0 +1,12 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/AmazonStreamingWebServiceRequest.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> + +using namespace Aws; + +AmazonStreamingWebServiceRequest::~AmazonStreamingWebServiceRequest() {} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonWebServiceRequest.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonWebServiceRequest.cpp new file mode 100644 index 0000000000..a6b0406683 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/AmazonWebServiceRequest.cpp @@ -0,0 +1,20 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/AmazonWebServiceRequest.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> + +using namespace Aws; + +AmazonWebServiceRequest::AmazonWebServiceRequest() : + m_responseStreamFactory(Aws::Utils::Stream::DefaultResponseStreamFactoryMethod), + m_onDataReceived(nullptr), + m_onDataSent(nullptr), + m_continueRequest(nullptr), + m_onRequestSigned(nullptr), + m_requestRetryHandler(nullptr) +{ +} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Aws.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Aws.cpp new file mode 100644 index 0000000000..1eaa477fca --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Aws.cpp @@ -0,0 +1,134 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/Version.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/Aws.h> +#include <aws/core/client/CoreErrors.h> +#include <aws/core/utils/logging/AWSLogging.h> +#include <aws/core/utils/logging/DefaultLogSystem.h> +#include <aws/core/Globals.h> +#include <aws/core/external/cjson/cJSON.h> +#include <aws/core/monitoring/MonitoringManager.h> +#include <aws/core/net/Net.h> +#include <aws/core/config/AWSProfileConfigLoader.h> +#include <aws/core/internal/AWSHttpResourceClient.h> + +namespace Aws +{ + static const char* ALLOCATION_TAG = "Aws_Init_Cleanup"; + + void InitAPI(const SDKOptions &options) + { +#ifdef USE_AWS_MEMORY_MANAGEMENT + if(options.memoryManagementOptions.memoryManager) + { + Aws::Utils::Memory::InitializeAWSMemorySystem(*options.memoryManagementOptions.memoryManager); + } +#endif // USE_AWS_MEMORY_MANAGEMENT + Aws::Client::CoreErrorsMapper::InitCoreErrorsMapper(); + if(options.loggingOptions.logLevel != Aws::Utils::Logging::LogLevel::Off) + { + if(options.loggingOptions.logger_create_fn) + { + Aws::Utils::Logging::InitializeAWSLogging(options.loggingOptions.logger_create_fn()); + } + else + { + Aws::Utils::Logging::InitializeAWSLogging( + Aws::MakeShared<Aws::Utils::Logging::DefaultLogSystem>(ALLOCATION_TAG, options.loggingOptions.logLevel, options.loggingOptions.defaultLogPrefix)); + } + // For users to better debugging in case multiple versions of SDK installed + AWS_LOGSTREAM_INFO(ALLOCATION_TAG, "Initiate AWS SDK for C++ with Version:" << Aws::String(Aws::Version::GetVersionString())); + } + + Aws::Config::InitConfigAndCredentialsCacheManager(); + + if (options.cryptoOptions.aes_CBCFactory_create_fn) + { + Aws::Utils::Crypto::SetAES_CBCFactory(options.cryptoOptions.aes_CBCFactory_create_fn()); + } + + if(options.cryptoOptions.aes_CTRFactory_create_fn) + { + Aws::Utils::Crypto::SetAES_CTRFactory(options.cryptoOptions.aes_CTRFactory_create_fn()); + } + + if(options.cryptoOptions.aes_GCMFactory_create_fn) + { + Aws::Utils::Crypto::SetAES_GCMFactory(options.cryptoOptions.aes_GCMFactory_create_fn()); + } + + if(options.cryptoOptions.md5Factory_create_fn) + { + Aws::Utils::Crypto::SetMD5Factory(options.cryptoOptions.md5Factory_create_fn()); + } + + if(options.cryptoOptions.sha256Factory_create_fn) + { + Aws::Utils::Crypto::SetSha256Factory(options.cryptoOptions.sha256Factory_create_fn()); + } + + if(options.cryptoOptions.sha256HMACFactory_create_fn) + { + Aws::Utils::Crypto::SetSha256HMACFactory(options.cryptoOptions.sha256HMACFactory_create_fn()); + } + + if (options.cryptoOptions.aes_KeyWrapFactory_create_fn) + { + Aws::Utils::Crypto::SetAES_KeyWrapFactory(options.cryptoOptions.aes_KeyWrapFactory_create_fn()); + } + + if(options.cryptoOptions.secureRandomFactory_create_fn) + { + Aws::Utils::Crypto::SetSecureRandomFactory(options.cryptoOptions.secureRandomFactory_create_fn()); + } + + Aws::Utils::Crypto::SetInitCleanupOpenSSLFlag(options.cryptoOptions.initAndCleanupOpenSSL); + Aws::Utils::Crypto::InitCrypto(); + + if(options.httpOptions.httpClientFactory_create_fn) + { + Aws::Http::SetHttpClientFactory(options.httpOptions.httpClientFactory_create_fn()); + } + + Aws::Http::SetInitCleanupCurlFlag(options.httpOptions.initAndCleanupCurl); + Aws::Http::SetInstallSigPipeHandlerFlag(options.httpOptions.installSigPipeHandler); + Aws::Http::InitHttp(); + Aws::InitializeEnumOverflowContainer(); + cJSON_Hooks hooks; + hooks.malloc_fn = [](size_t sz) { return Aws::Malloc("cJSON_Tag", sz); }; + hooks.free_fn = Aws::Free; + cJSON_InitHooks(&hooks); + Aws::Net::InitNetwork(); + Aws::Internal::InitEC2MetadataClient(); + Aws::Monitoring::InitMonitoring(options.monitoringOptions.customizedMonitoringFactory_create_fn); + } + + void ShutdownAPI(const SDKOptions& options) + { + Aws::Monitoring::CleanupMonitoring(); + Aws::Internal::CleanupEC2MetadataClient(); + Aws::Net::CleanupNetwork(); + Aws::CleanupEnumOverflowContainer(); + Aws::Http::CleanupHttp(); + Aws::Utils::Crypto::CleanupCrypto(); + + Aws::Config::CleanupConfigAndCredentialsCacheManager(); + + if(options.loggingOptions.logLevel != Aws::Utils::Logging::LogLevel::Off) + { + Aws::Utils::Logging::ShutdownAWSLogging(); + } + + Aws::Client::CoreErrorsMapper::CleanupCoreErrorsMapper(); + +#ifdef USE_AWS_MEMORY_MANAGEMENT + if(options.memoryManagementOptions.memoryManager) + { + Aws::Utils::Memory::ShutdownAWSMemorySystem(); + } +#endif // USE_AWS_MEMORY_MANAGEMENT + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Globals.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Globals.cpp new file mode 100644 index 0000000000..55f2ee9220 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Globals.cpp @@ -0,0 +1,28 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/Globals.h> +#include <aws/core/utils/EnumParseOverflowContainer.h> +#include <aws/core/utils/memory/AWSMemory.h> + +namespace Aws +{ + static const char TAG[] = "GlobalEnumOverflowContainer"; + static Utils::EnumParseOverflowContainer* g_enumOverflow; + + Utils::EnumParseOverflowContainer* GetEnumOverflowContainer() + { + return g_enumOverflow; + } + + void InitializeEnumOverflowContainer() + { + g_enumOverflow = Aws::New<Aws::Utils::EnumParseOverflowContainer>(TAG); + } + + void CleanupEnumOverflowContainer() + { + Aws::Delete(g_enumOverflow); + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Region.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Region.cpp new file mode 100644 index 0000000000..4b18bf2a2a --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Region.cpp @@ -0,0 +1,36 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/memory/stl/AWSString.h> +#include <aws/core/Region.h> +namespace Aws +{ + namespace Region + { + Aws::String ComputeSignerRegion(const Aws::String& region) + { + if (region == Aws::Region::AWS_GLOBAL) + { + return Aws::Region::US_EAST_1; + } + else if (region == "s3-external-1") + { + return Aws::Region::US_EAST_1; + } + else if (region.size() >= 5 && region.compare(0, 5, "fips-") == 0) + { + return region.substr(5); + } + else if (region.size() >= 5 && region.compare(region.size() - 5, 5, "-fips") == 0) + { + return region.substr(0, region.size() - 5); + } + else + { + return region; + } + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Version.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Version.cpp new file mode 100644 index 0000000000..35291906b7 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/Version.cpp @@ -0,0 +1,53 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/Version.h> +#include <aws/core/VersionConfig.h> + +namespace Aws +{ +namespace Version +{ + const char* GetVersionString() + { + return AWS_SDK_VERSION_STRING; + } + + unsigned GetVersionMajor() + { + return AWS_SDK_VERSION_MAJOR; + } + + unsigned GetVersionMinor() + { + return AWS_SDK_VERSION_MINOR; + } + + unsigned GetVersionPatch() + { + return AWS_SDK_VERSION_PATCH; + } + + + const char* GetCompilerVersionString() + { +#define xstr(s) str(s) +#define str(s) #s +#if defined(_MSC_VER) + return "MSVC/" xstr(_MSC_VER); +#elif defined(__clang__) + return "Clang/" xstr(__clang_major__) "." xstr(__clang_minor__) "." xstr(__clang_patchlevel__); +#elif defined(__GNUC__) + return "GCC/" xstr(__GNUC__) "." xstr(__GNUC_MINOR__) "." xstr(__GNUC_PATCHLEVEL__); +#else + return "UnknownCompiler"; +#endif +#undef str +#undef xstr + } +} //namespace Version +} //namespace Aws + + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSAuthSigner.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSAuthSigner.cpp new file mode 100644 index 0000000000..de4826fa5b --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSAuthSigner.cpp @@ -0,0 +1,806 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/auth/AWSAuthSigner.h> + +#include <aws/core/auth/AWSCredentialsProvider.h> +#include <aws/core/client/ClientConfiguration.h> +#include <aws/core/http/HttpRequest.h> +#include <aws/core/http/HttpResponse.h> +#include <aws/core/utils/DateTime.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/Outcome.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/memory/AWSMemory.h> +#include <aws/core/utils/crypto/Sha256.h> +#include <aws/core/utils/crypto/Sha256HMAC.h> +#include <aws/core/utils/stream/PreallocatedStreamBuf.h> +#include <aws/core/utils/event/EventMessage.h> +#include <aws/core/utils/event/EventHeader.h> + +#include <cstdio> +#include <iomanip> +#include <math.h> +#include <cstring> + +using namespace Aws; +using namespace Aws::Client; +using namespace Aws::Auth; +using namespace Aws::Http; +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; + +static const char* EQ = "="; +static const char* AWS_HMAC_SHA256 = "AWS4-HMAC-SHA256"; +static const char* EVENT_STREAM_CONTENT_SHA256 = "STREAMING-AWS4-HMAC-SHA256-EVENTS"; +static const char* EVENT_STREAM_PAYLOAD = "AWS4-HMAC-SHA256-PAYLOAD"; +static const char* AWS4_REQUEST = "aws4_request"; +static const char* SIGNED_HEADERS = "SignedHeaders"; +static const char* CREDENTIAL = "Credential"; +static const char* NEWLINE = "\n"; +static const char* X_AMZ_SIGNED_HEADERS = "X-Amz-SignedHeaders"; +static const char* X_AMZ_ALGORITHM = "X-Amz-Algorithm"; +static const char* X_AMZ_CREDENTIAL = "X-Amz-Credential"; +static const char* UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD"; +static const char* X_AMZ_SIGNATURE = "X-Amz-Signature"; +static const char* X_AMZN_TRACE_ID = "x-amzn-trace-id"; +static const char* X_AMZ_CONTENT_SHA256 = "x-amz-content-sha256"; +static const char* USER_AGENT = "user-agent"; +static const char* SIGNING_KEY = "AWS4"; +static const char* SIMPLE_DATE_FORMAT_STR = "%Y%m%d"; +static const char* EMPTY_STRING_SHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + +static const char v4LogTag[] = "AWSAuthV4Signer"; +static const char v4StreamingLogTag[] = "AWSAuthEventStreamV4Signer"; + +namespace Aws +{ + namespace Auth + { + const char SIGNATURE[] = "Signature"; + const char SIGV4_SIGNER[] = "SignatureV4"; + const char EVENTSTREAM_SIGV4_SIGNER[] = "EventStreamSignatureV4"; + const char EVENTSTREAM_SIGNATURE_HEADER[] = ":chunk-signature"; + const char EVENTSTREAM_DATE_HEADER[] = ":date"; + const char NULL_SIGNER[] = "NullSigner"; + } +} + +static Aws::String CanonicalizeRequestSigningString(HttpRequest& request, bool urlEscapePath) +{ + request.CanonicalizeRequest(); + Aws::StringStream signingStringStream; + signingStringStream << HttpMethodMapper::GetNameForHttpMethod(request.GetMethod()); + + URI uriCpy = request.GetUri(); + // Many AWS services do not decode the URL before calculating SignatureV4 on their end. + // This results in the signature getting calculated with a double encoded URL. + // That means we have to double encode it here for the signature to match on the service side. + if(urlEscapePath) + { + // RFC3986 is how we encode the URL before sending it on the wire. + auto rfc3986EncodedPath = URI::URLEncodePathRFC3986(uriCpy.GetPath()); + uriCpy.SetPath(rfc3986EncodedPath); + // However, SignatureV4 uses this URL encoding scheme + signingStringStream << NEWLINE << uriCpy.GetURLEncodedPath() << NEWLINE; + } + else + { + // For the services that DO decode the URL first; we don't need to double encode it. + uriCpy.SetPath(uriCpy.GetURLEncodedPath()); + signingStringStream << NEWLINE << uriCpy.GetPath() << NEWLINE; + } + + if (request.GetQueryString().find('=') != std::string::npos) + { + signingStringStream << request.GetQueryString().substr(1) << NEWLINE; + } + else if (request.GetQueryString().size() > 1) + { + signingStringStream << request.GetQueryString().substr(1) << "=" << NEWLINE; + } + else + { + signingStringStream << NEWLINE; + } + + return signingStringStream.str(); +} + +static Http::HeaderValueCollection CanonicalizeHeaders(Http::HeaderValueCollection&& headers) +{ + Http::HeaderValueCollection canonicalHeaders; + for (const auto& header : headers) + { + auto trimmedHeaderName = StringUtils::Trim(header.first.c_str()); + auto trimmedHeaderValue = StringUtils::Trim(header.second.c_str()); + + //multiline gets converted to line1,line2,etc... + auto headerMultiLine = StringUtils::SplitOnLine(trimmedHeaderValue); + Aws::String headerValue = headerMultiLine.size() == 0 ? "" : headerMultiLine[0]; + + if (headerMultiLine.size() > 1) + { + for(size_t i = 1; i < headerMultiLine.size(); ++i) + { + headerValue += ","; + headerValue += StringUtils::Trim(headerMultiLine[i].c_str()); + } + } + + //duplicate spaces need to be converted to one. + Aws::String::iterator new_end = + std::unique(headerValue.begin(), headerValue.end(), + [=](char lhs, char rhs) { return (lhs == rhs) && (lhs == ' '); } + ); + headerValue.erase(new_end, headerValue.end()); + + canonicalHeaders[trimmedHeaderName] = headerValue; + } + + return canonicalHeaders; +} + +AWSAuthV4Signer::AWSAuthV4Signer(const std::shared_ptr<Auth::AWSCredentialsProvider>& credentialsProvider, + const char* serviceName, const Aws::String& region, PayloadSigningPolicy signingPolicy, bool urlEscapePath) : + m_includeSha256HashHeader(true), + m_credentialsProvider(credentialsProvider), + m_serviceName(serviceName), + m_region(region), + m_hash(Aws::MakeUnique<Aws::Utils::Crypto::Sha256>(v4LogTag)), + m_HMAC(Aws::MakeUnique<Aws::Utils::Crypto::Sha256HMAC>(v4LogTag)), + m_unsignedHeaders({USER_AGENT, X_AMZN_TRACE_ID}), + m_payloadSigningPolicy(signingPolicy), + m_urlEscapePath(urlEscapePath) +{ + //go ahead and warm up the signing cache. + ComputeHash(credentialsProvider->GetAWSCredentials().GetAWSSecretKey(), DateTime::CalculateGmtTimestampAsString(SIMPLE_DATE_FORMAT_STR), region, m_serviceName); +} + +AWSAuthV4Signer::~AWSAuthV4Signer() +{ + // empty destructor in .cpp file to keep from needing the implementation of (AWSCredentialsProvider, Sha256, Sha256HMAC) in the header file +} + + +bool AWSAuthV4Signer::ShouldSignHeader(const Aws::String& header) const +{ + return m_unsignedHeaders.find(Aws::Utils::StringUtils::ToLower(header.c_str())) == m_unsignedHeaders.cend(); +} + +bool AWSAuthV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool signBody) const +{ + AWSCredentials credentials = m_credentialsProvider->GetAWSCredentials(); + + //don't sign anonymous requests + if (credentials.GetAWSAccessKeyId().empty() || credentials.GetAWSSecretKey().empty()) + { + return true; + } + + if (!credentials.GetSessionToken().empty()) + { + request.SetAwsSessionToken(credentials.GetSessionToken()); + } + + Aws::String payloadHash(UNSIGNED_PAYLOAD); + switch(m_payloadSigningPolicy) + { + case PayloadSigningPolicy::Always: + signBody = true; + break; + case PayloadSigningPolicy::Never: + signBody = false; + break; + case PayloadSigningPolicy::RequestDependent: + // respect the request setting + default: + break; + } + + if(signBody || request.GetUri().GetScheme() != Http::Scheme::HTTPS) + { + payloadHash = ComputePayloadHash(request); + if (payloadHash.empty()) + { + return false; + } + } + else + { + AWS_LOGSTREAM_DEBUG(v4LogTag, "Note: Http payloads are not being signed. signPayloads=" << signBody + << " http scheme=" << Http::SchemeMapper::ToString(request.GetUri().GetScheme())); + } + + if(m_includeSha256HashHeader) + { + request.SetHeaderValue(X_AMZ_CONTENT_SHA256, payloadHash); + } + + //calculate date header to use in internal signature (this also goes into date header). + DateTime now = GetSigningTimestamp(); + Aws::String dateHeaderValue = now.ToGmtString(DateFormat::ISO_8601_BASIC); + request.SetHeaderValue(AWS_DATE_HEADER, dateHeaderValue); + + Aws::StringStream headersStream; + Aws::StringStream signedHeadersStream; + + for (const auto& header : CanonicalizeHeaders(request.GetHeaders())) + { + if(ShouldSignHeader(header.first)) + { + headersStream << header.first.c_str() << ":" << header.second.c_str() << NEWLINE; + signedHeadersStream << header.first.c_str() << ";"; + } + } + + Aws::String canonicalHeadersString = headersStream.str(); + AWS_LOGSTREAM_DEBUG(v4LogTag, "Canonical Header String: " << canonicalHeadersString); + + //calculate signed headers parameter + Aws::String signedHeadersValue = signedHeadersStream.str(); + //remove that last semi-colon + if (!signedHeadersValue.empty()) + { + signedHeadersValue.pop_back(); + } + + AWS_LOGSTREAM_DEBUG(v4LogTag, "Signed Headers value:" << signedHeadersValue); + + //generate generalized canonicalized request string. + Aws::String canonicalRequestString = CanonicalizeRequestSigningString(request, m_urlEscapePath); + + //append v4 stuff to the canonical request string. + canonicalRequestString.append(canonicalHeadersString); + canonicalRequestString.append(NEWLINE); + canonicalRequestString.append(signedHeadersValue); + canonicalRequestString.append(NEWLINE); + canonicalRequestString.append(payloadHash); + + AWS_LOGSTREAM_DEBUG(v4LogTag, "Canonical Request String: " << canonicalRequestString); + + //now compute sha256 on that request string + auto hashResult = m_hash->Calculate(canonicalRequestString); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Failed to hash (sha256) request string"); + AWS_LOGSTREAM_DEBUG(v4LogTag, "The request string is: \"" << canonicalRequestString << "\""); + return false; + } + + auto sha256Digest = hashResult.GetResult(); + Aws::String canonicalRequestHash = HashingUtils::HexEncode(sha256Digest); + Aws::String simpleDate = now.ToGmtString(SIMPLE_DATE_FORMAT_STR); + + Aws::String signingRegion = region ? region : m_region; + Aws::String signingServiceName = serviceName ? serviceName : m_serviceName; + Aws::String stringToSign = GenerateStringToSign(dateHeaderValue, simpleDate, canonicalRequestHash, signingRegion, signingServiceName); + auto finalSignature = GenerateSignature(credentials, stringToSign, simpleDate, signingRegion, signingServiceName); + + Aws::StringStream ss; + ss << AWS_HMAC_SHA256 << " " << CREDENTIAL << EQ << credentials.GetAWSAccessKeyId() << "/" << simpleDate + << "/" << signingRegion << "/" << signingServiceName << "/" << AWS4_REQUEST << ", " << SIGNED_HEADERS << EQ + << signedHeadersValue << ", " << SIGNATURE << EQ << finalSignature; + + auto awsAuthString = ss.str(); + AWS_LOGSTREAM_DEBUG(v4LogTag, "Signing request with: " << awsAuthString); + request.SetAwsAuthorization(awsAuthString); + request.SetSigningAccessKey(credentials.GetAWSAccessKeyId()); + request.SetSigningRegion(signingRegion); + return true; +} + +bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, long long expirationTimeInSeconds) const +{ + return PresignRequest(request, m_region.c_str(), expirationTimeInSeconds); +} + +bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, const char* region, long long expirationInSeconds) const +{ + return PresignRequest(request, region, m_serviceName.c_str(), expirationInSeconds); +} + +bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, long long expirationTimeInSeconds) const +{ + AWSCredentials credentials = m_credentialsProvider->GetAWSCredentials(); + + //don't sign anonymous requests + if (credentials.GetAWSAccessKeyId().empty() || credentials.GetAWSSecretKey().empty()) + { + return true; + } + + Aws::StringStream intConversionStream; + intConversionStream << expirationTimeInSeconds; + request.AddQueryStringParameter(Http::X_AMZ_EXPIRES_HEADER, intConversionStream.str()); + + if (!credentials.GetSessionToken().empty()) + { + request.AddQueryStringParameter(Http::AWS_SECURITY_TOKEN, credentials.GetSessionToken()); + } + + //calculate date header to use in internal signature (this also goes into date header). + DateTime now = GetSigningTimestamp(); + Aws::String dateQueryValue = now.ToGmtString(DateFormat::ISO_8601_BASIC); + request.AddQueryStringParameter(Http::AWS_DATE_HEADER, dateQueryValue); + + Aws::StringStream headersStream; + Aws::StringStream signedHeadersStream; + for (const auto& header : CanonicalizeHeaders(request.GetHeaders())) + { + if(ShouldSignHeader(header.first)) + { + headersStream << header.first.c_str() << ":" << header.second.c_str() << NEWLINE; + signedHeadersStream << header.first.c_str() << ";"; + } + } + + Aws::String canonicalHeadersString = headersStream.str(); + AWS_LOGSTREAM_DEBUG(v4LogTag, "Canonical Header String: " << canonicalHeadersString); + + //calculate signed headers parameter + Aws::String signedHeadersValue(signedHeadersStream.str()); + //remove that last semi-colon + if (!signedHeadersValue.empty()) + { + signedHeadersValue.pop_back(); + } + + request.AddQueryStringParameter(X_AMZ_SIGNED_HEADERS, signedHeadersValue); + AWS_LOGSTREAM_DEBUG(v4LogTag, "Signed Headers value: " << signedHeadersValue); + + Aws::StringStream ss; + Aws::String signingRegion = region ? region : m_region; + Aws::String signingServiceName = serviceName ? serviceName : m_serviceName; + Aws::String simpleDate = now.ToGmtString(SIMPLE_DATE_FORMAT_STR); + ss << credentials.GetAWSAccessKeyId() << "/" << simpleDate + << "/" << signingRegion << "/" << signingServiceName << "/" << AWS4_REQUEST; + + request.AddQueryStringParameter(X_AMZ_ALGORITHM, AWS_HMAC_SHA256); + request.AddQueryStringParameter(X_AMZ_CREDENTIAL, ss.str()); + ss.str(""); + + request.SetSigningAccessKey(credentials.GetAWSAccessKeyId()); + request.SetSigningRegion(signingRegion); + + //generate generalized canonicalized request string. + Aws::String canonicalRequestString = CanonicalizeRequestSigningString(request, m_urlEscapePath); + + //append v4 stuff to the canonical request string. + canonicalRequestString.append(canonicalHeadersString); + canonicalRequestString.append(NEWLINE); + canonicalRequestString.append(signedHeadersValue); + canonicalRequestString.append(NEWLINE); + if (ServiceRequireUnsignedPayload(signingServiceName)) + { + canonicalRequestString.append(UNSIGNED_PAYLOAD); + } + else + { + canonicalRequestString.append(EMPTY_STRING_SHA256); + } + AWS_LOGSTREAM_DEBUG(v4LogTag, "Canonical Request String: " << canonicalRequestString); + + //now compute sha256 on that request string + auto hashResult = m_hash->Calculate(canonicalRequestString); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Failed to hash (sha256) request string"); + AWS_LOGSTREAM_DEBUG(v4LogTag, "The request string is: \"" << canonicalRequestString << "\""); + return false; + } + + auto sha256Digest = hashResult.GetResult(); + auto canonicalRequestHash = HashingUtils::HexEncode(sha256Digest); + + auto stringToSign = GenerateStringToSign(dateQueryValue, simpleDate, canonicalRequestHash, signingRegion, signingServiceName); + auto finalSigningHash = GenerateSignature(credentials, stringToSign, simpleDate, signingRegion, signingServiceName); + if (finalSigningHash.empty()) + { + return false; + } + + //add that the signature to the query string + request.AddQueryStringParameter(X_AMZ_SIGNATURE, finalSigningHash); + + return true; +} + +bool AWSAuthV4Signer::ServiceRequireUnsignedPayload(const Aws::String& serviceName) const +{ + // S3 uses a magic string (instead of the empty string) for its body hash for presigned URLs as outlined here: + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + // this is true for PUT, POST, GET, DELETE and HEAD operations. + // However, other services (for example RDS) implement the specification as outlined here: + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + // which states that body-less requests should use the empty-string SHA256 hash. + return "s3" == serviceName; +} + +Aws::String AWSAuthV4Signer::GenerateSignature(const AWSCredentials& credentials, const Aws::String& stringToSign, + const Aws::String& simpleDate, const Aws::String& region, const Aws::String& serviceName) const +{ + auto key = ComputeHash(credentials.GetAWSSecretKey(), simpleDate, region, serviceName); + return GenerateSignature(stringToSign, key); +} + +Aws::String AWSAuthV4Signer::GenerateSignature(const Aws::String& stringToSign, const ByteBuffer& key) const +{ + AWS_LOGSTREAM_DEBUG(v4LogTag, "Final String to sign: " << stringToSign); + + Aws::StringStream ss; + + auto hashResult = m_HMAC->Calculate(ByteBuffer((unsigned char*)stringToSign.c_str(), stringToSign.length()), key); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Unable to hmac (sha256) final string"); + AWS_LOGSTREAM_DEBUG(v4LogTag, "The final string is: \"" << stringToSign << "\""); + return {}; + } + + //now we finally sign our request string with our hex encoded derived hash. + auto finalSigningDigest = hashResult.GetResult(); + + auto finalSigningHash = HashingUtils::HexEncode(finalSigningDigest); + AWS_LOGSTREAM_DEBUG(v4LogTag, "Final computed signing hash: " << finalSigningHash); + + return finalSigningHash; +} + +Aws::String AWSAuthV4Signer::ComputePayloadHash(Aws::Http::HttpRequest& request) const +{ + if (!request.GetContentBody()) + { + AWS_LOGSTREAM_DEBUG(v4LogTag, "Using cached empty string sha256 " << EMPTY_STRING_SHA256 << " because payload is empty."); + return EMPTY_STRING_SHA256; + } + + //compute hash on payload if it exists. + auto hashResult = m_hash->Calculate(*request.GetContentBody()); + + if(request.GetContentBody()) + { + request.GetContentBody()->clear(); + request.GetContentBody()->seekg(0); + } + + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Unable to hash (sha256) request body"); + return {}; + } + + auto sha256Digest = hashResult.GetResult(); + + Aws::String payloadHash(HashingUtils::HexEncode(sha256Digest)); + AWS_LOGSTREAM_DEBUG(v4LogTag, "Calculated sha256 " << payloadHash << " for payload."); + return payloadHash; +} + +Aws::String AWSAuthV4Signer::GenerateStringToSign(const Aws::String& dateValue, const Aws::String& simpleDate, + const Aws::String& canonicalRequestHash, const Aws::String& region, const Aws::String& serviceName) const +{ + //generate the actual string we will use in signing the final request. + Aws::StringStream ss; + + ss << AWS_HMAC_SHA256 << NEWLINE << dateValue << NEWLINE << simpleDate << "/" << region << "/" + << serviceName << "/" << AWS4_REQUEST << NEWLINE << canonicalRequestHash; + + return ss.str(); +} + +Aws::Utils::ByteBuffer AWSAuthV4Signer::ComputeHash(const Aws::String& secretKey, + const Aws::String& simpleDate, const Aws::String& region, const Aws::String& serviceName) const +{ + Aws::String signingKey(SIGNING_KEY); + signingKey.append(secretKey); + auto hashResult = m_HMAC->Calculate(ByteBuffer((unsigned char*)simpleDate.c_str(), simpleDate.length()), + ByteBuffer((unsigned char*)signingKey.c_str(), signingKey.length())); + + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Failed to HMAC (SHA256) date string \"" << simpleDate << "\""); + return {}; + } + + auto kDate = hashResult.GetResult(); + hashResult = m_HMAC->Calculate(ByteBuffer((unsigned char*)region.c_str(), region.length()), kDate); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Failed to HMAC (SHA256) region string \"" << region << "\""); + return {}; + } + + auto kRegion = hashResult.GetResult(); + hashResult = m_HMAC->Calculate(ByteBuffer((unsigned char*)serviceName.c_str(), serviceName.length()), kRegion); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Failed to HMAC (SHA256) service string \"" << m_serviceName << "\""); + return {}; + } + + auto kService = hashResult.GetResult(); + hashResult = m_HMAC->Calculate(ByteBuffer((unsigned char*)AWS4_REQUEST, strlen(AWS4_REQUEST)), kService); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4LogTag, "Unable to HMAC (SHA256) request string"); + AWS_LOGSTREAM_DEBUG(v4LogTag, "The request string is: \"" << AWS4_REQUEST << "\""); + return {}; + } + return hashResult.GetResult(); +} + +AWSAuthEventStreamV4Signer::AWSAuthEventStreamV4Signer(const std::shared_ptr<Auth::AWSCredentialsProvider>& + credentialsProvider, const char* serviceName, const Aws::String& region) : + m_serviceName(serviceName), + m_region(region), + m_credentialsProvider(credentialsProvider) +{ + + m_unsignedHeaders.emplace_back(X_AMZN_TRACE_ID); + m_unsignedHeaders.emplace_back(USER_AGENT_HEADER); +} + +bool AWSAuthEventStreamV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool /* signBody */) const +{ + AWSCredentials credentials = m_credentialsProvider->GetAWSCredentials(); + + //don't sign anonymous requests + if (credentials.GetAWSAccessKeyId().empty() || credentials.GetAWSSecretKey().empty()) + { + return true; + } + + if (!credentials.GetSessionToken().empty()) + { + request.SetAwsSessionToken(credentials.GetSessionToken()); + } + + request.SetHeaderValue(X_AMZ_CONTENT_SHA256, EVENT_STREAM_CONTENT_SHA256); + + //calculate date header to use in internal signature (this also goes into date header). + DateTime now = GetSigningTimestamp(); + Aws::String dateHeaderValue = now.ToGmtString(DateFormat::ISO_8601_BASIC); + request.SetHeaderValue(AWS_DATE_HEADER, dateHeaderValue); + + Aws::StringStream headersStream; + Aws::StringStream signedHeadersStream; + + for (const auto& header : CanonicalizeHeaders(request.GetHeaders())) + { + if(ShouldSignHeader(header.first)) + { + headersStream << header.first.c_str() << ":" << header.second.c_str() << NEWLINE; + signedHeadersStream << header.first.c_str() << ";"; + } + } + + Aws::String canonicalHeadersString = headersStream.str(); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Canonical Header String: " << canonicalHeadersString); + + //calculate signed headers parameter + Aws::String signedHeadersValue = signedHeadersStream.str(); + //remove that last semi-colon + if (!signedHeadersValue.empty()) + { + signedHeadersValue.pop_back(); + } + + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Signed Headers value:" << signedHeadersValue); + + //generate generalized canonicalized request string. + Aws::String canonicalRequestString = CanonicalizeRequestSigningString(request, true/* m_urlEscapePath */); + + //append v4 stuff to the canonical request string. + canonicalRequestString.append(canonicalHeadersString); + canonicalRequestString.append(NEWLINE); + canonicalRequestString.append(signedHeadersValue); + canonicalRequestString.append(NEWLINE); + canonicalRequestString.append(EVENT_STREAM_CONTENT_SHA256); + + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Canonical Request String: " << canonicalRequestString); + + //now compute sha256 on that request string + auto hashResult = m_hash.Calculate(canonicalRequestString); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) request string"); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "The request string is: \"" << canonicalRequestString << "\""); + return false; + } + + auto sha256Digest = hashResult.GetResult(); + Aws::String canonicalRequestHash = HashingUtils::HexEncode(sha256Digest); + Aws::String simpleDate = now.ToGmtString(SIMPLE_DATE_FORMAT_STR); + + Aws::String signingRegion = region ? region : m_region; + Aws::String signingServiceName = serviceName ? serviceName : m_serviceName; + Aws::String stringToSign = GenerateStringToSign(dateHeaderValue, simpleDate, canonicalRequestHash, signingRegion, signingServiceName); + auto finalSignature = GenerateSignature(credentials, stringToSign, simpleDate, signingRegion, signingServiceName); + + Aws::StringStream ss; + ss << AWS_HMAC_SHA256 << " " << CREDENTIAL << EQ << credentials.GetAWSAccessKeyId() << "/" << simpleDate + << "/" << signingRegion << "/" << signingServiceName << "/" << AWS4_REQUEST << ", " << SIGNED_HEADERS << EQ + << signedHeadersValue << ", " << SIGNATURE << EQ << HashingUtils::HexEncode(finalSignature); + + auto awsAuthString = ss.str(); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Signing request with: " << awsAuthString); + request.SetAwsAuthorization(awsAuthString); + request.SetSigningAccessKey(credentials.GetAWSAccessKeyId()); + request.SetSigningRegion(signingRegion); + return true; +} + +// this works regardless if the current machine is Big/Little Endian +static void WriteBigEndian(Aws::String& str, uint64_t n) +{ + int shift = 56; + while(shift >= 0) + { + str.push_back((n >> shift) & 0xFF); + shift -= 8; + } +} + +bool AWSAuthEventStreamV4Signer::SignEventMessage(Event::Message& message, Aws::String& priorSignature) const +{ + using Event::EventHeaderValue; + + Aws::StringStream stringToSign; + stringToSign << EVENT_STREAM_PAYLOAD << NEWLINE; + const DateTime now = GetSigningTimestamp(); + const auto simpleDate = now.ToGmtString(SIMPLE_DATE_FORMAT_STR); + stringToSign << now.ToGmtString(DateFormat::ISO_8601_BASIC) << NEWLINE + << simpleDate << "/" << m_region << "/" + << m_serviceName << "/aws4_request" << NEWLINE << priorSignature << NEWLINE; + + + Aws::String nonSignatureHeaders; + nonSignatureHeaders.push_back(char(sizeof(EVENTSTREAM_DATE_HEADER) - 1)); // length of the string + nonSignatureHeaders += EVENTSTREAM_DATE_HEADER; + nonSignatureHeaders.push_back(static_cast<char>(EventHeaderValue::EventHeaderType::TIMESTAMP)); // type of the value + WriteBigEndian(nonSignatureHeaders, static_cast<uint64_t>(now.Millis())); // the value of the timestamp in big-endian + + auto hashOutcome = m_hash.Calculate(nonSignatureHeaders); + if (!hashOutcome.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) non-signature headers."); + return false; + } + + const auto nonSignatureHeadersHash = hashOutcome.GetResult(); + stringToSign << HashingUtils::HexEncode(nonSignatureHeadersHash) << NEWLINE; + + if (message.GetEventPayload().empty()) + { + AWS_LOGSTREAM_WARN(v4StreamingLogTag, "Attempting to sign an empty message (no payload and no headers). " + "It is unlikely that this is the intended behavior."); + } + else + { + // use a preallocatedStreamBuf to avoid making a copy. + // The Hashing API requires either Aws::String or IStream as input. + // TODO: the hashing API should be accept 'unsigned char*' as input. + Utils::Stream::PreallocatedStreamBuf streamBuf(message.GetEventPayload().data(), message.GetEventPayload().size()); + Aws::IOStream payload(&streamBuf); + hashOutcome = m_hash.Calculate(payload); + + if (!hashOutcome.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to hash (sha256) non-signature headers."); + return false; + } + const auto payloadHash = hashOutcome.GetResult(); + stringToSign << HashingUtils::HexEncode(payloadHash); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Payload hash - " << HashingUtils::HexEncode(payloadHash)); + } + + Utils::ByteBuffer finalSignatureDigest = GenerateSignature(m_credentialsProvider->GetAWSCredentials(), stringToSign.str(), simpleDate, m_region, m_serviceName); + const auto finalSignature = HashingUtils::HexEncode(finalSignatureDigest); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Final computed signing hash: " << finalSignature); + priorSignature = finalSignature; + + message.InsertEventHeader(EVENTSTREAM_DATE_HEADER, EventHeaderValue(now.Millis(), EventHeaderValue::EventHeaderType::TIMESTAMP)); + message.InsertEventHeader(EVENTSTREAM_SIGNATURE_HEADER, std::move(finalSignatureDigest)); + + AWS_LOGSTREAM_INFO(v4StreamingLogTag, "Event chunk final signature - " << finalSignature); + return true; +} + +bool AWSAuthEventStreamV4Signer::ShouldSignHeader(const Aws::String& header) const +{ + return std::find(m_unsignedHeaders.cbegin(), m_unsignedHeaders.cend(), Aws::Utils::StringUtils::ToLower(header.c_str())) == m_unsignedHeaders.cend(); +} + +Utils::ByteBuffer AWSAuthEventStreamV4Signer::GenerateSignature(const AWSCredentials& credentials, const Aws::String& stringToSign, + const Aws::String& simpleDate, const Aws::String& region, const Aws::String& serviceName) const +{ + Utils::Threading::ReaderLockGuard guard(m_derivedKeyLock); + const auto& secretKey = credentials.GetAWSSecretKey(); + if (secretKey != m_currentSecretKey || simpleDate != m_currentDateStr) + { + guard.UpgradeToWriterLock(); + // double-checked lock to prevent updating twice + if (m_currentDateStr != simpleDate || m_currentSecretKey != secretKey) + { + m_currentSecretKey = secretKey; + m_currentDateStr = simpleDate; + m_derivedKey = ComputeHash(m_currentSecretKey, m_currentDateStr, region, serviceName); + } + + } + return GenerateSignature(stringToSign, m_derivedKey); +} + +Utils::ByteBuffer AWSAuthEventStreamV4Signer::GenerateSignature(const Aws::String& stringToSign, const ByteBuffer& key) const +{ + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "Final String to sign: " << stringToSign); + + Aws::StringStream ss; + + auto hashResult = m_HMAC.Calculate(ByteBuffer((unsigned char*)stringToSign.c_str(), stringToSign.length()), key); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Unable to hmac (sha256) final string"); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "The final string is: \"" << stringToSign << "\""); + return {}; + } + + return hashResult.GetResult(); +} + +Aws::String AWSAuthEventStreamV4Signer::GenerateStringToSign(const Aws::String& dateValue, const Aws::String& simpleDate, + const Aws::String& canonicalRequestHash, const Aws::String& region, const Aws::String& serviceName) const +{ + //generate the actual string we will use in signing the final request. + Aws::StringStream ss; + + ss << AWS_HMAC_SHA256 << NEWLINE << dateValue << NEWLINE << simpleDate << "/" << region << "/" + << serviceName << "/" << AWS4_REQUEST << NEWLINE << canonicalRequestHash; + + return ss.str(); +} + +Aws::Utils::ByteBuffer AWSAuthEventStreamV4Signer::ComputeHash(const Aws::String& secretKey, + const Aws::String& simpleDate, const Aws::String& region, const Aws::String& serviceName) const +{ + Aws::String signingKey(SIGNING_KEY); + signingKey.append(secretKey); + auto hashResult = m_HMAC.Calculate(ByteBuffer((unsigned char*)simpleDate.c_str(), simpleDate.length()), + ByteBuffer((unsigned char*)signingKey.c_str(), signingKey.length())); + + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to HMAC (SHA256) date string \"" << simpleDate << "\""); + return {}; + } + + auto kDate = hashResult.GetResult(); + hashResult = m_HMAC.Calculate(ByteBuffer((unsigned char*)region.c_str(), region.length()), kDate); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to HMAC (SHA256) region string \"" << region << "\""); + return {}; + } + + auto kRegion = hashResult.GetResult(); + hashResult = m_HMAC.Calculate(ByteBuffer((unsigned char*)serviceName.c_str(), serviceName.length()), kRegion); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Failed to HMAC (SHA256) service string \"" << m_serviceName << "\""); + return {}; + } + + auto kService = hashResult.GetResult(); + hashResult = m_HMAC.Calculate(ByteBuffer((unsigned char*)AWS4_REQUEST, strlen(AWS4_REQUEST)), kService); + if (!hashResult.IsSuccess()) + { + AWS_LOGSTREAM_ERROR(v4StreamingLogTag, "Unable to HMAC (SHA256) request string"); + AWS_LOGSTREAM_DEBUG(v4StreamingLogTag, "The request string is: \"" << AWS4_REQUEST << "\""); + return {}; + } + return hashResult.GetResult(); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSAuthSignerProvider.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSAuthSignerProvider.cpp new file mode 100644 index 0000000000..31fd6c006b --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSAuthSignerProvider.cpp @@ -0,0 +1,51 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/auth/AWSAuthSignerProvider.h> +#include <aws/core/auth/AWSAuthSigner.h> +#include <aws/core/auth/AWSCredentialsProvider.h> +#include <aws/core/utils/memory/stl/AWSAllocator.h> + +const char CLASS_TAG[] = "AuthSignerProvider"; + +using namespace Aws::Auth; + +DefaultAuthSignerProvider::DefaultAuthSignerProvider(const std::shared_ptr<AWSCredentialsProvider>& credentialsProvider, + const Aws::String& serviceName, const Aws::String& region) +{ + m_signers.emplace_back(Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(CLASS_TAG, credentialsProvider, serviceName.c_str(), region)); + m_signers.emplace_back(Aws::MakeShared<Aws::Client::AWSAuthEventStreamV4Signer>(CLASS_TAG, credentialsProvider, serviceName.c_str(), region)); + m_signers.emplace_back(Aws::MakeShared<Aws::Client::AWSNullSigner>(CLASS_TAG)); +} + +DefaultAuthSignerProvider::DefaultAuthSignerProvider(const std::shared_ptr<Aws::Client::AWSAuthSigner>& signer) +{ + m_signers.emplace_back(Aws::MakeShared<Aws::Client::AWSNullSigner>(CLASS_TAG)); + if(signer) + { + m_signers.emplace_back(signer); + } +} + +std::shared_ptr<Aws::Client::AWSAuthSigner> DefaultAuthSignerProvider::GetSigner(const Aws::String& signerName) const +{ + for(const auto& signer : m_signers) + { + if(signer->GetName() == signerName) + { + return signer; + } + } + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Request's signer: '" << signerName << "' is not found in the signer's map."); + assert(false); + return nullptr; +} + +void DefaultAuthSignerProvider::AddSigner(std::shared_ptr<Aws::Client::AWSAuthSigner>& signer) +{ + assert(signer); + m_signers.emplace_back(signer); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp new file mode 100644 index 0000000000..31e28b996f --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp @@ -0,0 +1,466 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/auth/AWSCredentialsProvider.h> + +#include <aws/core/config/AWSProfileConfigLoader.h> +#include <aws/core/platform/Environment.h> +#include <aws/core/platform/FileSystem.h> +#include <aws/core/platform/OSVersionInfo.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/json/JsonSerializer.h> +#include <aws/core/utils/FileSystemUtils.h> +#include <aws/core/client/AWSError.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/xml/XmlSerializer.h> +#include <cstdlib> +#include <fstream> +#include <string.h> +#include <climits> + + +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; +using namespace Aws::Auth; +using namespace Aws::Internal; +using namespace Aws::FileSystem; +using namespace Aws::Utils::Xml; +using namespace Aws::Client; +using Aws::Utils::Threading::ReaderLockGuard; +using Aws::Utils::Threading::WriterLockGuard; + +static const char ACCESS_KEY_ENV_VAR[] = "AWS_ACCESS_KEY_ID"; +static const char SECRET_KEY_ENV_VAR[] = "AWS_SECRET_ACCESS_KEY"; +static const char SESSION_TOKEN_ENV_VAR[] = "AWS_SESSION_TOKEN"; +static const char DEFAULT_PROFILE[] = "default"; +static const char AWS_PROFILE_ENV_VAR[] = "AWS_PROFILE"; +static const char AWS_PROFILE_DEFAULT_ENV_VAR[] = "AWS_DEFAULT_PROFILE"; + +static const char AWS_CREDENTIALS_FILE[] = "AWS_SHARED_CREDENTIALS_FILE"; +extern const char AWS_CONFIG_FILE[] = "AWS_CONFIG_FILE"; + +extern const char PROFILE_DIRECTORY[] = ".aws"; +static const char DEFAULT_CREDENTIALS_FILE[] = "credentials"; +extern const char DEFAULT_CONFIG_FILE[] = "config"; + + +static const int EXPIRATION_GRACE_PERIOD = 5 * 1000; + +void AWSCredentialsProvider::Reload() +{ + m_lastLoadedMs = DateTime::Now().Millis(); +} + +bool AWSCredentialsProvider::IsTimeToRefresh(long reloadFrequency) +{ + if (DateTime::Now().Millis() - m_lastLoadedMs > reloadFrequency) + { + return true; + } + return false; +} + + +static const char* ENVIRONMENT_LOG_TAG = "EnvironmentAWSCredentialsProvider"; + + +AWSCredentials EnvironmentAWSCredentialsProvider::GetAWSCredentials() +{ + auto accessKey = Aws::Environment::GetEnv(ACCESS_KEY_ENV_VAR); + AWSCredentials credentials; + + if (!accessKey.empty()) + { + credentials.SetAWSAccessKeyId(accessKey); + + AWS_LOGSTREAM_DEBUG(ENVIRONMENT_LOG_TAG, "Found credential in environment with access key id " << accessKey); + auto secretKey = Aws::Environment::GetEnv(SECRET_KEY_ENV_VAR); + + if (!secretKey.empty()) + { + credentials.SetAWSSecretKey(secretKey); + AWS_LOGSTREAM_INFO(ENVIRONMENT_LOG_TAG, "Found secret key"); + } + + auto sessionToken = Aws::Environment::GetEnv(SESSION_TOKEN_ENV_VAR); + + if(!sessionToken.empty()) + { + credentials.SetSessionToken(sessionToken); + AWS_LOGSTREAM_INFO(ENVIRONMENT_LOG_TAG, "Found sessionToken"); + } + } + + return credentials; +} + +Aws::String Aws::Auth::GetConfigProfileFilename() +{ + auto configFileNameFromVar = Aws::Environment::GetEnv(AWS_CONFIG_FILE); + if (!configFileNameFromVar.empty()) + { + return configFileNameFromVar; + } + else + { + return Aws::FileSystem::GetHomeDirectory() + PROFILE_DIRECTORY + PATH_DELIM + DEFAULT_CONFIG_FILE; + } +} + +Aws::String Aws::Auth::GetConfigProfileName() +{ + auto profileFromVar = Aws::Environment::GetEnv(AWS_PROFILE_DEFAULT_ENV_VAR); + if (profileFromVar.empty()) + { + profileFromVar = Aws::Environment::GetEnv(AWS_PROFILE_ENV_VAR); + } + + if (profileFromVar.empty()) + { + return Aws::String(DEFAULT_PROFILE); + } + else + { + return profileFromVar; + } +} + +static const char* PROFILE_LOG_TAG = "ProfileConfigFileAWSCredentialsProvider"; + +Aws::String ProfileConfigFileAWSCredentialsProvider::GetCredentialsProfileFilename() +{ + auto credentialsFileNameFromVar = Aws::Environment::GetEnv(AWS_CREDENTIALS_FILE); + + if (credentialsFileNameFromVar.empty()) + { + return Aws::FileSystem::GetHomeDirectory() + PROFILE_DIRECTORY + PATH_DELIM + DEFAULT_CREDENTIALS_FILE; + } + else + { + return credentialsFileNameFromVar; + } +} + +Aws::String ProfileConfigFileAWSCredentialsProvider::GetProfileDirectory() +{ + Aws::String credentialsFileName = GetCredentialsProfileFilename(); + auto lastSeparator = credentialsFileName.find_last_of(PATH_DELIM); + if (lastSeparator != std::string::npos) + { + return credentialsFileName.substr(0, lastSeparator); + } + else + { + return {}; + } +} + +ProfileConfigFileAWSCredentialsProvider::ProfileConfigFileAWSCredentialsProvider(long refreshRateMs) : + m_profileToUse(Aws::Auth::GetConfigProfileName()), + m_credentialsFileLoader(GetCredentialsProfileFilename()), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(PROFILE_LOG_TAG, "Setting provider to read credentials from " << GetCredentialsProfileFilename() << " for credentials file" + << " and " << GetConfigProfileFilename() << " for the config file " + << ", for use with profile " << m_profileToUse); +} + +ProfileConfigFileAWSCredentialsProvider::ProfileConfigFileAWSCredentialsProvider(const char* profile, long refreshRateMs) : + m_profileToUse(profile), + m_credentialsFileLoader(GetCredentialsProfileFilename()), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(PROFILE_LOG_TAG, "Setting provider to read credentials from " << GetCredentialsProfileFilename() << " for credentials file" + << " and " << GetConfigProfileFilename() << " for the config file " + << ", for use with profile " << m_profileToUse); +} + +AWSCredentials ProfileConfigFileAWSCredentialsProvider::GetAWSCredentials() +{ + RefreshIfExpired(); + ReaderLockGuard guard(m_reloadLock); + auto credsFileProfileIter = m_credentialsFileLoader.GetProfiles().find(m_profileToUse); + + if(credsFileProfileIter != m_credentialsFileLoader.GetProfiles().end()) + { + return credsFileProfileIter->second.GetCredentials(); + } + + return AWSCredentials(); +} + + +void ProfileConfigFileAWSCredentialsProvider::Reload() +{ + m_credentialsFileLoader.Load(); + AWSCredentialsProvider::Reload(); +} + +void ProfileConfigFileAWSCredentialsProvider::RefreshIfExpired() +{ + ReaderLockGuard guard(m_reloadLock); + if (!IsTimeToRefresh(m_loadFrequencyMs)) + { + return; + } + + guard.UpgradeToWriterLock(); + if (!IsTimeToRefresh(m_loadFrequencyMs)) // double-checked lock to avoid refreshing twice + { + return; + } + + Reload(); +} + +static const char* INSTANCE_LOG_TAG = "InstanceProfileCredentialsProvider"; + +InstanceProfileCredentialsProvider::InstanceProfileCredentialsProvider(long refreshRateMs) : + m_ec2MetadataConfigLoader(Aws::MakeShared<Aws::Config::EC2InstanceProfileConfigLoader>(INSTANCE_LOG_TAG)), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(INSTANCE_LOG_TAG, "Creating Instance with default EC2MetadataClient and refresh rate " << refreshRateMs); +} + + +InstanceProfileCredentialsProvider::InstanceProfileCredentialsProvider(const std::shared_ptr<Aws::Config::EC2InstanceProfileConfigLoader>& loader, long refreshRateMs) : + m_ec2MetadataConfigLoader(loader), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(INSTANCE_LOG_TAG, "Creating Instance with injected EC2MetadataClient and refresh rate " << refreshRateMs); +} + + +AWSCredentials InstanceProfileCredentialsProvider::GetAWSCredentials() +{ + RefreshIfExpired(); + ReaderLockGuard guard(m_reloadLock); + auto profileIter = m_ec2MetadataConfigLoader->GetProfiles().find(Aws::Config::INSTANCE_PROFILE_KEY); + + if(profileIter != m_ec2MetadataConfigLoader->GetProfiles().end()) + { + return profileIter->second.GetCredentials(); + } + + return AWSCredentials(); +} + +void InstanceProfileCredentialsProvider::Reload() +{ + AWS_LOGSTREAM_INFO(INSTANCE_LOG_TAG, "Credentials have expired attempting to repull from EC2 Metadata Service."); + m_ec2MetadataConfigLoader->Load(); + AWSCredentialsProvider::Reload(); +} + +void InstanceProfileCredentialsProvider::RefreshIfExpired() +{ + AWS_LOGSTREAM_DEBUG(INSTANCE_LOG_TAG, "Checking if latest credential pull has expired."); + ReaderLockGuard guard(m_reloadLock); + if (!IsTimeToRefresh(m_loadFrequencyMs)) + { + return; + } + + guard.UpgradeToWriterLock(); + if (!IsTimeToRefresh(m_loadFrequencyMs)) // double-checked lock to avoid refreshing twice + { + return; + } + Reload(); +} + +static const char TASK_ROLE_LOG_TAG[] = "TaskRoleCredentialsProvider"; + +TaskRoleCredentialsProvider::TaskRoleCredentialsProvider(const char* URI, long refreshRateMs) : + m_ecsCredentialsClient(Aws::MakeShared<Aws::Internal::ECSCredentialsClient>(TASK_ROLE_LOG_TAG, URI)), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(TASK_ROLE_LOG_TAG, "Creating TaskRole with default ECSCredentialsClient and refresh rate " << refreshRateMs); +} + +TaskRoleCredentialsProvider::TaskRoleCredentialsProvider(const char* endpoint, const char* token, long refreshRateMs) : + m_ecsCredentialsClient(Aws::MakeShared<Aws::Internal::ECSCredentialsClient>(TASK_ROLE_LOG_TAG, ""/*resourcePath*/, endpoint, token)), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(TASK_ROLE_LOG_TAG, "Creating TaskRole with default ECSCredentialsClient and refresh rate " << refreshRateMs); +} + +TaskRoleCredentialsProvider::TaskRoleCredentialsProvider( + const std::shared_ptr<Aws::Internal::ECSCredentialsClient>& client, long refreshRateMs) : + m_ecsCredentialsClient(client), + m_loadFrequencyMs(refreshRateMs) +{ + AWS_LOGSTREAM_INFO(TASK_ROLE_LOG_TAG, "Creating TaskRole with default ECSCredentialsClient and refresh rate " << refreshRateMs); +} + +AWSCredentials TaskRoleCredentialsProvider::GetAWSCredentials() +{ + RefreshIfExpired(); + ReaderLockGuard guard(m_reloadLock); + return m_credentials; +} + +bool TaskRoleCredentialsProvider::ExpiresSoon() const +{ + return ((m_credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < EXPIRATION_GRACE_PERIOD); +} + +void TaskRoleCredentialsProvider::Reload() +{ + AWS_LOGSTREAM_INFO(TASK_ROLE_LOG_TAG, "Credentials have expired or will expire, attempting to repull from ECS IAM Service."); + + auto credentialsStr = m_ecsCredentialsClient->GetECSCredentials(); + if (credentialsStr.empty()) return; + + Json::JsonValue credentialsDoc(credentialsStr); + if (!credentialsDoc.WasParseSuccessful()) + { + AWS_LOGSTREAM_ERROR(TASK_ROLE_LOG_TAG, "Failed to parse output from ECSCredentialService."); + return; + } + + Aws::String accessKey, secretKey, token; + Utils::Json::JsonView credentialsView(credentialsDoc); + accessKey = credentialsView.GetString("AccessKeyId"); + secretKey = credentialsView.GetString("SecretAccessKey"); + token = credentialsView.GetString("Token"); + AWS_LOGSTREAM_DEBUG(TASK_ROLE_LOG_TAG, "Successfully pulled credentials from metadata service with access key " << accessKey); + + m_credentials.SetAWSAccessKeyId(accessKey); + m_credentials.SetAWSSecretKey(secretKey); + m_credentials.SetSessionToken(token); + m_credentials.SetExpiration(Aws::Utils::DateTime(credentialsView.GetString("Expiration"), DateFormat::ISO_8601)); + AWSCredentialsProvider::Reload(); +} + +void TaskRoleCredentialsProvider::RefreshIfExpired() +{ + AWS_LOGSTREAM_DEBUG(TASK_ROLE_LOG_TAG, "Checking if latest credential pull has expired."); + ReaderLockGuard guard(m_reloadLock); + if (!m_credentials.IsEmpty() && !IsTimeToRefresh(m_loadFrequencyMs) && !ExpiresSoon()) + { + return; + } + + guard.UpgradeToWriterLock(); + + if (!m_credentials.IsEmpty() && !IsTimeToRefresh(m_loadFrequencyMs) && !ExpiresSoon()) + { + return; + } + + Reload(); +} + +static const char PROCESS_LOG_TAG[] = "ProcessCredentialsProvider"; +ProcessCredentialsProvider::ProcessCredentialsProvider() : + m_profileToUse(Aws::Auth::GetConfigProfileName()) +{ + AWS_LOGSTREAM_INFO(PROCESS_LOG_TAG, "Setting process credentials provider to read config from " << m_profileToUse); +} + +ProcessCredentialsProvider::ProcessCredentialsProvider(const Aws::String& profile) : + m_profileToUse(profile) +{ + AWS_LOGSTREAM_INFO(PROCESS_LOG_TAG, "Setting process credentials provider to read config from " << m_profileToUse); +} + +AWSCredentials ProcessCredentialsProvider::GetAWSCredentials() +{ + RefreshIfExpired(); + ReaderLockGuard guard(m_reloadLock); + return m_credentials; +} + + +void ProcessCredentialsProvider::Reload() +{ + auto profile = Aws::Config::GetCachedConfigProfile(m_profileToUse); + const Aws::String &command = profile.GetCredentialProcess(); + if (command.empty()) + { + AWS_LOGSTREAM_ERROR(PROCESS_LOG_TAG, "Failed to find credential process's profile: " << m_profileToUse); + return; + } + m_credentials = GetCredentialsFromProcess(command); +} + +void ProcessCredentialsProvider::RefreshIfExpired() +{ + ReaderLockGuard guard(m_reloadLock); + if (!m_credentials.IsExpiredOrEmpty()) + { + return; + } + + guard.UpgradeToWriterLock(); + if (!m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice + { + return; + } + + Reload(); +} + +AWSCredentials Aws::Auth::GetCredentialsFromProcess(const Aws::String& process) +{ + Aws::String command = process; + command.append(" 2>&1"); // redirect stderr to stdout + Aws::String result = Aws::Utils::StringUtils::Trim(Aws::OSVersionInfo::GetSysCommandOutput(command.c_str()).c_str()); + Json::JsonValue credentialsDoc(result); + if (!credentialsDoc.WasParseSuccessful()) + { + AWS_LOGSTREAM_ERROR(PROFILE_LOG_TAG, "Failed to load credential from running: " << command << " Error: " << result); + return {}; + } + + Aws::Utils::Json::JsonView credentialsView(credentialsDoc); + if (!credentialsView.KeyExists("Version") || credentialsView.GetInteger("Version") != 1) + { + AWS_LOGSTREAM_ERROR(PROFILE_LOG_TAG, "Encountered an unsupported process credentials payload version:" << credentialsView.GetInteger("Version")); + return {}; + } + + AWSCredentials credentials; + Aws::String accessKey, secretKey, token, expire; + if (credentialsView.KeyExists("AccessKeyId")) + { + credentials.SetAWSAccessKeyId(credentialsView.GetString("AccessKeyId")); + } + + if (credentialsView.KeyExists("SecretAccessKey")) + { + credentials.SetAWSSecretKey(credentialsView.GetString("SecretAccessKey")); + } + + if (credentialsView.KeyExists("SessionToken")) + { + credentials.SetSessionToken(credentialsView.GetString("SessionToken")); + } + + if (credentialsView.KeyExists("Expiration")) + { + const auto expiration = Aws::Utils::DateTime(credentialsView.GetString("Expiration"), DateFormat::ISO_8601); + if (expiration.WasParseSuccessful()) + { + credentials.SetExpiration(expiration); + } + else + { + AWS_LOGSTREAM_ERROR(PROFILE_LOG_TAG, "Failed to parse credential's expiration value as an ISO 8601 Date. Credentials will be marked expired."); + credentials.SetExpiration(Aws::Utils::DateTime::Now()); + } + } + else + { + credentials.SetExpiration((std::chrono::time_point<std::chrono::system_clock>::max)()); + } + + AWS_LOGSTREAM_DEBUG(PROFILE_LOG_TAG, "Successfully pulled credentials from process credential with AccessKey: " << accessKey << ", Expiration:" << credentialsView.GetString("Expiration")); + return credentials; +} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp new file mode 100644 index 0000000000..373136d96a --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp @@ -0,0 +1,77 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/auth/AWSCredentialsProviderChain.h> +#include <aws/core/auth/STSCredentialsProvider.h> +#include <aws/core/platform/Environment.h> +#include <aws/core/utils/memory/AWSMemory.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/logging/LogMacros.h> + +using namespace Aws::Auth; + +static const char AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"; +static const char AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI"; +static const char AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN[] = "AWS_CONTAINER_AUTHORIZATION_TOKEN"; +static const char AWS_EC2_METADATA_DISABLED[] = "AWS_EC2_METADATA_DISABLED"; +static const char DefaultCredentialsProviderChainTag[] = "DefaultAWSCredentialsProviderChain"; + +AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials() +{ + for (auto&& credentialsProvider : m_providerChain) + { + AWSCredentials credentials = credentialsProvider->GetAWSCredentials(); + if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty()) + { + return credentials; + } + } + + return AWSCredentials(); +} + +DefaultAWSCredentialsProviderChain::DefaultAWSCredentialsProviderChain() : AWSCredentialsProviderChain() +{ + AddProvider(Aws::MakeShared<EnvironmentAWSCredentialsProvider>(DefaultCredentialsProviderChainTag)); + AddProvider(Aws::MakeShared<ProfileConfigFileAWSCredentialsProvider>(DefaultCredentialsProviderChainTag)); + AddProvider(Aws::MakeShared<ProcessCredentialsProvider>(DefaultCredentialsProviderChainTag)); + AddProvider(Aws::MakeShared<STSAssumeRoleWebIdentityCredentialsProvider>(DefaultCredentialsProviderChainTag)); + + //ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set + const auto relativeUri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI); + AWS_LOGSTREAM_DEBUG(DefaultCredentialsProviderChainTag, "The environment variable value " << AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI + << " is " << relativeUri); + + const auto absoluteUri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI); + AWS_LOGSTREAM_DEBUG(DefaultCredentialsProviderChainTag, "The environment variable value " << AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI + << " is " << absoluteUri); + + const auto ec2MetadataDisabled = Aws::Environment::GetEnv(AWS_EC2_METADATA_DISABLED); + AWS_LOGSTREAM_DEBUG(DefaultCredentialsProviderChainTag, "The environment variable value " << AWS_EC2_METADATA_DISABLED + << " is " << ec2MetadataDisabled); + + if (!relativeUri.empty()) + { + AddProvider(Aws::MakeShared<TaskRoleCredentialsProvider>(DefaultCredentialsProviderChainTag, relativeUri.c_str())); + AWS_LOGSTREAM_INFO(DefaultCredentialsProviderChainTag, "Added ECS metadata service credentials provider with relative path: [" + << relativeUri << "] to the provider chain."); + } + else if (!absoluteUri.empty()) + { + const auto token = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN); + AddProvider(Aws::MakeShared<TaskRoleCredentialsProvider>(DefaultCredentialsProviderChainTag, + absoluteUri.c_str(), token.c_str())); + + //DO NOT log the value of the authorization token for security purposes. + AWS_LOGSTREAM_INFO(DefaultCredentialsProviderChainTag, "Added ECS credentials provider with URI: [" + << absoluteUri << "] to the provider chain with a" << (token.empty() ? "n empty " : " non-empty ") + << "authorization token."); + } + else if (Aws::Utils::StringUtils::ToLower(ec2MetadataDisabled.c_str()) != "true") + { + AddProvider(Aws::MakeShared<InstanceProfileCredentialsProvider>(DefaultCredentialsProviderChainTag)); + AWS_LOGSTREAM_INFO(DefaultCredentialsProviderChainTag, "Added EC2 metadata service credentials provider to the provider chain."); + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/STSCredentialsProvider.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/STSCredentialsProvider.cpp new file mode 100644 index 0000000000..3f48c9e0c7 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/auth/STSCredentialsProvider.cpp @@ -0,0 +1,163 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/auth/STSCredentialsProvider.h> +#include <aws/core/config/AWSProfileConfigLoader.h> +#include <aws/core/platform/Environment.h> +#include <aws/core/platform/FileSystem.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/FileSystemUtils.h> +#include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/UUID.h> +#include <cstdlib> +#include <fstream> +#include <string.h> +#include <climits> + + +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; +using namespace Aws::Auth; +using namespace Aws::Internal; +using namespace Aws::FileSystem; +using namespace Aws::Client; +using Aws::Utils::Threading::ReaderLockGuard; +using Aws::Utils::Threading::WriterLockGuard; + +static const char STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG[] = "STSAssumeRoleWithWebIdentityCredentialsProvider"; +STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider() : + m_initialized(false) +{ + // check environment variables + Aws::String tmpRegion = Aws::Environment::GetEnv("AWS_DEFAULT_REGION"); + m_roleArn = Aws::Environment::GetEnv("AWS_ROLE_ARN"); + m_tokenFile = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE"); + m_sessionName = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME"); + + // check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable + // region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file. + if (m_roleArn.empty() || m_tokenFile.empty() || tmpRegion.empty()) + { + auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName()); + if (tmpRegion.empty()) + { + tmpRegion = profile.GetRegion(); + } + // If either of these two were not found from environment, use whatever found for all three in config file + if (m_roleArn.empty() || m_tokenFile.empty()) + { + m_roleArn = profile.GetRoleArn(); + m_tokenFile = profile.GetValue("web_identity_token_file"); + m_sessionName = profile.GetValue("role_session_name"); + } + } + + if (m_tokenFile.empty()) + { + AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Token file must be specified to use STS AssumeRole web identity creds provider."); + return; // No need to do further constructing + } + else + { + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved token_file from profile_config or environment variable to be " << m_tokenFile); + } + + if (m_roleArn.empty()) + { + AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "RoleArn must be specified to use STS AssumeRole web identity creds provider."); + return; // No need to do further constructing + } + else + { + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved role_arn from profile_config or environment variable to be " << m_roleArn); + } + + if (tmpRegion.empty()) + { + tmpRegion = Aws::Region::US_EAST_1; + } + else + { + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved region from profile_config or environment variable to be " << tmpRegion); + } + + if (m_sessionName.empty()) + { + m_sessionName = Aws::Utils::UUID::RandomUUID(); + } + else + { + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved session_name from profile_config or environment variable to be " << m_sessionName); + } + + Aws::Client::ClientConfiguration config; + config.scheme = Aws::Http::Scheme::HTTPS; + config.region = tmpRegion; + + Aws::Vector<Aws::String> retryableErrors; + retryableErrors.push_back("IDPCommunicationError"); + retryableErrors.push_back("InvalidIdentityToken"); + + config.retryStrategy = Aws::MakeShared<SpecifiedRetryableErrorsRetryStrategy>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, retryableErrors, 3/*maxRetries*/); + + m_client = Aws::MakeUnique<Aws::Internal::STSCredentialsClient>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, config); + m_initialized = true; + AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Creating STS AssumeRole with web identity creds provider."); +} + +AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials() +{ + // A valid client means required information like role arn and token file were constructed correctly. + // We can use this provider to load creds, otherwise, we can just return empty creds. + if (!m_initialized) + { + return Aws::Auth::AWSCredentials(); + } + RefreshIfExpired(); + ReaderLockGuard guard(m_reloadLock); + return m_credentials; +} + +void STSAssumeRoleWebIdentityCredentialsProvider::Reload() +{ + AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Credentials have expired, attempting to renew from STS."); + + Aws::IFStream tokenFile(m_tokenFile.c_str()); + if(tokenFile) + { + Aws::String token((std::istreambuf_iterator<char>(tokenFile)), std::istreambuf_iterator<char>()); + m_token = token; + } + else + { + AWS_LOGSTREAM_ERROR(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Can't open token file: " << m_tokenFile); + return; + } + STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request {m_sessionName, m_roleArn, m_token}; + + auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request); + AWS_LOGSTREAM_TRACE(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Successfully retrieved credentials with AWS_ACCESS_KEY: " << result.creds.GetAWSAccessKeyId()); + m_credentials = result.creds; +} + +void STSAssumeRoleWebIdentityCredentialsProvider::RefreshIfExpired() +{ + ReaderLockGuard guard(m_reloadLock); + if (!m_credentials.IsExpiredOrEmpty()) + { + return; + } + + guard.UpgradeToWriterLock(); + if (!m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice + { + return; + } + + Reload(); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AWSClient.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AWSClient.cpp new file mode 100644 index 0000000000..e1ff064840 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AWSClient.cpp @@ -0,0 +1,1098 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/AWSClient.h> +#include <aws/core/AmazonWebServiceRequest.h> +#include <aws/core/auth/AWSAuthSigner.h> +#include <aws/core/auth/AWSAuthSignerProvider.h> +#include <aws/core/client/AWSError.h> +#include <aws/core/client/AWSErrorMarshaller.h> +#include <aws/core/client/ClientConfiguration.h> +#include <aws/core/client/CoreErrors.h> +#include <aws/core/client/RetryStrategy.h> +#include <aws/core/http/HttpClient.h> +#include <aws/core/http/HttpClientFactory.h> +#include <aws/core/http/HttpResponse.h> +#include <aws/core/http/standard/StandardHttpResponse.h> +#include <aws/core/http/URI.h> +#include <aws/core/utils/stream/ResponseStream.h> +#include <aws/core/utils/json/JsonSerializer.h> +#include <aws/core/utils/Outcome.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/xml/XmlSerializer.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/Globals.h> +#include <aws/core/utils/EnumParseOverflowContainer.h> +#include <aws/core/utils/crypto/MD5.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/crypto/Factories.h> +#include <aws/core/utils/event/EventStream.h> +#include <aws/core/utils/UUID.h> +#include <aws/core/monitoring/MonitoringManager.h> +#include <aws/core/Region.h> +#include <aws/core/utils/DNS.h> +#include <aws/core/Version.h> +#include <aws/core/platform/OSVersionInfo.h> + +#include <cstring> +#include <cassert> + +using namespace Aws; +using namespace Aws::Client; +using namespace Aws::Http; +using namespace Aws::Utils; +using namespace Aws::Utils::Json; +using namespace Aws::Utils::Xml; + +static const int SUCCESS_RESPONSE_MIN = 200; +static const int SUCCESS_RESPONSE_MAX = 299; + +static const char AWS_CLIENT_LOG_TAG[] = "AWSClient"; +//4 Minutes +static const std::chrono::milliseconds TIME_DIFF_MAX = std::chrono::minutes(4); +//-4 Minutes +static const std::chrono::milliseconds TIME_DIFF_MIN = std::chrono::minutes(-4); + +static CoreErrors GuessBodylessErrorType(Aws::Http::HttpResponseCode responseCode) +{ + switch (responseCode) + { + case HttpResponseCode::FORBIDDEN: + case HttpResponseCode::UNAUTHORIZED: + return CoreErrors::ACCESS_DENIED; + case HttpResponseCode::NOT_FOUND: + return CoreErrors::RESOURCE_NOT_FOUND; + default: + return CoreErrors::UNKNOWN; + } +} + +struct RequestInfo +{ + Aws::Utils::DateTime ttl; + long attempt; + long maxAttempts; + + operator String() + { + Aws::StringStream ss; + if (ttl.WasParseSuccessful() && ttl != DateTime()) + { + assert(attempt > 1); + ss << "ttl=" << ttl.ToGmtString(DateFormat::ISO_8601_BASIC) << "; "; + } + ss << "attempt=" << attempt; + if (maxAttempts > 0) + { + ss << "; max=" << maxAttempts; + } + return ss.str(); + } +}; + +AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, + const std::shared_ptr<Aws::Client::AWSAuthSigner>& signer, + const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) : + m_region(configuration.region), + m_httpClient(CreateHttpClient(configuration)), + m_signerProvider(Aws::MakeUnique<Aws::Auth::DefaultAuthSignerProvider>(AWS_CLIENT_LOG_TAG, signer)), + m_errorMarshaller(errorMarshaller), + m_retryStrategy(configuration.retryStrategy), + m_writeRateLimiter(configuration.writeRateLimiter), + m_readRateLimiter(configuration.readRateLimiter), + m_userAgent(configuration.userAgent), + m_customizedUserAgent(!m_userAgent.empty()), + m_hash(Aws::Utils::Crypto::CreateMD5Implementation()), + m_requestTimeoutMs(configuration.requestTimeoutMs), + m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment) +{ + SetServiceClientName("AWSBaseClient"); +} + +AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, + const std::shared_ptr<Aws::Auth::AWSAuthSignerProvider>& signerProvider, + const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) : + m_region(configuration.region), + m_httpClient(CreateHttpClient(configuration)), + m_signerProvider(signerProvider), + m_errorMarshaller(errorMarshaller), + m_retryStrategy(configuration.retryStrategy), + m_writeRateLimiter(configuration.writeRateLimiter), + m_readRateLimiter(configuration.readRateLimiter), + m_userAgent(configuration.userAgent), + m_customizedUserAgent(!m_userAgent.empty()), + m_hash(Aws::Utils::Crypto::CreateMD5Implementation()), + m_requestTimeoutMs(configuration.requestTimeoutMs), + m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment) +{ + SetServiceClientName("AWSBaseClient"); +} + +void AWSClient::SetServiceClientName(const Aws::String& name) +{ + m_serviceName = name; + if (!m_customizedUserAgent) + { + Aws::StringStream ss; + ss << "aws-sdk-cpp/" << Version::GetVersionString() << "/" << m_serviceName << "/" << Aws::OSVersionInfo::ComputeOSVersionString() + << " " << Version::GetCompilerVersionString(); + m_userAgent = ss.str(); + } +} + +void AWSClient::DisableRequestProcessing() +{ + m_httpClient->DisableRequestProcessing(); +} + +void AWSClient::EnableRequestProcessing() +{ + m_httpClient->EnableRequestProcessing(); +} + +Aws::Client::AWSAuthSigner* AWSClient::GetSignerByName(const char* name) const +{ + const auto& signer = m_signerProvider->GetSigner(name); + return signer ? signer.get() : nullptr; +} + +static DateTime GetServerTimeFromError(const AWSError<CoreErrors> error) +{ + const Http::HeaderValueCollection& headers = error.GetResponseHeaders(); + auto awsDateHeaderIter = headers.find(StringUtils::ToLower(Http::AWS_DATE_HEADER)); + auto dateHeaderIter = headers.find(StringUtils::ToLower(Http::DATE_HEADER)); + if (awsDateHeaderIter != headers.end()) + { + return DateTime(awsDateHeaderIter->second.c_str(), DateFormat::AutoDetect); + } + else if (dateHeaderIter != headers.end()) + { + return DateTime(dateHeaderIter->second.c_str(), DateFormat::AutoDetect); + } + else + { + return DateTime(); + } +} + +bool AWSClient::AdjustClockSkew(HttpResponseOutcome& outcome, const char* signerName) const +{ + if (m_enableClockSkewAdjustment) + { + auto signer = GetSignerByName(signerName); + //detect clock skew and try to correct. + AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "If the signature check failed. This could be because of a time skew. Attempting to adjust the signer."); + + DateTime serverTime = GetServerTimeFromError(outcome.GetError()); + const auto signingTimestamp = signer->GetSigningTimestamp(); + if (!serverTime.WasParseSuccessful() || serverTime == DateTime()) + { + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Date header was not found in the response, can't attempt to detect clock skew"); + return false; + } + + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Server time is " << serverTime.ToGmtString(DateFormat::RFC822) << ", while client time is " << DateTime::Now().ToGmtString(DateFormat::RFC822)); + auto diff = DateTime::Diff(serverTime, signingTimestamp); + //only try again if clock skew was the cause of the error. + if (diff >= TIME_DIFF_MAX || diff <= TIME_DIFF_MIN) + { + diff = DateTime::Diff(serverTime, DateTime::Now()); + AWS_LOGSTREAM_INFO(AWS_CLIENT_LOG_TAG, "Computed time difference as " << diff.count() << " milliseconds. Adjusting signer with the skew."); + signer->SetClockSkew(diff); + AWSError<CoreErrors> newError( + outcome.GetError().GetErrorType(), outcome.GetError().GetExceptionName(), outcome.GetError().GetMessage(), true); + newError.SetResponseHeaders(outcome.GetError().GetResponseHeaders()); + newError.SetResponseCode(outcome.GetError().GetResponseCode()); + outcome = std::move(newError); + return true; + } + } + return false; +} + +HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri, + const Aws::AmazonWebServiceRequest& request, + HttpMethod method, + const char* signerName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + if (!Aws::Utils::IsValidHost(uri.GetAuthority())) + { + return HttpResponseOutcome(AWSError<CoreErrors>(CoreErrors::VALIDATION, "", "Invalid DNS Label found in URI host", false/*retryable*/)); + } + std::shared_ptr<HttpRequest> httpRequest(CreateHttpRequest(uri, method, request.GetResponseStreamFactory())); + HttpResponseOutcome outcome; + AWSError<CoreErrors> lastError; + Aws::Monitoring::CoreMetricsCollection coreMetrics; + auto contexts = Aws::Monitoring::OnRequestStarted(this->GetServiceClientName(), request.GetServiceRequestName(), httpRequest); + const char* signerRegion = signerRegionOverride; + Aws::String regionFromResponse; + + Aws::String invocationId = UUID::RandomUUID(); + RequestInfo requestInfo; + requestInfo.attempt = 1; + requestInfo.maxAttempts = 0; + httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId); + httpRequest->SetHeaderValue(Http::SDK_REQUEST_HEADER, requestInfo); + + for (long retries = 0;; retries++) + { + m_retryStrategy->GetSendToken(); + httpRequest->SetEventStreamRequest(request.IsEventStreamRequest()); + + outcome = AttemptOneRequest(httpRequest, request, signerName, signerRegion, signerServiceNameOverride); + if (retries == 0) + { + m_retryStrategy->RequestBookkeeping(outcome); + } + else + { + m_retryStrategy->RequestBookkeeping(outcome, lastError); + } + coreMetrics.httpClientMetrics = httpRequest->GetRequestMetrics(); + if (outcome.IsSuccess()) + { + Aws::Monitoring::OnRequestSucceeded(this->GetServiceClientName(), request.GetServiceRequestName(), httpRequest, outcome, coreMetrics, contexts); + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "Request successful returning."); + break; + } + lastError = outcome.GetError(); + + DateTime serverTime = GetServerTimeFromError(outcome.GetError()); + auto clockSkew = DateTime::Diff(serverTime, DateTime::Now()); + + Aws::Monitoring::OnRequestFailed(this->GetServiceClientName(), request.GetServiceRequestName(), httpRequest, outcome, coreMetrics, contexts); + + if (!m_httpClient->IsRequestProcessingEnabled()) + { + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "Request was cancelled externally."); + break; + } + + // Adjust region + bool retryWithCorrectRegion = false; + HttpResponseCode httpResponseCode = outcome.GetError().GetResponseCode(); + if (httpResponseCode == HttpResponseCode::MOVED_PERMANENTLY || // 301 + httpResponseCode == HttpResponseCode::TEMPORARY_REDIRECT || // 307 + httpResponseCode == HttpResponseCode::BAD_REQUEST || // 400 + httpResponseCode == HttpResponseCode::FORBIDDEN) // 403 + { + regionFromResponse = GetErrorMarshaller()->ExtractRegion(outcome.GetError()); + if (m_region == Aws::Region::AWS_GLOBAL && !regionFromResponse.empty() && regionFromResponse != signerRegion) + { + signerRegion = regionFromResponse.c_str(); + retryWithCorrectRegion = true; + } + } + + long sleepMillis = m_retryStrategy->CalculateDelayBeforeNextRetry(outcome.GetError(), retries); + //AdjustClockSkew returns true means clock skew was the problem and skew was adjusted, false otherwise. + //sleep if clock skew and region was NOT the problem. AdjustClockSkew may update error inside outcome. + bool shouldSleep = !AdjustClockSkew(outcome, signerName) && !retryWithCorrectRegion; + + if (!retryWithCorrectRegion && !m_retryStrategy->ShouldRetry(outcome.GetError(), retries)) + { + break; + } + + AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "Request failed, now waiting " << sleepMillis << " ms before attempting again."); + if(request.GetBody()) + { + request.GetBody()->clear(); + request.GetBody()->seekg(0); + } + + if (request.GetRequestRetryHandler()) + { + request.GetRequestRetryHandler()(request); + } + + if (shouldSleep) + { + m_httpClient->RetryRequestSleep(std::chrono::milliseconds(sleepMillis)); + } + + Aws::Http::URI newUri = uri; + Aws::String newEndpoint = GetErrorMarshaller()->ExtractEndpoint(outcome.GetError()); + if (!newEndpoint.empty()) + { + newUri.SetAuthority(newEndpoint); + } + httpRequest = CreateHttpRequest(newUri, method, request.GetResponseStreamFactory()); + + httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId); + if (serverTime.WasParseSuccessful() && serverTime != DateTime()) + { + requestInfo.ttl = DateTime::Now() + clockSkew + std::chrono::milliseconds(m_requestTimeoutMs); + } + requestInfo.attempt ++; + requestInfo.maxAttempts = m_retryStrategy->GetMaxAttempts(); + httpRequest->SetHeaderValue(Http::SDK_REQUEST_HEADER, requestInfo); + Aws::Monitoring::OnRequestRetry(this->GetServiceClientName(), request.GetServiceRequestName(), httpRequest, contexts); + } + Aws::Monitoring::OnFinish(this->GetServiceClientName(), request.GetServiceRequestName(), httpRequest, contexts); + return outcome; +} + +HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri, + HttpMethod method, + const char* signerName, + const char* requestName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + if (!Aws::Utils::IsValidHost(uri.GetAuthority())) + { + return HttpResponseOutcome(AWSError<CoreErrors>(CoreErrors::VALIDATION, "", "Invalid DNS Label found in URI host", false/*retryable*/)); + } + + std::shared_ptr<HttpRequest> httpRequest(CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + HttpResponseOutcome outcome; + AWSError<CoreErrors> lastError; + Aws::Monitoring::CoreMetricsCollection coreMetrics; + auto contexts = Aws::Monitoring::OnRequestStarted(this->GetServiceClientName(), requestName, httpRequest); + const char* signerRegion = signerRegionOverride; + Aws::String regionFromResponse; + + Aws::String invocationId = UUID::RandomUUID(); + RequestInfo requestInfo; + requestInfo.attempt = 1; + requestInfo.maxAttempts = 0; + httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId); + httpRequest->SetHeaderValue(Http::SDK_REQUEST_HEADER, requestInfo); + + for (long retries = 0;; retries++) + { + m_retryStrategy->GetSendToken(); + outcome = AttemptOneRequest(httpRequest, signerName, requestName, signerRegion, signerServiceNameOverride); + if (retries == 0) + { + m_retryStrategy->RequestBookkeeping(outcome); + } + else + { + m_retryStrategy->RequestBookkeeping(outcome, lastError); + } + coreMetrics.httpClientMetrics = httpRequest->GetRequestMetrics(); + if (outcome.IsSuccess()) + { + Aws::Monitoring::OnRequestSucceeded(this->GetServiceClientName(), requestName, httpRequest, outcome, coreMetrics, contexts); + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "Request successful returning."); + break; + } + lastError = outcome.GetError(); + + DateTime serverTime = GetServerTimeFromError(outcome.GetError()); + auto clockSkew = DateTime::Diff(serverTime, DateTime::Now()); + + Aws::Monitoring::OnRequestFailed(this->GetServiceClientName(), requestName, httpRequest, outcome, coreMetrics, contexts); + + if (!m_httpClient->IsRequestProcessingEnabled()) + { + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "Request was cancelled externally."); + break; + } + + // Adjust region + bool retryWithCorrectRegion = false; + HttpResponseCode httpResponseCode = outcome.GetError().GetResponseCode(); + if (httpResponseCode == HttpResponseCode::MOVED_PERMANENTLY || // 301 + httpResponseCode == HttpResponseCode::TEMPORARY_REDIRECT || // 307 + httpResponseCode == HttpResponseCode::BAD_REQUEST || // 400 + httpResponseCode == HttpResponseCode::FORBIDDEN) // 403 + { + regionFromResponse = GetErrorMarshaller()->ExtractRegion(outcome.GetError()); + if (m_region == Aws::Region::AWS_GLOBAL && !regionFromResponse.empty() && regionFromResponse != signerRegion) + { + signerRegion = regionFromResponse.c_str(); + retryWithCorrectRegion = true; + } + } + + long sleepMillis = m_retryStrategy->CalculateDelayBeforeNextRetry(outcome.GetError(), retries); + //AdjustClockSkew returns true means clock skew was the problem and skew was adjusted, false otherwise. + //sleep if clock skew and region was NOT the problem. AdjustClockSkew may update error inside outcome. + bool shouldSleep = !AdjustClockSkew(outcome, signerName) && !retryWithCorrectRegion; + + if (!retryWithCorrectRegion && !m_retryStrategy->ShouldRetry(outcome.GetError(), retries)) + { + break; + } + + AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "Request failed, now waiting " << sleepMillis << " ms before attempting again."); + + if (shouldSleep) + { + m_httpClient->RetryRequestSleep(std::chrono::milliseconds(sleepMillis)); + } + + Aws::Http::URI newUri = uri; + Aws::String newEndpoint = GetErrorMarshaller()->ExtractEndpoint(outcome.GetError()); + if (!newEndpoint.empty()) + { + newUri.SetAuthority(newEndpoint); + } + httpRequest = CreateHttpRequest(newUri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + + httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId); + if (serverTime.WasParseSuccessful() && serverTime != DateTime()) + { + requestInfo.ttl = DateTime::Now() + clockSkew + std::chrono::milliseconds(m_requestTimeoutMs); + } + requestInfo.attempt ++; + requestInfo.maxAttempts = m_retryStrategy->GetMaxAttempts(); + httpRequest->SetHeaderValue(Http::SDK_REQUEST_HEADER, requestInfo); + Aws::Monitoring::OnRequestRetry(this->GetServiceClientName(), requestName, httpRequest, contexts); + } + Aws::Monitoring::OnFinish(this->GetServiceClientName(), requestName, httpRequest, contexts); + return outcome; +} + +static bool DoesResponseGenerateError(const std::shared_ptr<HttpResponse>& response) +{ + if (response->HasClientError()) return true; + + int responseCode = static_cast<int>(response->GetResponseCode()); + return responseCode < SUCCESS_RESPONSE_MIN || responseCode > SUCCESS_RESPONSE_MAX; + +} + +HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptr<HttpRequest>& httpRequest, const Aws::AmazonWebServiceRequest& request, + const char* signerName, const char* signerRegionOverride, const char* signerServiceNameOverride) const +{ + BuildHttpRequest(request, httpRequest); + auto signer = GetSignerByName(signerName); + if (!signer->SignRequest(*httpRequest, signerRegionOverride, signerServiceNameOverride, request.SignBody())) + { + AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, "Request signing failed. Returning error."); + return HttpResponseOutcome(AWSError<CoreErrors>(CoreErrors::CLIENT_SIGNING_FAILURE, "", "SDK failed to sign the request", false/*retryable*/)); + } + + if (request.GetRequestSignedHandler()) + { + request.GetRequestSignedHandler()(*httpRequest); + } + + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request Successfully signed"); + std::shared_ptr<HttpResponse> httpResponse( + m_httpClient->MakeRequest(httpRequest, m_readRateLimiter.get(), m_writeRateLimiter.get())); + + if (DoesResponseGenerateError(httpResponse)) + { + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned error. Attempting to generate appropriate error codes from response"); + auto error = BuildAWSError(httpResponse); + return HttpResponseOutcome(std::move(error)); + } + + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned successful response."); + + return HttpResponseOutcome(std::move(httpResponse)); +} + +HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptr<HttpRequest>& httpRequest, + const char* signerName, const char* requestName, const char* signerRegionOverride, const char* signerServiceNameOverride) const +{ + AWS_UNREFERENCED_PARAM(requestName); + + auto signer = GetSignerByName(signerName); + if (!signer->SignRequest(*httpRequest, signerRegionOverride, signerServiceNameOverride, true)) + { + AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, "Request signing failed. Returning error."); + return HttpResponseOutcome(AWSError<CoreErrors>(CoreErrors::CLIENT_SIGNING_FAILURE, "", "SDK failed to sign the request", false/*retryable*/)); + } + + //user agent and headers like that shouldn't be signed for the sake of compatibility with proxies which MAY mutate that header. + AddCommonHeaders(*httpRequest); + + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request Successfully signed"); + std::shared_ptr<HttpResponse> httpResponse( + m_httpClient->MakeRequest(httpRequest, m_readRateLimiter.get(), m_writeRateLimiter.get())); + + if (DoesResponseGenerateError(httpResponse)) + { + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned error. Attempting to generate appropriate error codes from response"); + auto error = BuildAWSError(httpResponse); + return HttpResponseOutcome(std::move(error)); + } + + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned successful response."); + + return HttpResponseOutcome(std::move(httpResponse)); +} + +StreamOutcome AWSClient::MakeRequestWithUnparsedResponse(const Aws::Http::URI& uri, + const Aws::AmazonWebServiceRequest& request, + Http::HttpMethod method, + const char* signerName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpResponseOutcome = AttemptExhaustively(uri, request, method, signerName, signerRegionOverride, signerServiceNameOverride); + if (httpResponseOutcome.IsSuccess()) + { + return StreamOutcome(AmazonWebServiceResult<Stream::ResponseStream>( + httpResponseOutcome.GetResult()->SwapResponseStreamOwnership(), + httpResponseOutcome.GetResult()->GetHeaders(), httpResponseOutcome.GetResult()->GetResponseCode())); + } + + return StreamOutcome(std::move(httpResponseOutcome)); +} + +StreamOutcome AWSClient::MakeRequestWithUnparsedResponse(const Aws::Http::URI& uri, + Http::HttpMethod method, + const char* signerName, + const char* requestName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpResponseOutcome = AttemptExhaustively(uri, method, signerName, requestName, signerRegionOverride, signerServiceNameOverride); + if (httpResponseOutcome.IsSuccess()) + { + return StreamOutcome(AmazonWebServiceResult<Stream::ResponseStream>( + httpResponseOutcome.GetResult()->SwapResponseStreamOwnership(), + httpResponseOutcome.GetResult()->GetHeaders(), httpResponseOutcome.GetResult()->GetResponseCode())); + } + + return StreamOutcome(std::move(httpResponseOutcome)); +} + +XmlOutcome AWSXMLClient::MakeRequestWithEventStream(const Aws::Http::URI& uri, + const Aws::AmazonWebServiceRequest& request, + Http::HttpMethod method, + const char* signerName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpOutcome = AttemptExhaustively(uri, request, method, signerName, signerRegionOverride, signerServiceNameOverride); + if (httpOutcome.IsSuccess()) + { + return XmlOutcome(AmazonWebServiceResult<XmlDocument>(XmlDocument(), httpOutcome.GetResult()->GetHeaders())); + } + + return XmlOutcome(std::move(httpOutcome)); +} + +XmlOutcome AWSXMLClient::MakeRequestWithEventStream(const Aws::Http::URI& uri, + Http::HttpMethod method, + const char* signerName, + const char* requestName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpOutcome = AttemptExhaustively(uri, method, signerName, requestName, signerRegionOverride, signerServiceNameOverride); + if (httpOutcome.IsSuccess()) + { + return XmlOutcome(AmazonWebServiceResult<XmlDocument>(XmlDocument(), httpOutcome.GetResult()->GetHeaders())); + } + + return XmlOutcome(std::move(httpOutcome)); +} + +void AWSClient::AddHeadersToRequest(const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest, + const Http::HeaderValueCollection& headerValues) const +{ + for (auto const& headerValue : headerValues) + { + httpRequest->SetHeaderValue(headerValue.first, headerValue.second); + } + + AddCommonHeaders(*httpRequest); +} + +void AWSClient::AddContentBodyToRequest(const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest, + const std::shared_ptr<Aws::IOStream>& body, bool needsContentMd5, bool isChunked) const +{ + httpRequest->AddContentBody(body); + + //If there is no body, we have a content length of 0 + //note: we also used to remove content-type, but S3 actually needs content-type on InitiateMultipartUpload and it isn't + //forbiden by the spec. If we start getting weird errors related to this, make sure it isn't caused by this removal. + if (!body) + { + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "No content body, content-length headers"); + + if(httpRequest->GetMethod() == HttpMethod::HTTP_POST || httpRequest->GetMethod() == HttpMethod::HTTP_PUT) + { + httpRequest->SetHeaderValue(Http::CONTENT_LENGTH_HEADER, "0"); + } + else + { + httpRequest->DeleteHeader(Http::CONTENT_LENGTH_HEADER); + } + } + + //Add transfer-encoding:chunked to header + if (body && isChunked) + { + httpRequest->SetTransferEncoding(CHUNKED_VALUE); + } + //in the scenario where we are adding a content body as a stream, the request object likely already + //has a content-length header set and we don't want to seek the stream just to find this information. + else if (body && !httpRequest->HasHeader(Http::CONTENT_LENGTH_HEADER)) + { + if (!m_httpClient->SupportsChunkedTransferEncoding()) + { + AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "This http client doesn't support transfer-encoding:chunked. " << + "The request may fail if it's not a seekable stream."); + } + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "Found body, but content-length has not been set, attempting to compute content-length"); + body->seekg(0, body->end); + auto streamSize = body->tellg(); + body->seekg(0, body->beg); + Aws::StringStream ss; + ss << streamSize; + httpRequest->SetContentLength(ss.str()); + } + + if (needsContentMd5 && body && !httpRequest->HasHeader(Http::CONTENT_MD5_HEADER)) + { + AWS_LOGSTREAM_TRACE(AWS_CLIENT_LOG_TAG, "Found body, and content-md5 needs to be set" << + ", attempting to compute content-md5"); + + //changing the internal state of the hash computation is not a logical state + //change as far as constness goes for this class. Due to the platform specificness + //of hash computations, we can't control the fact that computing a hash mutates + //state on some platforms such as windows (but that isn't a concern of this class. + auto md5HashResult = const_cast<AWSClient*>(this)->m_hash->Calculate(*body); + body->clear(); + if (md5HashResult.IsSuccess()) + { + httpRequest->SetHeaderValue(Http::CONTENT_MD5_HEADER, HashingUtils::Base64Encode(md5HashResult.GetResult())); + } + } +} + +Aws::String Aws::Client::GetAuthorizationHeader(const Aws::Http::HttpRequest& httpRequest) +{ + // Extract the hex-encoded signature from the authorization header rather than recalculating it. + assert(httpRequest.HasAwsAuthorization()); + const auto& authHeader = httpRequest.GetAwsAuthorization(); + auto signaturePosition = authHeader.rfind(Aws::Auth::SIGNATURE); + // The auth header should end with 'Signature=<64 chars>' + // Make sure we found the word 'Signature' in the header and make sure it's the last item followed by its 64 hex chars + if (signaturePosition == Aws::String::npos || (signaturePosition + strlen(Aws::Auth::SIGNATURE) + 1/*'=' character*/ + 64/*hex chars*/) != authHeader.length()) + { + AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, "Failed to extract signature from authorization header."); + return {}; + } + return authHeader.substr(signaturePosition + strlen(Aws::Auth::SIGNATURE) + 1); +} + +void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, + const std::shared_ptr<HttpRequest>& httpRequest) const +{ + //do headers first since the request likely will set content-length as it's own header. + AddHeadersToRequest(httpRequest, request.GetHeaders()); + + if (request.IsEventStreamRequest()) + { + httpRequest->AddContentBody(request.GetBody()); + } + else + { + AddContentBodyToRequest(httpRequest, request.GetBody(), request.ShouldComputeContentMd5(), request.IsStreaming() && request.IsChunked() && m_httpClient->SupportsChunkedTransferEncoding()); + } + + // Pass along handlers for processing data sent/received in bytes + httpRequest->SetDataReceivedEventHandler(request.GetDataReceivedEventHandler()); + httpRequest->SetDataSentEventHandler(request.GetDataSentEventHandler()); + httpRequest->SetContinueRequestHandle(request.GetContinueRequestHandler()); + + request.AddQueryStringParameters(httpRequest->GetUri()); +} + +void AWSClient::AddCommonHeaders(HttpRequest& httpRequest) const +{ + httpRequest.SetUserAgent(m_userAgent); +} + +Aws::String AWSClient::GeneratePresignedUrl(URI& uri, HttpMethod method, long long expirationInSeconds) +{ + std::shared_ptr<HttpRequest> request = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*request, expirationInSeconds)) + { + return request->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(URI& uri, HttpMethod method, const Aws::Http::HeaderValueCollection& customizedHeaders, long long expirationInSeconds) +{ + std::shared_ptr<HttpRequest> request = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + for (const auto& it: customizedHeaders) + { + request->SetHeaderValue(it.first.c_str(), it.second); + } + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*request, expirationInSeconds)) + { + return request->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(URI& uri, HttpMethod method, const char* region, long long expirationInSeconds) const +{ + std::shared_ptr<HttpRequest> request = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*request, region, expirationInSeconds)) + { + return request->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(URI& uri, HttpMethod method, const char* region, const Aws::Http::HeaderValueCollection& customizedHeaders, long long expirationInSeconds) +{ + std::shared_ptr<HttpRequest> request = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + for (const auto& it: customizedHeaders) + { + request->SetHeaderValue(it.first.c_str(), it.second); + } + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*request, region, expirationInSeconds)) + { + return request->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(Aws::Http::URI& uri, Aws::Http::HttpMethod method, const char* region, const char* serviceName, long long expirationInSeconds) const +{ + std::shared_ptr<HttpRequest> request = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*request, region, serviceName, expirationInSeconds)) + { + return request->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(Aws::Http::URI& uri, Aws::Http::HttpMethod method, const char* region, const char* serviceName, const Aws::Http::HeaderValueCollection& customizedHeaders, long long expirationInSeconds) +{ + std::shared_ptr<HttpRequest> request = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + for (const auto& it: customizedHeaders) + { + request->SetHeaderValue(it.first.c_str(), it.second); + } + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*request, region, serviceName, expirationInSeconds)) + { + return request->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(const Aws::AmazonWebServiceRequest& request, Aws::Http::URI& uri, Aws::Http::HttpMethod method, const char* region, + const Aws::Http::QueryStringParameterCollection& extraParams, long long expirationInSeconds) const +{ + std::shared_ptr<HttpRequest> httpRequest = + ConvertToRequestForPresigning(request, uri, method, extraParams); + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*httpRequest, region, expirationInSeconds)) + { + return httpRequest->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(const Aws::AmazonWebServiceRequest& request, Aws::Http::URI& uri, Aws::Http::HttpMethod method, const char* region, const char* serviceName, +const Aws::Http::QueryStringParameterCollection& extraParams, long long expirationInSeconds) const +{ + std::shared_ptr<HttpRequest> httpRequest = + ConvertToRequestForPresigning(request, uri, method, extraParams); + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*httpRequest, region, serviceName, expirationInSeconds)) + { + return httpRequest->GetURIString(); + } + + return {}; +} + +Aws::String AWSClient::GeneratePresignedUrl(const Aws::AmazonWebServiceRequest& request, Aws::Http::URI& uri, Aws::Http::HttpMethod method, + const Aws::Http::QueryStringParameterCollection& extraParams, long long expirationInSeconds) const +{ + std::shared_ptr<HttpRequest> httpRequest = + ConvertToRequestForPresigning(request, uri, method, extraParams); + auto signer = GetSignerByName(Aws::Auth::SIGV4_SIGNER); + if (signer->PresignRequest(*httpRequest, expirationInSeconds)) + { + return httpRequest->GetURIString(); + } + + return {}; +} + +std::shared_ptr<Aws::Http::HttpRequest> AWSClient::ConvertToRequestForPresigning(const Aws::AmazonWebServiceRequest& request, Aws::Http::URI& uri, + Aws::Http::HttpMethod method, const Aws::Http::QueryStringParameterCollection& extraParams) const +{ + request.PutToPresignedUrl(uri); + std::shared_ptr<HttpRequest> httpRequest = CreateHttpRequest(uri, method, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + + for (auto& param : extraParams) + { + httpRequest->AddQueryStringParameter(param.first.c_str(), param.second); + } + + return httpRequest; +} + +std::shared_ptr<Aws::Http::HttpResponse> AWSClient::MakeHttpRequest(std::shared_ptr<Aws::Http::HttpRequest>& request) const +{ + return m_httpClient->MakeRequest(request, m_readRateLimiter.get(), m_writeRateLimiter.get()); +} + + +//////////////////////////////////////////////////////////////////////////// +AWSJsonClient::AWSJsonClient(const Aws::Client::ClientConfiguration& configuration, + const std::shared_ptr<Aws::Client::AWSAuthSigner>& signer, + const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) : + BASECLASS(configuration, signer, errorMarshaller) +{ +} + +AWSJsonClient::AWSJsonClient(const Aws::Client::ClientConfiguration& configuration, + const std::shared_ptr<Aws::Auth::AWSAuthSignerProvider>& signerProvider, + const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) : + BASECLASS(configuration, signerProvider, errorMarshaller) +{ +} + + +JsonOutcome AWSJsonClient::MakeRequest(const Aws::Http::URI& uri, + const Aws::AmazonWebServiceRequest& request, + Http::HttpMethod method, + const char* signerName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpOutcome(BASECLASS::AttemptExhaustively(uri, request, method, signerName, signerRegionOverride, signerServiceNameOverride)); + if (!httpOutcome.IsSuccess()) + { + return JsonOutcome(std::move(httpOutcome)); + } + + if (httpOutcome.GetResult()->GetResponseBody().tellp() > 0) + //this is stupid, but gcc doesn't pick up the covariant on the dereference so we have to give it a little hint. + return JsonOutcome(AmazonWebServiceResult<JsonValue>(JsonValue(httpOutcome.GetResult()->GetResponseBody()), + httpOutcome.GetResult()->GetHeaders(), + httpOutcome.GetResult()->GetResponseCode())); + + else + return JsonOutcome(AmazonWebServiceResult<JsonValue>(JsonValue(), httpOutcome.GetResult()->GetHeaders())); +} + +JsonOutcome AWSJsonClient::MakeRequest(const Aws::Http::URI& uri, + Http::HttpMethod method, + const char* signerName, + const char* requestName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpOutcome(BASECLASS::AttemptExhaustively(uri, method, signerName, requestName, signerRegionOverride, signerServiceNameOverride)); + if (!httpOutcome.IsSuccess()) + { + return JsonOutcome(std::move(httpOutcome)); + } + + if (httpOutcome.GetResult()->GetResponseBody().tellp() > 0) + { + JsonValue jsonValue(httpOutcome.GetResult()->GetResponseBody()); + if (!jsonValue.WasParseSuccessful()) + { + return JsonOutcome(AWSError<CoreErrors>(CoreErrors::UNKNOWN, "Json Parser Error", jsonValue.GetErrorMessage(), false)); + } + + //this is stupid, but gcc doesn't pick up the covariant on the dereference so we have to give it a little hint. + return JsonOutcome(AmazonWebServiceResult<JsonValue>(std::move(jsonValue), + httpOutcome.GetResult()->GetHeaders(), + httpOutcome.GetResult()->GetResponseCode())); + } + + return JsonOutcome(AmazonWebServiceResult<JsonValue>(JsonValue(), httpOutcome.GetResult()->GetHeaders())); +} + +JsonOutcome AWSJsonClient::MakeEventStreamRequest(std::shared_ptr<Aws::Http::HttpRequest>& request) const +{ + // request is assumed to be signed + std::shared_ptr<HttpResponse> httpResponse = MakeHttpRequest(request); + + if (DoesResponseGenerateError(httpResponse)) + { + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned error. Attempting to generate appropriate error codes from response"); + auto error = BuildAWSError(httpResponse); + return JsonOutcome(std::move(error)); + } + + AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned successful response."); + + HttpResponseOutcome httpOutcome(std::move(httpResponse)); + + if (httpOutcome.GetResult()->GetResponseBody().tellp() > 0) + { + JsonValue jsonValue(httpOutcome.GetResult()->GetResponseBody()); + if (!jsonValue.WasParseSuccessful()) + { + return JsonOutcome(AWSError<CoreErrors>(CoreErrors::UNKNOWN, "Json Parser Error", jsonValue.GetErrorMessage(), false)); + } + + //this is stupid, but gcc doesn't pick up the covariant on the dereference so we have to give it a little hint. + return JsonOutcome(AmazonWebServiceResult<JsonValue>(std::move(jsonValue), + httpOutcome.GetResult()->GetHeaders(), + httpOutcome.GetResult()->GetResponseCode())); + } + + return JsonOutcome(AmazonWebServiceResult<JsonValue>(JsonValue(), httpOutcome.GetResult()->GetHeaders())); +} + +AWSError<CoreErrors> AWSJsonClient::BuildAWSError( + const std::shared_ptr<Aws::Http::HttpResponse>& httpResponse) const +{ + AWSError<CoreErrors> error; + if (httpResponse->HasClientError()) + { + bool retryable = httpResponse->GetClientErrorType() == CoreErrors::NETWORK_CONNECTION ? true : false; + error = AWSError<CoreErrors>(httpResponse->GetClientErrorType(), "", httpResponse->GetClientErrorMessage(), retryable); + } + else if (!httpResponse->GetResponseBody() || httpResponse->GetResponseBody().tellp() < 1) + { + auto responseCode = httpResponse->GetResponseCode(); + auto errorCode = GuessBodylessErrorType(responseCode); + + Aws::StringStream ss; + ss << "No response body."; + error = AWSError<CoreErrors>(errorCode, "", ss.str(), + IsRetryableHttpResponseCode(responseCode)); + } + else + { + assert(httpResponse->GetResponseCode() != HttpResponseCode::OK); + error = GetErrorMarshaller()->Marshall(*httpResponse); + } + + error.SetResponseHeaders(httpResponse->GetHeaders()); + error.SetResponseCode(httpResponse->GetResponseCode()); + error.SetRemoteHostIpAddress(httpResponse->GetOriginatingRequest().GetResolvedRemoteHost()); + AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, error); + return error; +} + +///////////////////////////////////////////////////////////////////////////////////////// +AWSXMLClient::AWSXMLClient(const Aws::Client::ClientConfiguration& configuration, + const std::shared_ptr<Aws::Client::AWSAuthSigner>& signer, + const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) : + BASECLASS(configuration, signer, errorMarshaller) +{ +} + +AWSXMLClient::AWSXMLClient(const Aws::Client::ClientConfiguration& configuration, + const std::shared_ptr<Aws::Auth::AWSAuthSignerProvider>& signerProvider, + const std::shared_ptr<AWSErrorMarshaller>& errorMarshaller) : + BASECLASS(configuration, signerProvider, errorMarshaller) +{ +} + +XmlOutcome AWSXMLClient::MakeRequest(const Aws::Http::URI& uri, + const Aws::AmazonWebServiceRequest& request, + Http::HttpMethod method, + const char* signerName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpOutcome(BASECLASS::AttemptExhaustively(uri, request, method, signerName, signerRegionOverride, signerServiceNameOverride)); + if (!httpOutcome.IsSuccess()) + { + return XmlOutcome(std::move(httpOutcome)); + } + + if (httpOutcome.GetResult()->GetResponseBody().tellp() > 0) + { + XmlDocument xmlDoc = XmlDocument::CreateFromXmlStream(httpOutcome.GetResult()->GetResponseBody()); + + if (!xmlDoc.WasParseSuccessful()) + { + AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, "Xml parsing for error failed with message " << xmlDoc.GetErrorMessage().c_str()); + return AWSError<CoreErrors>(CoreErrors::UNKNOWN, "Xml Parse Error", xmlDoc.GetErrorMessage(), false); + } + + return XmlOutcome(AmazonWebServiceResult<XmlDocument>(std::move(xmlDoc), + httpOutcome.GetResult()->GetHeaders(), httpOutcome.GetResult()->GetResponseCode())); + } + + return XmlOutcome(AmazonWebServiceResult<XmlDocument>(XmlDocument(), httpOutcome.GetResult()->GetHeaders())); +} + +XmlOutcome AWSXMLClient::MakeRequest(const Aws::Http::URI& uri, + Http::HttpMethod method, + const char* signerName, + const char* requestName, + const char* signerRegionOverride, + const char* signerServiceNameOverride) const +{ + HttpResponseOutcome httpOutcome(BASECLASS::AttemptExhaustively(uri, method, signerName, requestName, signerRegionOverride, signerServiceNameOverride)); + if (!httpOutcome.IsSuccess()) + { + return XmlOutcome(std::move(httpOutcome)); + } + + if (httpOutcome.GetResult()->GetResponseBody().tellp() > 0) + { + return XmlOutcome(AmazonWebServiceResult<XmlDocument>( + XmlDocument::CreateFromXmlStream(httpOutcome.GetResult()->GetResponseBody()), + httpOutcome.GetResult()->GetHeaders(), httpOutcome.GetResult()->GetResponseCode())); + } + + return XmlOutcome(AmazonWebServiceResult<XmlDocument>(XmlDocument(), httpOutcome.GetResult()->GetHeaders())); +} + +AWSError<CoreErrors> AWSXMLClient::BuildAWSError(const std::shared_ptr<Http::HttpResponse>& httpResponse) const +{ + AWSError<CoreErrors> error; + if (httpResponse->HasClientError()) + { + bool retryable = httpResponse->GetClientErrorType() == CoreErrors::NETWORK_CONNECTION ? true : false; + error = AWSError<CoreErrors>(httpResponse->GetClientErrorType(), "", httpResponse->GetClientErrorMessage(), retryable); + } + else if (!httpResponse->GetResponseBody() || httpResponse->GetResponseBody().tellp() < 1) + { + auto responseCode = httpResponse->GetResponseCode(); + auto errorCode = GuessBodylessErrorType(responseCode); + + Aws::StringStream ss; + ss << "No response body."; + error = AWSError<CoreErrors>(errorCode, "", ss.str(), IsRetryableHttpResponseCode(responseCode)); + } + else + { + assert(httpResponse->GetResponseCode() != HttpResponseCode::OK); + + // When trying to build an AWS Error from a response which is an FStream, we need to rewind the + // file pointer back to the beginning in order to correctly read the input using the XML string iterator + if ((httpResponse->GetResponseBody().tellp() > 0) + && (httpResponse->GetResponseBody().tellg() > 0)) + { + httpResponse->GetResponseBody().seekg(0); + } + + error = GetErrorMarshaller()->Marshall(*httpResponse); + } + + error.SetResponseHeaders(httpResponse->GetHeaders()); + error.SetResponseCode(httpResponse->GetResponseCode()); + error.SetRemoteHostIpAddress(httpResponse->GetOriginatingRequest().GetResolvedRemoteHost()); + AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, error); + return error; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AWSErrorMarshaller.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AWSErrorMarshaller.cpp new file mode 100644 index 0000000000..f5fa676f98 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AWSErrorMarshaller.cpp @@ -0,0 +1,180 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/AWSErrorMarshaller.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/json/JsonSerializer.h> +#include <aws/core/utils/xml/XmlSerializer.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/client/AWSError.h> +#include <aws/core/client/CoreErrors.h> + +using namespace Aws::Utils::Logging; +using namespace Aws::Utils::Json; +using namespace Aws::Utils::Xml; +using namespace Aws::Http; +using namespace Aws::Utils; +using namespace Aws::Client; + +static const char AWS_ERROR_MARSHALLER_LOG_TAG[] = "AWSErrorMarshaller"; +AWS_CORE_API extern const char MESSAGE_LOWER_CASE[] = "message"; +AWS_CORE_API extern const char MESSAGE_CAMEL_CASE[] = "Message"; +AWS_CORE_API extern const char ERROR_TYPE_HEADER[] = "x-amzn-ErrorType"; +AWS_CORE_API extern const char REQUEST_ID_HEADER[] = "x-amzn-RequestId"; +AWS_CORE_API extern const char TYPE[] = "__type"; + +AWSError<CoreErrors> JsonErrorMarshaller::Marshall(const Aws::Http::HttpResponse& httpResponse) const +{ + JsonValue exceptionPayload(httpResponse.GetResponseBody()); + JsonView payloadView(exceptionPayload); + AWSError<CoreErrors> error; + if (exceptionPayload.WasParseSuccessful()) + { + AWS_LOGSTREAM_TRACE(AWS_ERROR_MARSHALLER_LOG_TAG, "Error response is " << payloadView.WriteReadable()); + + Aws::String message(payloadView.ValueExists(MESSAGE_CAMEL_CASE) ? payloadView.GetString(MESSAGE_CAMEL_CASE) : + payloadView.ValueExists(MESSAGE_LOWER_CASE) ? payloadView.GetString(MESSAGE_LOWER_CASE) : ""); + + if (httpResponse.HasHeader(ERROR_TYPE_HEADER)) + { + error = Marshall(httpResponse.GetHeader(ERROR_TYPE_HEADER), message); + } + else if (payloadView.ValueExists(TYPE)) + { + error = Marshall(payloadView.GetString(TYPE), message); + } + else + { + error = FindErrorByHttpResponseCode(httpResponse.GetResponseCode()); + error.SetMessage(message); + } + } + else + { + error = AWSError<CoreErrors>(CoreErrors::UNKNOWN, "", "Failed to parse error payload", false); + } + + error.SetRequestId(httpResponse.HasHeader(REQUEST_ID_HEADER) ? httpResponse.GetHeader(REQUEST_ID_HEADER) : ""); + error.SetJsonPayload(std::move(exceptionPayload)); + return error; +} + +const JsonValue& JsonErrorMarshaller::GetJsonPayloadFromError(const AWSError<CoreErrors>& error) const +{ + return error.GetJsonPayload(); +} + +AWSError<CoreErrors> XmlErrorMarshaller::Marshall(const Aws::Http::HttpResponse& httpResponse) const +{ + XmlDocument doc = XmlDocument::CreateFromXmlStream(httpResponse.GetResponseBody()); + AWS_LOGSTREAM_TRACE(AWS_ERROR_MARSHALLER_LOG_TAG, "Error response is " << doc.ConvertToString()); + bool errorParsed = false; + AWSError<CoreErrors> error; + if (doc.WasParseSuccessful() && !doc.GetRootElement().IsNull()) + { + XmlNode errorNode = doc.GetRootElement(); + + Aws::String requestId(!errorNode.FirstChild("RequestId").IsNull() ? errorNode.FirstChild("RequestId").GetText() : + !errorNode.FirstChild("RequestID").IsNull() ? errorNode.FirstChild("RequestID").GetText() : ""); + + if (errorNode.GetName() != "Error") + { + errorNode = doc.GetRootElement().FirstChild("Error"); + } + if (errorNode.IsNull()) + { + errorNode = doc.GetRootElement().FirstChild("Errors"); + if(!errorNode.IsNull()) + { + errorNode = errorNode.FirstChild("Error"); + } + } + + if (!errorNode.IsNull()) + { + requestId = !requestId.empty() ? requestId : !errorNode.FirstChild("RequestId").IsNull() ? errorNode.FirstChild("RequestId").GetText() : + !errorNode.FirstChild("RequestID").IsNull() ? errorNode.FirstChild("RequestID").GetText() : ""; + + XmlNode codeNode = errorNode.FirstChild("Code"); + XmlNode messageNode = errorNode.FirstChild("Message"); + + if (!codeNode.IsNull()) + { + error = Marshall(StringUtils::Trim(codeNode.GetText().c_str()), + StringUtils::Trim(messageNode.GetText().c_str())); + errorParsed = true; + } + } + + error.SetRequestId(requestId); + } + + if(!errorParsed) + { + // An error occurred attempting to parse the httpResponse as an XML stream, so we're just + // going to dump the XML parsing error and the http response code as a string + AWS_LOGSTREAM_WARN(AWS_ERROR_MARSHALLER_LOG_TAG, "Unable to generate a proper httpResponse from the response " + "stream. Response code: " << static_cast< uint32_t >(httpResponse.GetResponseCode())); + error = FindErrorByHttpResponseCode(httpResponse.GetResponseCode()); + } + + error.SetXmlPayload(std::move(doc)); + return error; +} + +const XmlDocument& XmlErrorMarshaller::GetXmlPayloadFromError(const AWSError<CoreErrors>& error) const +{ + return error.GetXmlPayload(); +} + +AWSError<CoreErrors> AWSErrorMarshaller::Marshall(const Aws::String& exceptionName, const Aws::String& message) const +{ + if(exceptionName.empty()) + { + return AWSError<CoreErrors>(CoreErrors::UNKNOWN, "", message, false); + } + + auto locationOfPound = exceptionName.find_first_of('#'); + auto locationOfColon = exceptionName.find_first_of(':'); + Aws::String formalExceptionName; + + if (locationOfPound != Aws::String::npos) + { + formalExceptionName = exceptionName.substr(locationOfPound + 1); + } + else if (locationOfColon != Aws::String::npos) + { + formalExceptionName = exceptionName.substr(0, locationOfColon); + } + else + { + formalExceptionName = exceptionName; + } + + AWSError<CoreErrors> error = FindErrorByName(formalExceptionName.c_str()); + if (error.GetErrorType() != CoreErrors::UNKNOWN) + { + AWS_LOGSTREAM_WARN(AWS_ERROR_MARSHALLER_LOG_TAG, "Encountered AWSError '" << formalExceptionName.c_str() << + "': " << message.c_str()); + error.SetExceptionName(formalExceptionName); + error.SetMessage(message); + return error; + } + + AWS_LOGSTREAM_WARN(AWS_ERROR_MARSHALLER_LOG_TAG, "Encountered Unknown AWSError '" << exceptionName.c_str() << + "': " << message.c_str()); + + return AWSError<CoreErrors>(CoreErrors::UNKNOWN, exceptionName, "Unable to parse ExceptionName: " + exceptionName + " Message: " + message, false); +} + +AWSError<CoreErrors> AWSErrorMarshaller::FindErrorByName(const char* errorName) const +{ + return CoreErrorsMapper::GetErrorForName(errorName); +} + +AWSError<CoreErrors> AWSErrorMarshaller::FindErrorByHttpResponseCode(Aws::Http::HttpResponseCode code) const +{ + return CoreErrorsMapper::GetErrorForHttpResponseCode(code); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AsyncCallerContext.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AsyncCallerContext.cpp new file mode 100644 index 0000000000..4f9abdc9e4 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/AsyncCallerContext.cpp @@ -0,0 +1,16 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/AsyncCallerContext.h> +#include <aws/core/utils/UUID.h> + +namespace Aws +{ + namespace Client + { + AsyncCallerContext::AsyncCallerContext() : m_uuid(Aws::Utils::UUID::RandomUUID()) + {} + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp new file mode 100644 index 0000000000..e517379a77 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp @@ -0,0 +1,160 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/ClientConfiguration.h> +#include <aws/core/auth/AWSCredentialsProvider.h> +#include <aws/core/client/DefaultRetryStrategy.h> +#include <aws/core/platform/Environment.h> +#include <aws/core/platform/OSVersionInfo.h> +#include <aws/core/utils/memory/AWSMemory.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/threading/Executor.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/Version.h> +#include <aws/core/config/AWSProfileConfigLoader.h> +#include <aws/core/utils/logging/LogMacros.h> + +namespace Aws +{ +namespace Auth +{ + AWS_CORE_API Aws::String GetConfigProfileFilename(); +} +namespace Client +{ + +static const char* CLIENT_CONFIG_TAG = "ClientConfiguration"; + +AWS_CORE_API Aws::String ComputeUserAgentString() +{ + Aws::StringStream ss; + ss << "aws-sdk-cpp/" << Version::GetVersionString() << " " << Aws::OSVersionInfo::ComputeOSVersionString() + << " " << Version::GetCompilerVersionString(); + return ss.str(); +} + +ClientConfiguration::ClientConfiguration() : + scheme(Aws::Http::Scheme::HTTPS), + useDualStack(false), + maxConnections(25), + httpRequestTimeoutMs(0), + requestTimeoutMs(3000), + connectTimeoutMs(1000), + enableTcpKeepAlive(true), + tcpKeepAliveIntervalMs(30000), + lowSpeedLimit(1), + proxyScheme(Aws::Http::Scheme::HTTP), + proxyPort(0), + executor(Aws::MakeShared<Aws::Utils::Threading::DefaultExecutor>(CLIENT_CONFIG_TAG)), + verifySSL(true), + writeRateLimiter(nullptr), + readRateLimiter(nullptr), + httpLibOverride(Aws::Http::TransferLibType::DEFAULT_CLIENT), + followRedirects(FollowRedirectsPolicy::DEFAULT), + disableExpectHeader(false), + enableClockSkewAdjustment(true), + enableHostPrefixInjection(true), + enableEndpointDiscovery(false), + profileName(Aws::Auth::GetConfigProfileName()) +{ + AWS_LOGSTREAM_DEBUG(CLIENT_CONFIG_TAG, "ClientConfiguration will use SDK Auto Resolved profile: [" << profileName << "] if not specified by users."); + + // Initialize Retry Strategy + int maxAttempts; + Aws::String maxAttemptsString = Aws::Environment::GetEnv("AWS_MAX_ATTEMPTS"); + if (maxAttemptsString.empty()) + { + maxAttemptsString = Aws::Config::GetCachedConfigValue("max_attempts"); + } + // In case users specify 0 explicitly to disable retry. + if (maxAttemptsString == "0") + { + maxAttempts = 0; + } + else + { + maxAttempts = static_cast<int>(Aws::Utils::StringUtils::ConvertToInt32(maxAttemptsString.c_str())); + if (maxAttempts == 0) + { + AWS_LOGSTREAM_WARN(CLIENT_CONFIG_TAG, "Retry Strategy will use the default max attempts."); + maxAttempts = -1; + } + } + + Aws::String retryMode = Aws::Environment::GetEnv("AWS_RETRY_MODE"); + if (retryMode.empty()) + { + retryMode = Aws::Config::GetCachedConfigValue("retry_mode"); + } + if (retryMode == "standard") + { + if (maxAttempts < 0) + { + retryStrategy = Aws::MakeShared<StandardRetryStrategy>(CLIENT_CONFIG_TAG); + } + else + { + retryStrategy = Aws::MakeShared<StandardRetryStrategy>(CLIENT_CONFIG_TAG, maxAttempts); + } + } + else + { + retryStrategy = Aws::MakeShared<DefaultRetryStrategy>(CLIENT_CONFIG_TAG); + } + + // Automatically determine the AWS region from environment variables, configuration file and EC2 metadata. + region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION"); + if (!region.empty()) + { + return; + } + + region = Aws::Environment::GetEnv("AWS_REGION"); + if (!region.empty()) + { + return; + } + + region = Aws::Config::GetCachedConfigValue("region"); + if (!region.empty()) + { + return; + } + + if (Aws::Utils::StringUtils::ToLower(Aws::Environment::GetEnv("AWS_EC2_METADATA_DISABLED").c_str()) != "true") + { + auto client = Aws::Internal::GetEC2MetadataClient(); + if (client) + { + region = client->GetCurrentRegion(); + } + } + + if (!region.empty()) + { + return; + } + + region = Aws::String(Aws::Region::US_EAST_1); +} + +ClientConfiguration::ClientConfiguration(const char* profile) : ClientConfiguration() +{ + if (profile && Aws::Config::HasCachedConfigProfile(profile)) + { + this->profileName = Aws::String(profile); + AWS_LOGSTREAM_DEBUG(CLIENT_CONFIG_TAG, "Use user specified profile: [" << this->profileName << "] for ClientConfiguration."); + auto tmpRegion = Aws::Config::GetCachedConfigProfile(this->profileName).GetRegion(); + if (!tmpRegion.empty()) + { + region = tmpRegion; + } + return; + } + AWS_LOGSTREAM_WARN(CLIENT_CONFIG_TAG, "User specified profile: [" << profile << "] is not found, will use the SDK resolved one."); +} + +} // namespace Client +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/CoreErrors.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/CoreErrors.cpp new file mode 100644 index 0000000000..8c2c288dcd --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/CoreErrors.cpp @@ -0,0 +1,151 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/AWSError.h> +#include <aws/core/client/CoreErrors.h> +#include <aws/core/utils/memory/stl/AWSMap.h> +#include <aws/core/utils/HashingUtils.h> + +using namespace Aws::Client; +using namespace Aws::Utils; +using namespace Aws::Http; + +#ifdef _MSC_VER +#pragma warning(push) +// VS2015 compiler's bug, warning s_CoreErrorsMapper: symbol will be dynamically initialized (implementation limitation) +#pragma warning(disable : 4592) +#endif + +static Aws::UniquePtr<Aws::Map<Aws::String, AWSError<CoreErrors> > > s_CoreErrorsMapper(nullptr); + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +void CoreErrorsMapper::InitCoreErrorsMapper() +{ + if (s_CoreErrorsMapper) + { + return; + } + s_CoreErrorsMapper = Aws::MakeUnique<Aws::Map<Aws::String, AWSError<CoreErrors> > >("InitCoreErrorsMapper"); + + s_CoreErrorsMapper->emplace("IncompleteSignature", AWSError<CoreErrors>(CoreErrors::INCOMPLETE_SIGNATURE, false)); + s_CoreErrorsMapper->emplace("IncompleteSignatureException", AWSError<CoreErrors>(CoreErrors::INCOMPLETE_SIGNATURE, false)); + s_CoreErrorsMapper->emplace("InvalidSignatureException", AWSError<CoreErrors>(CoreErrors::INVALID_SIGNATURE, false)); + s_CoreErrorsMapper->emplace("InvalidSignature", AWSError<CoreErrors>(CoreErrors::INVALID_SIGNATURE, false)); + s_CoreErrorsMapper->emplace("InternalFailureException", AWSError<CoreErrors>(CoreErrors::INTERNAL_FAILURE, true)); + s_CoreErrorsMapper->emplace("InternalFailure", AWSError<CoreErrors>(CoreErrors::INTERNAL_FAILURE, true)); + s_CoreErrorsMapper->emplace("InternalServerError", AWSError<CoreErrors>(CoreErrors::INTERNAL_FAILURE, true)); + s_CoreErrorsMapper->emplace("InternalError", AWSError<CoreErrors>(CoreErrors::INTERNAL_FAILURE, true)); + s_CoreErrorsMapper->emplace("InvalidActionException", AWSError<CoreErrors>(CoreErrors::INVALID_ACTION, false)); + s_CoreErrorsMapper->emplace("InvalidAction", AWSError<CoreErrors>(CoreErrors::INVALID_ACTION, false)); + s_CoreErrorsMapper->emplace("InvalidClientTokenIdException", AWSError<CoreErrors>(CoreErrors::INVALID_CLIENT_TOKEN_ID, false)); + s_CoreErrorsMapper->emplace("InvalidClientTokenId", AWSError<CoreErrors>(CoreErrors::INVALID_CLIENT_TOKEN_ID, false)); + s_CoreErrorsMapper->emplace("InvalidParameterCombinationException", AWSError<CoreErrors>(CoreErrors::INVALID_PARAMETER_COMBINATION, false)); + s_CoreErrorsMapper->emplace("InvalidParameterCombination", AWSError<CoreErrors>(CoreErrors::INVALID_PARAMETER_COMBINATION, false)); + s_CoreErrorsMapper->emplace("InvalidParameterValueException", AWSError<CoreErrors>(CoreErrors::INVALID_PARAMETER_VALUE, false)); + s_CoreErrorsMapper->emplace("InvalidParameterValue", AWSError<CoreErrors>(CoreErrors::INVALID_PARAMETER_VALUE, false)); + s_CoreErrorsMapper->emplace("InvalidQueryParameterException", AWSError<CoreErrors>(CoreErrors::INVALID_QUERY_PARAMETER, false)); + s_CoreErrorsMapper->emplace("InvalidQueryParameter", AWSError<CoreErrors>(CoreErrors::INVALID_QUERY_PARAMETER, false)); + s_CoreErrorsMapper->emplace("MalformedQueryStringException", AWSError<CoreErrors>(CoreErrors::MALFORMED_QUERY_STRING, false)); + s_CoreErrorsMapper->emplace("MalformedQueryString", AWSError<CoreErrors>(CoreErrors::MALFORMED_QUERY_STRING, false)); + s_CoreErrorsMapper->emplace("MissingActionException", AWSError<CoreErrors>(CoreErrors::MISSING_ACTION, false)); + s_CoreErrorsMapper->emplace("MissingAction", AWSError<CoreErrors>(CoreErrors::MISSING_ACTION, false)); + s_CoreErrorsMapper->emplace("MissingAuthenticationTokenException", AWSError<CoreErrors>(CoreErrors::MISSING_AUTHENTICATION_TOKEN, false)); + s_CoreErrorsMapper->emplace("MissingAuthenticationToken", AWSError<CoreErrors>(CoreErrors::MISSING_AUTHENTICATION_TOKEN, false)); + s_CoreErrorsMapper->emplace("MissingParameterException", AWSError<CoreErrors>(CoreErrors::MISSING_PARAMETER, false)); + s_CoreErrorsMapper->emplace("MissingParameter", AWSError<CoreErrors>(CoreErrors::MISSING_PARAMETER, false)); + s_CoreErrorsMapper->emplace("OptInRequired", AWSError<CoreErrors>(CoreErrors::OPT_IN_REQUIRED, false)); + s_CoreErrorsMapper->emplace("RequestExpiredException", AWSError<CoreErrors>(CoreErrors::REQUEST_EXPIRED, true)); + s_CoreErrorsMapper->emplace("RequestExpired", AWSError<CoreErrors>(CoreErrors::REQUEST_EXPIRED, true)); + s_CoreErrorsMapper->emplace("ServiceUnavailableException", AWSError<CoreErrors>(CoreErrors::SERVICE_UNAVAILABLE, true)); + s_CoreErrorsMapper->emplace("ServiceUnavailableError", AWSError<CoreErrors>(CoreErrors::SERVICE_UNAVAILABLE, true)); + s_CoreErrorsMapper->emplace("ServiceUnavailable", AWSError<CoreErrors>(CoreErrors::SERVICE_UNAVAILABLE, true)); + s_CoreErrorsMapper->emplace("RequestThrottledException", AWSError<CoreErrors>(CoreErrors::THROTTLING, true)); + s_CoreErrorsMapper->emplace("RequestThrottled", AWSError<CoreErrors>(CoreErrors::THROTTLING, true)); + s_CoreErrorsMapper->emplace("ThrottlingException", AWSError<CoreErrors>(CoreErrors::THROTTLING, true)); + s_CoreErrorsMapper->emplace("ThrottledException", AWSError<CoreErrors>(CoreErrors::THROTTLING, true)); + s_CoreErrorsMapper->emplace("Throttling", AWSError<CoreErrors>(CoreErrors::THROTTLING, true)); + s_CoreErrorsMapper->emplace("ValidationErrorException", AWSError<CoreErrors>(CoreErrors::VALIDATION, false)); + s_CoreErrorsMapper->emplace("ValidationException", AWSError<CoreErrors>(CoreErrors::VALIDATION, false)); + s_CoreErrorsMapper->emplace("ValidationError", AWSError<CoreErrors>(CoreErrors::VALIDATION, false)); + s_CoreErrorsMapper->emplace("AccessDeniedException", AWSError<CoreErrors>(CoreErrors::ACCESS_DENIED, false)); + s_CoreErrorsMapper->emplace("AccessDenied", AWSError<CoreErrors>(CoreErrors::ACCESS_DENIED, false)); + s_CoreErrorsMapper->emplace("ResourceNotFoundException", AWSError<CoreErrors>(CoreErrors::RESOURCE_NOT_FOUND, false)); + s_CoreErrorsMapper->emplace("ResourceNotFound", AWSError<CoreErrors>(CoreErrors::RESOURCE_NOT_FOUND, false)); + s_CoreErrorsMapper->emplace("UnrecognizedClientException", AWSError<CoreErrors>(CoreErrors::UNRECOGNIZED_CLIENT, false)); + s_CoreErrorsMapper->emplace("UnrecognizedClient", AWSError<CoreErrors>(CoreErrors::UNRECOGNIZED_CLIENT, false)); + s_CoreErrorsMapper->emplace("SlowDownException", AWSError<CoreErrors>(CoreErrors::SLOW_DOWN, true)); + s_CoreErrorsMapper->emplace("SlowDown", AWSError<CoreErrors>(CoreErrors::SLOW_DOWN, true)); + s_CoreErrorsMapper->emplace("SignatureDoesNotMatchException", AWSError<CoreErrors>(CoreErrors::SIGNATURE_DOES_NOT_MATCH, false)); + s_CoreErrorsMapper->emplace("SignatureDoesNotMatch", AWSError<CoreErrors>(CoreErrors::SIGNATURE_DOES_NOT_MATCH, false)); + s_CoreErrorsMapper->emplace("InvalidAccessKeyIdException", AWSError<CoreErrors>(CoreErrors::INVALID_ACCESS_KEY_ID, false)); + s_CoreErrorsMapper->emplace("InvalidAccessKeyId", AWSError<CoreErrors>(CoreErrors::INVALID_ACCESS_KEY_ID, false)); + s_CoreErrorsMapper->emplace("RequestTimeTooSkewedException", AWSError<CoreErrors>(CoreErrors::REQUEST_TIME_TOO_SKEWED, true)); + s_CoreErrorsMapper->emplace("RequestTimeTooSkewed", AWSError<CoreErrors>(CoreErrors::REQUEST_TIME_TOO_SKEWED, true)); + s_CoreErrorsMapper->emplace("RequestTimeoutException", AWSError<CoreErrors>(CoreErrors::REQUEST_TIMEOUT, true)); + s_CoreErrorsMapper->emplace("RequestTimeout", AWSError<CoreErrors>(CoreErrors::REQUEST_TIMEOUT, true)); +} + +void CoreErrorsMapper::CleanupCoreErrorsMapper() +{ + if (s_CoreErrorsMapper) + { + s_CoreErrorsMapper = nullptr; + } +} + +AWSError<CoreErrors> CoreErrorsMapper::GetErrorForName(const char* errorName) +{ + auto iter = s_CoreErrorsMapper->find(errorName); + if (iter != s_CoreErrorsMapper->end()) + { + return iter->second; + } + return AWSError<CoreErrors>(CoreErrors::UNKNOWN, false); +} + +AWS_CORE_API AWSError<CoreErrors> CoreErrorsMapper::GetErrorForHttpResponseCode(HttpResponseCode code) +{ + // best effort attempt to map HTTP response codes to CoreErrors + bool retryable = IsRetryableHttpResponseCode(code); + AWSError<CoreErrors> error; + switch (code) + { + case HttpResponseCode::UNAUTHORIZED: + case HttpResponseCode::FORBIDDEN: + error = AWSError<CoreErrors>(CoreErrors::ACCESS_DENIED, retryable); + break; + case HttpResponseCode::NOT_FOUND: + error = AWSError<CoreErrors>(CoreErrors::RESOURCE_NOT_FOUND, retryable); + break; + case HttpResponseCode::TOO_MANY_REQUESTS: + error = AWSError<CoreErrors>(CoreErrors::SLOW_DOWN, retryable); + break; + case HttpResponseCode::INTERNAL_SERVER_ERROR: + error = AWSError<CoreErrors>(CoreErrors::INTERNAL_FAILURE, retryable); + break; + case HttpResponseCode::BANDWIDTH_LIMIT_EXCEEDED: + error = AWSError<CoreErrors>(CoreErrors::THROTTLING, retryable); + break; + case HttpResponseCode::SERVICE_UNAVAILABLE: + error = AWSError<CoreErrors>(CoreErrors::SERVICE_UNAVAILABLE, retryable); + break; + case HttpResponseCode::REQUEST_TIMEOUT: + case HttpResponseCode::AUTHENTICATION_TIMEOUT: + case HttpResponseCode::LOGIN_TIMEOUT: + case HttpResponseCode::GATEWAY_TIMEOUT: + case HttpResponseCode::NETWORK_READ_TIMEOUT: + case HttpResponseCode::NETWORK_CONNECT_TIMEOUT: + error = AWSError<CoreErrors>(CoreErrors::REQUEST_TIMEOUT, retryable); + break; + default: + int codeValue = static_cast<int>(code); + error = AWSError<CoreErrors>(CoreErrors::UNKNOWN, codeValue >= 500 && codeValue < 600); + } + error.SetResponseCode(code); + return error; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/DefaultRetryStrategy.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/DefaultRetryStrategy.cpp new file mode 100644 index 0000000000..7e57c79ffc --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/DefaultRetryStrategy.cpp @@ -0,0 +1,32 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/DefaultRetryStrategy.h> + +#include <aws/core/client/AWSError.h> +#include <aws/core/utils/UnreferencedParam.h> + +using namespace Aws; +using namespace Aws::Client; + +bool DefaultRetryStrategy::ShouldRetry(const AWSError<CoreErrors>& error, long attemptedRetries) const +{ + if (attemptedRetries >= m_maxRetries) + return false; + + return error.ShouldRetry(); +} + +long DefaultRetryStrategy::CalculateDelayBeforeNextRetry(const AWSError<CoreErrors>& error, long attemptedRetries) const +{ + AWS_UNREFERENCED_PARAM(error); + + if (attemptedRetries == 0) + { + return 0; + } + + return (1 << attemptedRetries) * m_scaleFactor; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/RetryStrategy.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/RetryStrategy.cpp new file mode 100644 index 0000000000..b439b7ca99 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/RetryStrategy.cpp @@ -0,0 +1,102 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/RetryStrategy.h> + +#include <aws/core/client/AWSError.h> +#include <aws/core/client/CoreErrors.h> +#include <aws/core/utils/Outcome.h> + +using namespace Aws::Utils::Threading; + +namespace Aws +{ + namespace Client + { + static const int INITIAL_RETRY_TOKENS = 500; + static const int RETRY_COST = 5; + static const int NO_RETRY_INCREMENT = 1; + static const int TIMEOUT_RETRY_COST = 10; + + StandardRetryStrategy::StandardRetryStrategy(long maxAttempts) : + m_retryQuotaContainer(Aws::MakeShared<DefaultRetryQuotaContainer>("StandardRetryStrategy")), + m_maxAttempts(maxAttempts) + {} + + StandardRetryStrategy::StandardRetryStrategy(std::shared_ptr<RetryQuotaContainer> retryQuotaContainer, long maxAttempts) : + m_retryQuotaContainer(retryQuotaContainer), + m_maxAttempts(maxAttempts) + {} + + void StandardRetryStrategy::RequestBookkeeping(const HttpResponseOutcome& httpResponseOutcome) + { + if (httpResponseOutcome.IsSuccess()) + { + m_retryQuotaContainer->ReleaseRetryQuota(NO_RETRY_INCREMENT); + } + } + + void StandardRetryStrategy::RequestBookkeeping(const HttpResponseOutcome& httpResponseOutcome, const AWSError<CoreErrors>& lastError) + { + if (httpResponseOutcome.IsSuccess()) + { + m_retryQuotaContainer->ReleaseRetryQuota(lastError); + } + } + + bool StandardRetryStrategy::ShouldRetry(const AWSError<CoreErrors>& error, long attemptedRetries) const + { + if (!error.ShouldRetry()) + return false; + + if (attemptedRetries + 1 >= m_maxAttempts) + return false; + + return m_retryQuotaContainer->AcquireRetryQuota(error); + } + + long StandardRetryStrategy::CalculateDelayBeforeNextRetry(const AWSError<CoreErrors>& error, long attemptedRetries) const + { + AWS_UNREFERENCED_PARAM(error); + return (std::min)(rand() % 1000 * (1 << attemptedRetries), 20000); + } + + DefaultRetryQuotaContainer::DefaultRetryQuotaContainer() : m_retryQuota(INITIAL_RETRY_TOKENS) + {} + + bool DefaultRetryQuotaContainer::AcquireRetryQuota(int capacityAmount) + { + WriterLockGuard guard(m_retryQuotaLock); + + if (capacityAmount > m_retryQuota) + { + return false; + } + else + { + m_retryQuota -= capacityAmount; + return true; + } + } + + bool DefaultRetryQuotaContainer::AcquireRetryQuota(const AWSError<CoreErrors>& error) + { + int capacityAmount = error.GetErrorType() == CoreErrors::REQUEST_TIMEOUT ? TIMEOUT_RETRY_COST : RETRY_COST; + return AcquireRetryQuota(capacityAmount); + } + + void DefaultRetryQuotaContainer::ReleaseRetryQuota(int capacityAmount) + { + WriterLockGuard guard(m_retryQuotaLock); + m_retryQuota = (std::min)(m_retryQuota + capacityAmount, INITIAL_RETRY_TOKENS); + } + + void DefaultRetryQuotaContainer::ReleaseRetryQuota(const AWSError<CoreErrors>& error) + { + int capacityAmount = error.GetErrorType() == CoreErrors::REQUEST_TIMEOUT ? TIMEOUT_RETRY_COST : RETRY_COST; + ReleaseRetryQuota(capacityAmount); + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/SpecifiedRetryableErrorsRetryStrategy.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/SpecifiedRetryableErrorsRetryStrategy.cpp new file mode 100644 index 0000000000..ec4e373304 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/client/SpecifiedRetryableErrorsRetryStrategy.cpp @@ -0,0 +1,28 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h> + +#include <aws/core/client/AWSError.h> + +using namespace Aws; +using namespace Aws::Client; + +bool SpecifiedRetryableErrorsRetryStrategy::ShouldRetry(const AWSError<CoreErrors>& error, long attemptedRetries) const +{ + if (attemptedRetries >= m_maxRetries) + { + return false; + } + for (const auto& err: m_specifiedRetryableErrors) + { + if (error.GetExceptionName() == err) + { + return true; + } + } + + return error.ShouldRetry(); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/config/AWSProfileConfigLoader.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/config/AWSProfileConfigLoader.cpp new file mode 100644 index 0000000000..9ec2e54f55 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/config/AWSProfileConfigLoader.cpp @@ -0,0 +1,540 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/config/AWSProfileConfigLoader.h> +#include <aws/core/internal/AWSHttpResourceClient.h> +#include <aws/core/auth/AWSCredentialsProvider.h> +#include <aws/core/utils/memory/stl/AWSList.h> +#include <aws/core/utils/memory/stl/AWSStreamFwd.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/json/JsonSerializer.h> +#include <fstream> + +namespace Aws +{ + namespace Config + { + using namespace Aws::Utils; + using namespace Aws::Auth; + + static const char* const CONFIG_LOADER_TAG = "Aws::Config::AWSProfileConfigLoader"; + #ifdef _MSC_VER + // VS2015 compiler's bug, warning s_CoreErrorsMapper: symbol will be dynamically initialized (implementation limitation) + AWS_SUPPRESS_WARNING(4592, + static Aws::UniquePtr<ConfigAndCredentialsCacheManager> s_configManager(nullptr); + ) + #else + static Aws::UniquePtr<ConfigAndCredentialsCacheManager> s_configManager(nullptr); + #endif + + static const char CONFIG_CREDENTIALS_CACHE_MANAGER_TAG[] = "ConfigAndCredentialsCacheManager"; + + bool AWSProfileConfigLoader::Load() + { + if(LoadInternal()) + { + AWS_LOGSTREAM_INFO(CONFIG_LOADER_TAG, "Successfully reloaded configuration."); + m_lastLoadTime = DateTime::Now(); + AWS_LOGSTREAM_TRACE(CONFIG_LOADER_TAG, "reloaded config at " + << m_lastLoadTime.ToGmtString(DateFormat::ISO_8601)); + return true; + } + + AWS_LOGSTREAM_INFO(CONFIG_LOADER_TAG, "Failed to reload configuration."); + return false; + } + + bool AWSProfileConfigLoader::PersistProfiles(const Aws::Map<Aws::String, Profile>& profiles) + { + if(PersistInternal(profiles)) + { + AWS_LOGSTREAM_INFO(CONFIG_LOADER_TAG, "Successfully persisted configuration."); + m_profiles = profiles; + m_lastLoadTime = DateTime::Now(); + AWS_LOGSTREAM_TRACE(CONFIG_LOADER_TAG, "persisted config at " + << m_lastLoadTime.ToGmtString(DateFormat::ISO_8601)); + return true; + } + + AWS_LOGSTREAM_WARN(CONFIG_LOADER_TAG, "Failed to persist configuration."); + return false; + } + + static const char REGION_KEY[] = "region"; + static const char ACCESS_KEY_ID_KEY[] = "aws_access_key_id"; + static const char SECRET_KEY_KEY[] = "aws_secret_access_key"; + static const char SESSION_TOKEN_KEY[] = "aws_session_token"; + static const char ROLE_ARN_KEY[] = "role_arn"; + static const char EXTERNAL_ID_KEY[] = "external_id"; + static const char CREDENTIAL_PROCESS_COMMAND[] = "credential_process"; + static const char SOURCE_PROFILE_KEY[] = "source_profile"; + static const char PROFILE_PREFIX[] = "profile "; + static const char EQ = '='; + static const char LEFT_BRACKET = '['; + static const char RIGHT_BRACKET = ']'; + static const char PARSER_TAG[] = "Aws::Config::ConfigFileProfileFSM"; + + class ConfigFileProfileFSM + { + public: + ConfigFileProfileFSM() : m_parserState(START) {} + + const Aws::Map<String, Profile>& GetProfiles() const { return m_foundProfiles; } + + void ParseStream(Aws::IStream& stream) + { + static const size_t ASSUME_EMPTY_LEN = 3; + + Aws::String line; + while(std::getline(stream, line) && m_parserState != FAILURE) + { + if (line.empty() || line.length() < ASSUME_EMPTY_LEN) + { + continue; + } + + auto openPos = line.find(LEFT_BRACKET); + auto closePos = line.find(RIGHT_BRACKET); + + switch(m_parserState) + { + + case START: + if(openPos != std::string::npos && closePos != std::string::npos) + { + FlushProfileAndReset(line, openPos, closePos); + m_parserState = PROFILE_FOUND; + } + break; + + //fallthrough here is intentional to reduce duplicate logic + case PROFILE_KEY_VALUE_FOUND: + if(openPos != std::string::npos && closePos != std::string::npos) + { + m_parserState = PROFILE_FOUND; + FlushProfileAndReset(line, openPos, closePos); + break; + } + // fall through + case PROFILE_FOUND: + { + auto equalsPos = line.find(EQ); + if (equalsPos != std::string::npos) + { + auto key = line.substr(0, equalsPos); + auto value = line.substr(equalsPos + 1); + m_profileKeyValuePairs[StringUtils::Trim(key.c_str())] = + StringUtils::Trim(value.c_str()); + m_parserState = PROFILE_KEY_VALUE_FOUND; + } + + break; + } + default: + m_parserState = FAILURE; + break; + } + } + + FlushProfileAndReset(line, std::string::npos, std::string::npos); + } + + private: + + void FlushProfileAndReset(Aws::String& line, size_t openPos, size_t closePos) + { + if(!m_currentWorkingProfile.empty() && !m_profileKeyValuePairs.empty()) + { + Profile profile; + profile.SetName(m_currentWorkingProfile); + + auto regionIter = m_profileKeyValuePairs.find(REGION_KEY); + if (regionIter != m_profileKeyValuePairs.end()) + { + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found region " << regionIter->second); + profile.SetRegion(regionIter->second); + } + + auto accessKeyIdIter = m_profileKeyValuePairs.find(ACCESS_KEY_ID_KEY); + Aws::String accessKey, secretKey, sessionToken; + if (accessKeyIdIter != m_profileKeyValuePairs.end()) + { + accessKey = accessKeyIdIter->second; + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found access key " << accessKey); + + auto secretAccessKeyIter = m_profileKeyValuePairs.find(SECRET_KEY_KEY); + auto sessionTokenIter = m_profileKeyValuePairs.find(SESSION_TOKEN_KEY); + if (secretAccessKeyIter != m_profileKeyValuePairs.end()) + { + secretKey = secretAccessKeyIter->second; + } + else + { + AWS_LOGSTREAM_ERROR(PARSER_TAG, "No secret access key found even though an access key was specified. This will cause all signed AWS calls to fail."); + } + + if (sessionTokenIter != m_profileKeyValuePairs.end()) + { + sessionToken = sessionTokenIter->second; + } + + profile.SetCredentials(Aws::Auth::AWSCredentials(accessKey, secretKey, sessionToken)); + } + + auto assumeRoleArnIter = m_profileKeyValuePairs.find(ROLE_ARN_KEY); + if (assumeRoleArnIter != m_profileKeyValuePairs.end()) + { + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found role arn " << assumeRoleArnIter->second); + profile.SetRoleArn(assumeRoleArnIter->second); + } + + auto externalIdIter = m_profileKeyValuePairs.find(EXTERNAL_ID_KEY); + if (externalIdIter != m_profileKeyValuePairs.end()) + { + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found external id " << externalIdIter->second); + profile.SetExternalId(externalIdIter->second); + } + + auto sourceProfileIter = m_profileKeyValuePairs.find(SOURCE_PROFILE_KEY); + if (sourceProfileIter != m_profileKeyValuePairs.end()) + { + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found source profile " << sourceProfileIter->second); + profile.SetSourceProfile(sourceProfileIter->second); + } + + auto credentialProcessIter = m_profileKeyValuePairs.find(CREDENTIAL_PROCESS_COMMAND); + if (credentialProcessIter != m_profileKeyValuePairs.end()) + { + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found credential process " << credentialProcessIter->second); + profile.SetCredentialProcess(credentialProcessIter->second); + } + profile.SetAllKeyValPairs(m_profileKeyValuePairs); + + m_foundProfiles[profile.GetName()] = std::move(profile); + m_currentWorkingProfile.clear(); + m_profileKeyValuePairs.clear(); + } + + if(!line.empty() && openPos != std::string::npos && closePos != std::string::npos) + { + m_currentWorkingProfile = StringUtils::Trim(line.substr(openPos + 1, closePos - openPos - 1).c_str()); + StringUtils::Replace(m_currentWorkingProfile, PROFILE_PREFIX, ""); + AWS_LOGSTREAM_DEBUG(PARSER_TAG, "found profile " << m_currentWorkingProfile); + } + } + + enum State + { + START = 0, + PROFILE_FOUND, + PROFILE_KEY_VALUE_FOUND, + FAILURE + }; + + Aws::String m_currentWorkingProfile; + Aws::Map<String, String> m_profileKeyValuePairs; + State m_parserState; + Aws::Map<String, Profile> m_foundProfiles; + }; + + static const char* const CONFIG_FILE_LOADER = "Aws::Config::AWSConfigFileProfileConfigLoader"; + + AWSConfigFileProfileConfigLoader::AWSConfigFileProfileConfigLoader(const Aws::String& fileName, bool useProfilePrefix) : + m_fileName(fileName), m_useProfilePrefix(useProfilePrefix) + { + AWS_LOGSTREAM_INFO(CONFIG_FILE_LOADER, "Initializing config loader against fileName " + << fileName << " and using profilePrefix = " << useProfilePrefix); + } + + bool AWSConfigFileProfileConfigLoader::LoadInternal() + { + m_profiles.clear(); + + Aws::IFStream inputFile(m_fileName.c_str()); + if(inputFile) + { + ConfigFileProfileFSM parser; + parser.ParseStream(inputFile); + m_profiles = parser.GetProfiles(); + return m_profiles.size() > 0; + } + + AWS_LOGSTREAM_INFO(CONFIG_FILE_LOADER, "Unable to open config file " << m_fileName << " for reading."); + + return false; + } + + bool AWSConfigFileProfileConfigLoader::PersistInternal(const Aws::Map<Aws::String, Profile>& profiles) + { + Aws::OFStream outputFile(m_fileName.c_str(), std::ios_base::out | std::ios_base::trunc); + if(outputFile) + { + for(auto& profile : profiles) + { + Aws::String prefix = m_useProfilePrefix ? PROFILE_PREFIX : ""; + + AWS_LOGSTREAM_DEBUG(CONFIG_FILE_LOADER, "Writing profile " << profile.first << " to disk."); + + outputFile << LEFT_BRACKET << prefix << profile.second.GetName() << RIGHT_BRACKET << std::endl; + const Aws::Auth::AWSCredentials& credentials = profile.second.GetCredentials(); + outputFile << ACCESS_KEY_ID_KEY << EQ << credentials.GetAWSAccessKeyId() << std::endl; + outputFile << SECRET_KEY_KEY << EQ << credentials.GetAWSSecretKey() << std::endl; + + if(!credentials.GetSessionToken().empty()) + { + outputFile << SESSION_TOKEN_KEY << EQ << credentials.GetSessionToken() << std::endl; + } + + if(!profile.second.GetRegion().empty()) + { + outputFile << REGION_KEY << EQ << profile.second.GetRegion() << std::endl; + } + + if(!profile.second.GetRoleArn().empty()) + { + outputFile << ROLE_ARN_KEY << EQ << profile.second.GetRoleArn() << std::endl; + } + + if(!profile.second.GetSourceProfile().empty()) + { + outputFile << SOURCE_PROFILE_KEY << EQ << profile.second.GetSourceProfile() << std::endl; + } + + outputFile << std::endl; + } + + AWS_LOGSTREAM_INFO(CONFIG_FILE_LOADER, "Profiles written to config file " << m_fileName); + + return true; + } + + AWS_LOGSTREAM_WARN(CONFIG_FILE_LOADER, "Unable to open config file " << m_fileName << " for writing."); + + return false; + } + + static const char* const EC2_INSTANCE_PROFILE_LOG_TAG = "Aws::Config::EC2InstanceProfileConfigLoader"; + + EC2InstanceProfileConfigLoader::EC2InstanceProfileConfigLoader(const std::shared_ptr<Aws::Internal::EC2MetadataClient>& client) + : m_ec2metadataClient(client == nullptr ? Aws::MakeShared<Aws::Internal::EC2MetadataClient>(EC2_INSTANCE_PROFILE_LOG_TAG) : client) + { + } + + bool EC2InstanceProfileConfigLoader::LoadInternal() + { + auto credentialsStr = m_ec2metadataClient->GetDefaultCredentialsSecurely(); + if(credentialsStr.empty()) return false; + + Json::JsonValue credentialsDoc(credentialsStr); + if (!credentialsDoc.WasParseSuccessful()) + { + AWS_LOGSTREAM_ERROR(EC2_INSTANCE_PROFILE_LOG_TAG, + "Failed to parse output from EC2MetadataService."); + return false; + } + const char* accessKeyId = "AccessKeyId"; + const char* secretAccessKey = "SecretAccessKey"; + Aws::String accessKey, secretKey, token; + + auto credentialsView = credentialsDoc.View(); + accessKey = credentialsView.GetString(accessKeyId); + AWS_LOGSTREAM_INFO(EC2_INSTANCE_PROFILE_LOG_TAG, + "Successfully pulled credentials from metadata service with access key " << accessKey); + + secretKey = credentialsView.GetString(secretAccessKey); + token = credentialsView.GetString("Token"); + + auto region = m_ec2metadataClient->GetCurrentRegion(); + + Profile profile; + profile.SetCredentials(AWSCredentials(accessKey, secretKey, token)); + profile.SetRegion(region); + profile.SetName(INSTANCE_PROFILE_KEY); + + m_profiles[INSTANCE_PROFILE_KEY] = profile; + + return true; + } + + ConfigAndCredentialsCacheManager::ConfigAndCredentialsCacheManager() : + m_credentialsFileLoader(Aws::Auth::ProfileConfigFileAWSCredentialsProvider::GetCredentialsProfileFilename()), + m_configFileLoader(Aws::Auth::GetConfigProfileFilename(), true/*use profile prefix*/) + { + ReloadCredentialsFile(); + ReloadConfigFile(); + } + + void ConfigAndCredentialsCacheManager::ReloadConfigFile() + { + Aws::Utils::Threading::WriterLockGuard guard(m_configLock); + m_configFileLoader.SetFileName(Aws::Auth::GetConfigProfileFilename()); + m_configFileLoader.Load(); + } + + void ConfigAndCredentialsCacheManager::ReloadCredentialsFile() + { + Aws::Utils::Threading::WriterLockGuard guard(m_credentialsLock); + m_credentialsFileLoader.SetFileName(Aws::Auth::ProfileConfigFileAWSCredentialsProvider::GetCredentialsProfileFilename()); + m_credentialsFileLoader.Load(); + } + + bool ConfigAndCredentialsCacheManager::HasConfigProfile(const Aws::String& profileName) const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_configLock); + return (m_configFileLoader.GetProfiles().count(profileName) == 1); + } + + Aws::Config::Profile ConfigAndCredentialsCacheManager::GetConfigProfile(const Aws::String& profileName) const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_configLock); + const auto& profiles = m_configFileLoader.GetProfiles(); + const auto &iter = profiles.find(profileName); + if (iter == profiles.end()) + { + return {}; + } + return iter->second; + } + + Aws::Map<Aws::String, Aws::Config::Profile> ConfigAndCredentialsCacheManager::GetConfigProfiles() const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_configLock); + return m_configFileLoader.GetProfiles(); + } + + Aws::String ConfigAndCredentialsCacheManager::GetConfig(const Aws::String& profileName, const Aws::String& key) const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_configLock); + const auto& profiles = m_configFileLoader.GetProfiles(); + const auto &iter = profiles.find(profileName); + if (iter == profiles.end()) + { + return {}; + } + return iter->second.GetValue(key); + } + + bool ConfigAndCredentialsCacheManager::HasCredentialsProfile(const Aws::String& profileName) const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_credentialsLock); + return (m_credentialsFileLoader.GetProfiles().count(profileName) == 1); + } + + Aws::Config::Profile ConfigAndCredentialsCacheManager::GetCredentialsProfile(const Aws::String& profileName) const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_credentialsLock); + const auto &profiles = m_credentialsFileLoader.GetProfiles(); + const auto &iter = profiles.find(profileName); + if (iter == profiles.end()) + { + return {}; + } + return iter->second; + } + + Aws::Map<Aws::String, Aws::Config::Profile> ConfigAndCredentialsCacheManager::GetCredentialsProfiles() const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_credentialsLock); + return m_credentialsFileLoader.GetProfiles(); + } + + Aws::Auth::AWSCredentials ConfigAndCredentialsCacheManager::GetCredentials(const Aws::String& profileName) const + { + Aws::Utils::Threading::ReaderLockGuard guard(m_credentialsLock); + const auto& profiles = m_credentialsFileLoader.GetProfiles(); + const auto &iter = profiles.find(profileName); + if (iter == profiles.end()) + { + return {}; + } + return iter->second.GetCredentials(); + } + + void InitConfigAndCredentialsCacheManager() + { + if (s_configManager) + { + return; + } + s_configManager = Aws::MakeUnique<ConfigAndCredentialsCacheManager>(CONFIG_CREDENTIALS_CACHE_MANAGER_TAG); + } + + void CleanupConfigAndCredentialsCacheManager() + { + if (!s_configManager) + { + return; + } + s_configManager = nullptr; + } + + void ReloadCachedConfigFile() + { + assert(s_configManager); + s_configManager->ReloadConfigFile(); + } + + void ReloadCachedCredentialsFile() + { + assert(s_configManager); + s_configManager->ReloadCredentialsFile(); + } + + bool HasCachedConfigProfile(const Aws::String& profileName) + { + assert(s_configManager); + return s_configManager->HasConfigProfile(profileName); + } + + Aws::Config::Profile GetCachedConfigProfile(const Aws::String& profileName) + { + assert(s_configManager); + return s_configManager->GetConfigProfile(profileName); + } + + Aws::Map<Aws::String, Aws::Config::Profile> GetCachedConfigProfiles() + { + assert(s_configManager); + return s_configManager->GetConfigProfiles(); + } + + Aws::String GetCachedConfigValue(const Aws::String &profileName, const Aws::String &key) + { + assert(s_configManager); + return s_configManager->GetConfig(profileName, key); + } + + Aws::String GetCachedConfigValue(const Aws::String &key) + { + assert(s_configManager); + return s_configManager->GetConfig(Aws::Auth::GetConfigProfileName(), key); + } + + bool HasCachedCredentialsProfile(const Aws::String& profileName) + { + assert(s_configManager); + return s_configManager->HasCredentialsProfile(profileName); + } + + Aws::Config::Profile GetCachedCredentialsProfile(const Aws::String &profileName) + { + assert(s_configManager); + return s_configManager->GetCredentialsProfile(profileName); + } + + Aws::Map<Aws::String, Aws::Config::Profile> GetCachedCredentialsProfiles() + { + assert(s_configManager); + return s_configManager->GetCredentialsProfiles(); + } + + Aws::Auth::AWSCredentials GetCachedCredentials(const Aws::String &profileName) + { + assert(s_configManager); + return s_configManager->GetCredentials(profileName); + } + } // Config namespace +} // Aws namespace diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/external/cjson/cJSON.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/external/cjson/cJSON.cpp new file mode 100644 index 0000000000..2525976334 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/external/cjson/cJSON.cpp @@ -0,0 +1,2983 @@ +/* + Copyright (c) 2009-2017 Dave Gamble and cJSON contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. +*/ + +/* cJSON */ +/* JSON parser in C. */ + +/* disable warnings about old C89 functions in MSVC */ +#if !defined(_CRT_SECURE_NO_DEPRECATE) && defined(_MSC_VER) +#define _CRT_SECURE_NO_DEPRECATE +#endif + +#ifdef __GNUC__ +#pragma GCC visibility push(default) +#endif +#if defined(_MSC_VER) +#pragma warning (push) +/* disable warning about single line comments in system headers */ +#pragma warning (disable : 4001) +#endif + +#include <string.h> +#include <stdio.h> +#include <math.h> +#include <stdlib.h> +#include <limits.h> +#include <ctype.h> + +#ifdef ENABLE_LOCALES +#include <locale.h> +#endif + +#if defined(_MSC_VER) +#pragma warning (pop) +#endif +#ifdef __GNUC__ +#pragma GCC visibility pop +#endif + +#include <aws/core/external/cjson/cJSON.h> + +/* define our own boolean type */ +// #define true ((cJSON_bool)1) +// #define false ((cJSON_bool)0) + +typedef struct { + const unsigned char *json; + size_t position; +} error; +static error global_error = { NULL, 0 }; + +CJSON_PUBLIC(const char *) cJSON_GetErrorPtr(void) +{ + return (const char*) (global_error.json + global_error.position); +} + +CJSON_PUBLIC(char *) cJSON_GetStringValue(cJSON *item) { + if (!cJSON_IsString(item)) { + return NULL; + } + + return item->valuestring; +} + +/* This is a safeguard to prevent copy-pasters from using incompatible C and header files */ +#if (CJSON_VERSION_MAJOR != 1) || (CJSON_VERSION_MINOR != 7) || (CJSON_VERSION_PATCH != 7) + #error cJSON.h and cJSON.c have different versions. Make sure that both have the same. +#endif + +CJSON_PUBLIC(const char*) cJSON_Version(void) +{ + static char version[15]; + sprintf(version, "%i.%i.%i", CJSON_VERSION_MAJOR, CJSON_VERSION_MINOR, CJSON_VERSION_PATCH); + + return version; +} + +/* Case insensitive string comparison, doesn't consider two NULL pointers equal though */ +static int case_insensitive_strcmp(const unsigned char *string1, const unsigned char *string2) +{ + if ((string1 == NULL) || (string2 == NULL)) + { + return 1; + } + + if (string1 == string2) + { + return 0; + } + + for(; tolower(*string1) == tolower(*string2); (void)string1++, string2++) + { + if (*string1 == '\0') + { + return 0; + } + } + + return tolower(*string1) - tolower(*string2); +} + +typedef struct internal_hooks +{ + void *(*allocate)(size_t size); + void (*deallocate)(void *pointer); + void *(*reallocate)(void *pointer, size_t size); +} internal_hooks; + +#if defined(_MSC_VER) +/* work around MSVC error C2322: '...' address of dillimport '...' is not static */ +static void *internal_malloc(size_t size) +{ + return malloc(size); +} +static void internal_free(void *pointer) +{ + free(pointer); +} +static void *internal_realloc(void *pointer, size_t size) +{ + return realloc(pointer, size); +} +#else +#define internal_malloc malloc +#define internal_free free +#define internal_realloc realloc +#endif + +static internal_hooks global_hooks = { internal_malloc, internal_free, internal_realloc }; + +static unsigned char* cJSON_strdup(const unsigned char* string, const internal_hooks * const hooks) +{ + size_t length = 0; + unsigned char *copy = NULL; + + if (string == NULL) + { + return NULL; + } + + length = strlen((const char*)string) + sizeof(""); + copy = (unsigned char*)hooks->allocate(length); + if (copy == NULL) + { + return NULL; + } + memcpy(copy, string, length); + + return copy; +} + +CJSON_PUBLIC(void) cJSON_InitHooks(cJSON_Hooks* hooks) +{ + if (hooks == NULL) + { + /* Reset hooks */ + global_hooks.allocate = malloc; + global_hooks.deallocate = free; + global_hooks.reallocate = realloc; + return; + } + + global_hooks.allocate = malloc; + if (hooks->malloc_fn != NULL) + { + global_hooks.allocate = hooks->malloc_fn; + } + + global_hooks.deallocate = free; + if (hooks->free_fn != NULL) + { + global_hooks.deallocate = hooks->free_fn; + } + + /* use realloc only if both free and malloc are used */ + global_hooks.reallocate = NULL; + if ((global_hooks.allocate == malloc) && (global_hooks.deallocate == free)) + { + global_hooks.reallocate = realloc; + } +} + +/* Internal constructor. */ +static cJSON *cJSON_New_Item(const internal_hooks * const hooks) +{ + cJSON* node = (cJSON*)hooks->allocate(sizeof(cJSON)); + if (node) + { + memset(node, '\0', sizeof(cJSON)); + } + + return node; +} + +/* Delete a cJSON structure. */ +CJSON_PUBLIC(void) cJSON_Delete(cJSON *item) +{ + cJSON *next = NULL; + while (item != NULL) + { + next = item->next; + if (!(item->type & cJSON_IsReference) && (item->child != NULL)) + { + cJSON_Delete(item->child); + } + if (!(item->type & cJSON_IsReference) && (item->valuestring != NULL)) + { + global_hooks.deallocate(item->valuestring); + } + if (!(item->type & cJSON_StringIsConst) && (item->string != NULL)) + { + global_hooks.deallocate(item->string); + } + global_hooks.deallocate(item); + item = next; + } +} + +/* get the decimal point character of the current locale */ +static unsigned char get_decimal_point(void) +{ +#ifdef ENABLE_LOCALES + struct lconv *lconv = localeconv(); + return (unsigned char) lconv->decimal_point[0]; +#else + return '.'; +#endif +} + +typedef struct +{ + const unsigned char *content; + size_t length; + size_t offset; + size_t depth; /* How deeply nested (in arrays/objects) is the input at the current offset. */ + internal_hooks hooks; +} parse_buffer; + +/* check if the given size is left to read in a given parse buffer (starting with 1) */ +#define can_read(buffer, size) ((buffer != NULL) && (((buffer)->offset + size) <= (buffer)->length)) +/* check if the buffer can be accessed at the given index (starting with 0) */ +#define can_access_at_index(buffer, index) ((buffer != NULL) && (((buffer)->offset + index) < (buffer)->length)) +#define cannot_access_at_index(buffer, index) (!can_access_at_index(buffer, index)) +/* get a pointer to the buffer at the position */ +#define buffer_at_offset(buffer) ((buffer)->content + (buffer)->offset) + +/* Parse the input text to generate a number, and populate the result into item. */ +static cJSON_bool parse_number(cJSON * const item, parse_buffer * const input_buffer) +{ + double number = 0; + unsigned char *after_end = NULL; + unsigned char number_c_string[64]; + unsigned char decimal_point = get_decimal_point(); + bool isInteger = true; + size_t i = 0; + + if ((input_buffer == NULL) || (input_buffer->content == NULL)) + { + return false; + } + + /* copy the number into a temporary buffer and replace '.' with the decimal point + * of the current locale (for strtod) + * This also takes care of '\0' not necessarily being available for marking the end of the input */ + for (i = 0; (i < (sizeof(number_c_string) - 1)) && can_access_at_index(input_buffer, i); i++) + { + switch (buffer_at_offset(input_buffer)[i]) + { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '+': + case '-': + number_c_string[i] = buffer_at_offset(input_buffer)[i]; + break; + case 'e': + case 'E': + number_c_string[i] = buffer_at_offset(input_buffer)[i]; + isInteger = false; + break; + + case '.': + number_c_string[i] = decimal_point; + isInteger = false; + break; + + default: + goto loop_end; + } + } +loop_end: + number_c_string[i] = '\0'; + + number = strtod((const char*)number_c_string, (char**)&after_end); + if (number_c_string == after_end) + { + return false; /* parse_error */ + } + + item->valuedouble = number; + // For integer which is out of the range of [INT_MIN, INT_MAX], it may lose precision if we cast it to double. + // Instead, we keep the integer literal as a string. + if (isInteger && (number > INT_MAX || number < INT_MIN)) + { + item->valuestring = (char*)cJSON_strdup(number_c_string, &global_hooks); + } + + /* use saturation in case of overflow */ + if (number >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (number <= INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)number; + } + + item->type = cJSON_Number; + + input_buffer->offset += (size_t)(after_end - number_c_string); + return true; +} + +/* don't ask me, but the original cJSON_SetNumberValue returns an integer or double */ +CJSON_PUBLIC(double) cJSON_SetNumberHelper(cJSON *object, double number) +{ + if (number >= INT_MAX) + { + object->valueint = INT_MAX; + } + else if (number <= INT_MIN) + { + object->valueint = INT_MIN; + } + else + { + object->valueint = (int)number; + } + + return object->valuedouble = number; +} + +typedef struct +{ + unsigned char *buffer; + size_t length; + size_t offset; + size_t depth; /* current nesting depth (for formatted printing) */ + cJSON_bool noalloc; + cJSON_bool format; /* is this print a formatted print */ + internal_hooks hooks; +} printbuffer; + +/* realloc printbuffer if necessary to have at least "needed" bytes more */ +static unsigned char* ensure(printbuffer * const p, size_t needed) +{ + unsigned char *newbuffer = NULL; + size_t newsize = 0; + + if ((p == NULL) || (p->buffer == NULL)) + { + return NULL; + } + + if ((p->length > 0) && (p->offset >= p->length)) + { + /* make sure that offset is valid */ + return NULL; + } + + if (needed > INT_MAX) + { + /* sizes bigger than INT_MAX are currently not supported */ + return NULL; + } + + needed += p->offset + 1; + if (needed <= p->length) + { + return p->buffer + p->offset; + } + + if (p->noalloc) { + return NULL; + } + + /* calculate new buffer size */ + if (needed > (INT_MAX / 2)) + { + /* overflow of int, use INT_MAX if possible */ + if (needed <= INT_MAX) + { + newsize = INT_MAX; + } + else + { + return NULL; + } + } + else + { + newsize = needed * 2; + } + + if (p->hooks.reallocate != NULL) + { + /* reallocate with realloc if available */ + newbuffer = (unsigned char*)p->hooks.reallocate(p->buffer, newsize); + if (newbuffer == NULL) + { + p->hooks.deallocate(p->buffer); + p->length = 0; + p->buffer = NULL; + + return NULL; + } + } + else + { + /* otherwise reallocate manually */ + newbuffer = (unsigned char*)p->hooks.allocate(newsize); + if (!newbuffer) + { + p->hooks.deallocate(p->buffer); + p->length = 0; + p->buffer = NULL; + + return NULL; + } + if (newbuffer) + { + memcpy(newbuffer, p->buffer, p->offset + 1); + } + p->hooks.deallocate(p->buffer); + } + p->length = newsize; + p->buffer = newbuffer; + + return newbuffer + p->offset; +} + +/* calculate the new length of the string in a printbuffer and update the offset */ +static void update_offset(printbuffer * const buffer) +{ + const unsigned char *buffer_pointer = NULL; + if ((buffer == NULL) || (buffer->buffer == NULL)) + { + return; + } + buffer_pointer = buffer->buffer + buffer->offset; + + buffer->offset += strlen((const char*)buffer_pointer); +} + +/* Render the number nicely from the given item into a string. */ +static cJSON_bool print_number(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + double d = item->valuedouble; + int length = 0; + size_t i = 0; + unsigned char number_buffer[26]; /* temporary buffer to print the number into */ + unsigned char decimal_point = get_decimal_point(); + double test; + + if (output_buffer == NULL) + { + return false; + } + + /* For integer which is out of the range of [INT_MIN, INT_MAX], valuestring is an integer literal. */ + if (item->valuestring) + { + length = sprintf((char*)number_buffer, "%s", item->valuestring); + } + /* This checks for NaN and Infinity */ + else if ((d * 0) != 0) + { + length = sprintf((char*)number_buffer, "null"); + } + else + { + /* Try 15 decimal places of precision to avoid nonsignificant nonzero digits */ + length = sprintf((char*)number_buffer, "%1.15g", d); + + /* Check whether the original double can be recovered */ + if ((sscanf((char*)number_buffer, "%lg", &test) != 1) || ((double)test != d)) + { + /* If not, print with 17 decimal places of precision */ + length = sprintf((char*)number_buffer, "%1.17g", d); + } + } + + /* sprintf failed or buffer overrun occurred */ + if ((length < 0) || (length > (int)(sizeof(number_buffer) - 1))) + { + return false; + } + + /* reserve appropriate space in the output */ + output_pointer = ensure(output_buffer, (size_t)length + sizeof("")); + if (output_pointer == NULL) + { + return false; + } + + /* copy the printed number to the output and replace locale + * dependent decimal point with '.' */ + for (i = 0; i < ((size_t)length); i++) + { + if (number_buffer[i] == decimal_point) + { + output_pointer[i] = '.'; + continue; + } + + output_pointer[i] = number_buffer[i]; + } + output_pointer[i] = '\0'; + + output_buffer->offset += (size_t)length; + + return true; +} + +/* parse 4 digit hexadecimal number */ +static unsigned parse_hex4(const unsigned char * const input) +{ + unsigned int h = 0; + size_t i = 0; + + for (i = 0; i < 4; i++) + { + /* parse digit */ + if ((input[i] >= '0') && (input[i] <= '9')) + { + h += (unsigned int) input[i] - '0'; + } + else if ((input[i] >= 'A') && (input[i] <= 'F')) + { + h += (unsigned int) 10 + input[i] - 'A'; + } + else if ((input[i] >= 'a') && (input[i] <= 'f')) + { + h += (unsigned int) 10 + input[i] - 'a'; + } + else /* invalid */ + { + return 0; + } + + if (i < 3) + { + /* shift left to make place for the next nibble */ + h = h << 4; + } + } + + return h; +} + +/* converts a UTF-16 literal to UTF-8 + * A literal can be one or two sequences of the form \uXXXX */ +static unsigned char utf16_literal_to_utf8(const unsigned char * const input_pointer, const unsigned char * const input_end, unsigned char **output_pointer) +{ + long unsigned int codepoint = 0; + unsigned int first_code = 0; + const unsigned char *first_sequence = input_pointer; + unsigned char utf8_length = 0; + unsigned char utf8_position = 0; + unsigned char sequence_length = 0; + unsigned char first_byte_mark = 0; + + if ((input_end - first_sequence) < 6) + { + /* input ends unexpectedly */ + goto fail; + } + + /* get the first utf16 sequence */ + first_code = parse_hex4(first_sequence + 2); + + /* check that the code is valid */ + if (((first_code >= 0xDC00) && (first_code <= 0xDFFF))) + { + goto fail; + } + + /* UTF16 surrogate pair */ + if ((first_code >= 0xD800) && (first_code <= 0xDBFF)) + { + const unsigned char *second_sequence = first_sequence + 6; + unsigned int second_code = 0; + sequence_length = 12; /* \uXXXX\uXXXX */ + + if ((input_end - second_sequence) < 6) + { + /* input ends unexpectedly */ + goto fail; + } + + if ((second_sequence[0] != '\\') || (second_sequence[1] != 'u')) + { + /* missing second half of the surrogate pair */ + goto fail; + } + + /* get the second utf16 sequence */ + second_code = parse_hex4(second_sequence + 2); + /* check that the code is valid */ + if ((second_code < 0xDC00) || (second_code > 0xDFFF)) + { + /* invalid second half of the surrogate pair */ + goto fail; + } + + + /* calculate the unicode codepoint from the surrogate pair */ + codepoint = 0x10000 + (((first_code & 0x3FF) << 10) | (second_code & 0x3FF)); + } + else + { + sequence_length = 6; /* \uXXXX */ + codepoint = first_code; + } + + /* encode as UTF-8 + * takes at maximum 4 bytes to encode: + * 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */ + if (codepoint < 0x80) + { + /* normal ascii, encoding 0xxxxxxx */ + utf8_length = 1; + } + else if (codepoint < 0x800) + { + /* two bytes, encoding 110xxxxx 10xxxxxx */ + utf8_length = 2; + first_byte_mark = 0xC0; /* 11000000 */ + } + else if (codepoint < 0x10000) + { + /* three bytes, encoding 1110xxxx 10xxxxxx 10xxxxxx */ + utf8_length = 3; + first_byte_mark = 0xE0; /* 11100000 */ + } + else if (codepoint <= 0x10FFFF) + { + /* four bytes, encoding 1110xxxx 10xxxxxx 10xxxxxx 10xxxxxx */ + utf8_length = 4; + first_byte_mark = 0xF0; /* 11110000 */ + } + else + { + /* invalid unicode codepoint */ + goto fail; + } + + /* encode as utf8 */ + for (utf8_position = (unsigned char)(utf8_length - 1); utf8_position > 0; utf8_position--) + { + /* 10xxxxxx */ + (*output_pointer)[utf8_position] = (unsigned char)((codepoint | 0x80) & 0xBF); + codepoint >>= 6; + } + /* encode first byte */ + if (utf8_length > 1) + { + (*output_pointer)[0] = (unsigned char)((codepoint | first_byte_mark) & 0xFF); + } + else + { + (*output_pointer)[0] = (unsigned char)(codepoint & 0x7F); + } + + *output_pointer += utf8_length; + + return sequence_length; + +fail: + return 0; +} + +/* Parse the input text into an unescaped cinput, and populate item. */ +static cJSON_bool parse_string(cJSON * const item, parse_buffer * const input_buffer) +{ + const unsigned char *input_pointer = buffer_at_offset(input_buffer) + 1; + const unsigned char *input_end = buffer_at_offset(input_buffer) + 1; + unsigned char *output_pointer = NULL; + unsigned char *output = NULL; + + /* not a string */ + if (buffer_at_offset(input_buffer)[0] != '\"') + { + goto fail; + } + + { + /* calculate approximate size of the output (overestimate) */ + size_t allocation_length = 0; + size_t skipped_bytes = 0; + while (((size_t)(input_end - input_buffer->content) < input_buffer->length) && (*input_end != '\"')) + { + /* is escape sequence */ + if (input_end[0] == '\\') + { + if ((size_t)(input_end + 1 - input_buffer->content) >= input_buffer->length) + { + /* prevent buffer overflow when last input character is a backslash */ + goto fail; + } + skipped_bytes++; + input_end++; + } + input_end++; + } + if (((size_t)(input_end - input_buffer->content) >= input_buffer->length) || (*input_end != '\"')) + { + goto fail; /* string ended unexpectedly */ + } + + /* This is at most how much we need for the output */ + allocation_length = (size_t) (input_end - buffer_at_offset(input_buffer)) - skipped_bytes; + output = (unsigned char*)input_buffer->hooks.allocate(allocation_length + sizeof("")); + if (output == NULL) + { + goto fail; /* allocation failure */ + } + } + + output_pointer = output; + /* loop through the string literal */ + while (input_pointer < input_end) + { + if (*input_pointer != '\\') + { + *output_pointer++ = *input_pointer++; + } + /* escape sequence */ + else + { + unsigned char sequence_length = 2; + if ((input_end - input_pointer) < 1) + { + goto fail; + } + + switch (input_pointer[1]) + { + case 'b': + *output_pointer++ = '\b'; + break; + case 'f': + *output_pointer++ = '\f'; + break; + case 'n': + *output_pointer++ = '\n'; + break; + case 'r': + *output_pointer++ = '\r'; + break; + case 't': + *output_pointer++ = '\t'; + break; + case '\"': + case '\\': + case '/': + *output_pointer++ = input_pointer[1]; + break; + + /* UTF-16 literal */ + case 'u': + sequence_length = utf16_literal_to_utf8(input_pointer, input_end, &output_pointer); + if (sequence_length == 0) + { + /* failed to convert UTF16-literal to UTF-8 */ + goto fail; + } + break; + + default: + goto fail; + } + input_pointer += sequence_length; + } + } + + /* zero terminate the output */ + *output_pointer = '\0'; + + item->type = cJSON_String; + item->valuestring = (char*)output; + + input_buffer->offset = (size_t) (input_end - input_buffer->content); + input_buffer->offset++; + + return true; + +fail: + if (output != NULL) + { + input_buffer->hooks.deallocate(output); + } + + if (input_pointer != NULL) + { + input_buffer->offset = (size_t)(input_pointer - input_buffer->content); + } + + return false; +} + +/* Render the cstring provided to an escaped version that can be printed. */ +static cJSON_bool print_string_ptr(const unsigned char * const input, printbuffer * const output_buffer) +{ + const unsigned char *input_pointer = NULL; + unsigned char *output = NULL; + unsigned char *output_pointer = NULL; + size_t output_length = 0; + /* numbers of additional characters needed for escaping */ + size_t escape_characters = 0; + + if (output_buffer == NULL) + { + return false; + } + + /* empty string */ + if (input == NULL) + { + output = ensure(output_buffer, sizeof("\"\"")); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "\"\""); + + return true; + } + + /* set "flag" to 1 if something needs to be escaped */ + for (input_pointer = input; *input_pointer; input_pointer++) + { + switch (*input_pointer) + { + case '\"': + case '\\': + case '\b': + case '\f': + case '\n': + case '\r': + case '\t': + /* one character escape sequence */ + escape_characters++; + break; + default: + if (*input_pointer < 32) + { + /* UTF-16 escape sequence uXXXX */ + escape_characters += 5; + } + break; + } + } + output_length = (size_t)(input_pointer - input) + escape_characters; + + output = ensure(output_buffer, output_length + sizeof("\"\"")); + if (output == NULL) + { + return false; + } + + /* no characters have to be escaped */ + if (escape_characters == 0) + { + output[0] = '\"'; + memcpy(output + 1, input, output_length); + output[output_length + 1] = '\"'; + output[output_length + 2] = '\0'; + + return true; + } + + output[0] = '\"'; + output_pointer = output + 1; + /* copy the string */ + for (input_pointer = input; *input_pointer != '\0'; (void)input_pointer++, output_pointer++) + { + if ((*input_pointer > 31) && (*input_pointer != '\"') && (*input_pointer != '\\')) + { + /* normal character, copy */ + *output_pointer = *input_pointer; + } + else + { + /* character needs to be escaped */ + *output_pointer++ = '\\'; + switch (*input_pointer) + { + case '\\': + *output_pointer = '\\'; + break; + case '\"': + *output_pointer = '\"'; + break; + case '\b': + *output_pointer = 'b'; + break; + case '\f': + *output_pointer = 'f'; + break; + case '\n': + *output_pointer = 'n'; + break; + case '\r': + *output_pointer = 'r'; + break; + case '\t': + *output_pointer = 't'; + break; + default: + /* escape and print as unicode codepoint */ + sprintf((char*)output_pointer, "u%04x", *input_pointer); + output_pointer += 4; + break; + } + } + } + output[output_length + 1] = '\"'; + output[output_length + 2] = '\0'; + + return true; +} + +/* Invoke print_string_ptr (which is useful) on an item. */ +static cJSON_bool print_string(const cJSON * const item, printbuffer * const p) +{ + return print_string_ptr((unsigned char*)item->valuestring, p); +} + +/* Predeclare these prototypes. */ +static cJSON_bool parse_value(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_value(const cJSON * const item, printbuffer * const output_buffer); +static cJSON_bool parse_array(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_array(const cJSON * const item, printbuffer * const output_buffer); +static cJSON_bool parse_object(cJSON * const item, parse_buffer * const input_buffer); +static cJSON_bool print_object(const cJSON * const item, printbuffer * const output_buffer); + +/* Utility to jump whitespace and cr/lf */ +static parse_buffer *buffer_skip_whitespace(parse_buffer * const buffer) +{ + if ((buffer == NULL) || (buffer->content == NULL)) + { + return NULL; + } + + while (can_access_at_index(buffer, 0) && (buffer_at_offset(buffer)[0] <= 32)) + { + buffer->offset++; + } + + if (buffer->offset == buffer->length) + { + buffer->offset--; + } + + return buffer; +} + +/* skip the UTF-8 BOM (byte order mark) if it is at the beginning of a buffer */ +static parse_buffer *skip_utf8_bom(parse_buffer * const buffer) +{ + if ((buffer == NULL) || (buffer->content == NULL) || (buffer->offset != 0)) + { + return NULL; + } + + if (can_access_at_index(buffer, 4) && (strncmp((const char*)buffer_at_offset(buffer), "\xEF\xBB\xBF", 3) == 0)) + { + buffer->offset += 3; + } + + return buffer; +} + +/* Parse an object - create a new root, and populate. */ +CJSON_PUBLIC(cJSON *) cJSON_ParseWithOpts(const char *value, const char **return_parse_end, cJSON_bool require_null_terminated) +{ + parse_buffer buffer = { 0, 0, 0, 0, { 0, 0, 0 } }; + cJSON *item = NULL; + + /* reset error position */ + global_error.json = NULL; + global_error.position = 0; + + if (value == NULL) + { + goto fail; + } + + buffer.content = (const unsigned char*)value; + buffer.length = strlen((const char*)value) + sizeof(""); + buffer.offset = 0; + buffer.hooks = global_hooks; + + item = cJSON_New_Item(&global_hooks); + if (item == NULL) /* memory fail */ + { + goto fail; + } + + if (!parse_value(item, buffer_skip_whitespace(skip_utf8_bom(&buffer)))) + { + /* parse failure. ep is set. */ + goto fail; + } + + /* if we require null-terminated JSON without appended garbage, skip and then check for a null terminator */ + if (require_null_terminated) + { + buffer_skip_whitespace(&buffer); + if ((buffer.offset >= buffer.length) || buffer_at_offset(&buffer)[0] != '\0') + { + goto fail; + } + } + if (return_parse_end) + { + *return_parse_end = (const char*)buffer_at_offset(&buffer); + } + + return item; + +fail: + if (item != NULL) + { + cJSON_Delete(item); + } + + if (value != NULL) + { + error local_error; + local_error.json = (const unsigned char*)value; + local_error.position = 0; + + if (buffer.offset < buffer.length) + { + local_error.position = buffer.offset; + } + else if (buffer.length > 0) + { + local_error.position = buffer.length - 1; + } + + if (return_parse_end != NULL) + { + *return_parse_end = (const char*)local_error.json + local_error.position; + } + + global_error = local_error; + } + + return NULL; +} + +/* Default options for cJSON_Parse */ +CJSON_PUBLIC(cJSON *) cJSON_Parse(const char *value) +{ + return cJSON_ParseWithOpts(value, 0, 0); +} + +#define cjson_min(a, b) ((a < b) ? a : b) + +static unsigned char *print(const cJSON * const item, cJSON_bool format, const internal_hooks * const hooks) +{ + static const size_t default_buffer_size = 256; + printbuffer buffer[1]; + unsigned char *printed = NULL; + + memset(buffer, 0, sizeof(buffer)); + + /* create buffer */ + buffer->buffer = (unsigned char*) hooks->allocate(default_buffer_size); + buffer->length = default_buffer_size; + buffer->format = format; + buffer->hooks = *hooks; + if (buffer->buffer == NULL) + { + goto fail; + } + + /* print the value */ + if (!print_value(item, buffer)) + { + goto fail; + } + update_offset(buffer); + + /* check if reallocate is available */ + if (hooks->reallocate != NULL) + { + printed = (unsigned char*) hooks->reallocate(buffer->buffer, buffer->offset + 1); + if (printed == NULL) { + goto fail; + } + buffer->buffer = NULL; + } + else /* otherwise copy the JSON over to a new buffer */ + { + printed = (unsigned char*) hooks->allocate(buffer->offset + 1); + if (printed == NULL) + { + goto fail; + } + memcpy(printed, buffer->buffer, cjson_min(buffer->length, buffer->offset + 1)); + printed[buffer->offset] = '\0'; /* just to be sure */ + + /* free the buffer */ + hooks->deallocate(buffer->buffer); + } + + return printed; + +fail: + if (buffer->buffer != NULL) + { + hooks->deallocate(buffer->buffer); + } + + if (printed != NULL) + { + hooks->deallocate(printed); + } + + return NULL; +} + +/* Render a cJSON item/entity/structure to text. */ +CJSON_PUBLIC(char *) cJSON_Print(const cJSON *item) +{ + return (char*)print(item, true, &global_hooks); +} + +CJSON_PUBLIC(char *) cJSON_PrintUnformatted(const cJSON *item) +{ + return (char*)print(item, false, &global_hooks); +} + +CJSON_PUBLIC(char *) cJSON_PrintBuffered(const cJSON *item, int prebuffer, cJSON_bool fmt) +{ + printbuffer p = { 0, 0, 0, 0, 0, 0, { 0, 0, 0 } }; + + if (prebuffer < 0) + { + return NULL; + } + + p.buffer = (unsigned char*)global_hooks.allocate((size_t)prebuffer); + if (!p.buffer) + { + return NULL; + } + + p.length = (size_t)prebuffer; + p.offset = 0; + p.noalloc = false; + p.format = fmt; + p.hooks = global_hooks; + + if (!print_value(item, &p)) + { + global_hooks.deallocate(p.buffer); + return NULL; + } + + return (char*)p.buffer; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_PrintPreallocated(cJSON *item, char *buf, const int len, const cJSON_bool fmt) +{ + printbuffer p = { 0, 0, 0, 0, 0, 0, { 0, 0, 0 } }; + + if ((len < 0) || (buf == NULL)) + { + return false; + } + + p.buffer = (unsigned char*)buf; + p.length = (size_t)len; + p.offset = 0; + p.noalloc = true; + p.format = fmt; + p.hooks = global_hooks; + + return print_value(item, &p); +} + +/* Parser core - when encountering text, process appropriately. */ +static cJSON_bool parse_value(cJSON * const item, parse_buffer * const input_buffer) +{ + if ((input_buffer == NULL) || (input_buffer->content == NULL)) + { + return false; /* no input */ + } + + /* parse the different types of values */ + /* null */ + if (can_read(input_buffer, 4) && (strncmp((const char*)buffer_at_offset(input_buffer), "null", 4) == 0)) + { + item->type = cJSON_NULL; + input_buffer->offset += 4; + return true; + } + /* false */ + if (can_read(input_buffer, 5) && (strncmp((const char*)buffer_at_offset(input_buffer), "false", 5) == 0)) + { + item->type = cJSON_False; + input_buffer->offset += 5; + return true; + } + /* true */ + if (can_read(input_buffer, 4) && (strncmp((const char*)buffer_at_offset(input_buffer), "true", 4) == 0)) + { + item->type = cJSON_True; + item->valueint = 1; + input_buffer->offset += 4; + return true; + } + /* string */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '\"')) + { + return parse_string(item, input_buffer); + } + /* number */ + if (can_access_at_index(input_buffer, 0) && ((buffer_at_offset(input_buffer)[0] == '-') || ((buffer_at_offset(input_buffer)[0] >= '0') && (buffer_at_offset(input_buffer)[0] <= '9')))) + { + return parse_number(item, input_buffer); + } + /* array */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '[')) + { + return parse_array(item, input_buffer); + } + /* object */ + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '{')) + { + return parse_object(item, input_buffer); + } + + return false; +} + +/* Render a value to text. */ +static cJSON_bool print_value(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output = NULL; + + if ((item == NULL) || (output_buffer == NULL)) + { + return false; + } + + switch ((item->type) & 0xFF) + { + case cJSON_NULL: + output = ensure(output_buffer, 5); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "null"); + return true; + + case cJSON_False: + output = ensure(output_buffer, 6); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "false"); + return true; + + case cJSON_True: + output = ensure(output_buffer, 5); + if (output == NULL) + { + return false; + } + strcpy((char*)output, "true"); + return true; + + case cJSON_Number: + return print_number(item, output_buffer); + + case cJSON_Raw: + { + size_t raw_length = 0; + if (item->valuestring == NULL) + { + return false; + } + + raw_length = strlen(item->valuestring) + sizeof(""); + output = ensure(output_buffer, raw_length); + if (output == NULL) + { + return false; + } + memcpy(output, item->valuestring, raw_length); + return true; + } + + case cJSON_String: + return print_string(item, output_buffer); + + case cJSON_Array: + return print_array(item, output_buffer); + + case cJSON_Object: + return print_object(item, output_buffer); + + default: + return false; + } +} + +/* Build an array from input text. */ +static cJSON_bool parse_array(cJSON * const item, parse_buffer * const input_buffer) +{ + cJSON *head = NULL; /* head of the linked list */ + cJSON *current_item = NULL; + + if (input_buffer->depth >= CJSON_NESTING_LIMIT) + { + return false; /* to deeply nested */ + } + input_buffer->depth++; + + if (buffer_at_offset(input_buffer)[0] != '[') + { + /* not an array */ + goto fail; + } + + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ']')) + { + /* empty array */ + goto success; + } + + /* check if we skipped to the end of the buffer */ + if (cannot_access_at_index(input_buffer, 0)) + { + input_buffer->offset--; + goto fail; + } + + /* step back to character in front of the first element */ + input_buffer->offset--; + /* loop through the comma separated array elements */ + do + { + /* allocate next item */ + cJSON *new_item = cJSON_New_Item(&(input_buffer->hooks)); + if (new_item == NULL) + { + goto fail; /* allocation failure */ + } + + /* attach next item to list */ + if (head == NULL) + { + /* start the linked list */ + current_item = head = new_item; + } + else + { + /* add to the end and advance */ + current_item->next = new_item; + new_item->prev = current_item; + current_item = new_item; + } + + /* parse next value */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_value(current_item, input_buffer)) + { + goto fail; /* failed to parse value */ + } + buffer_skip_whitespace(input_buffer); + } + while (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ',')); + + if (cannot_access_at_index(input_buffer, 0) || buffer_at_offset(input_buffer)[0] != ']') + { + goto fail; /* expected end of array */ + } + +success: + input_buffer->depth--; + + item->type = cJSON_Array; + item->child = head; + + input_buffer->offset++; + + return true; + +fail: + if (head != NULL) + { + cJSON_Delete(head); + } + + return false; +} + +/* Render an array to text */ +static cJSON_bool print_array(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + size_t length = 0; + cJSON *current_element = item->child; + + if (output_buffer == NULL) + { + return false; + } + + /* Compose the output array. */ + /* opening square bracket */ + output_pointer = ensure(output_buffer, 1); + if (output_pointer == NULL) + { + return false; + } + + *output_pointer = '['; + output_buffer->offset++; + output_buffer->depth++; + + while (current_element != NULL) + { + if (!print_value(current_element, output_buffer)) + { + return false; + } + update_offset(output_buffer); + if (current_element->next) + { + length = (size_t) (output_buffer->format ? 2 : 1); + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ','; + if(output_buffer->format) + { + *output_pointer++ = ' '; + } + *output_pointer = '\0'; + output_buffer->offset += length; + } + current_element = current_element->next; + } + + output_pointer = ensure(output_buffer, 2); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ']'; + *output_pointer = '\0'; + output_buffer->depth--; + + return true; +} + +/* Build an object from the text. */ +static cJSON_bool parse_object(cJSON * const item, parse_buffer * const input_buffer) +{ + cJSON *head = NULL; /* linked list head */ + cJSON *current_item = NULL; + + if (input_buffer->depth >= CJSON_NESTING_LIMIT) + { + return false; /* to deeply nested */ + } + input_buffer->depth++; + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != '{')) + { + goto fail; /* not an object */ + } + + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == '}')) + { + goto success; /* empty object */ + } + + /* check if we skipped to the end of the buffer */ + if (cannot_access_at_index(input_buffer, 0)) + { + input_buffer->offset--; + goto fail; + } + + /* step back to character in front of the first element */ + input_buffer->offset--; + /* loop through the comma separated array elements */ + do + { + /* allocate next item */ + cJSON *new_item = cJSON_New_Item(&(input_buffer->hooks)); + if (new_item == NULL) + { + goto fail; /* allocation failure */ + } + + /* attach next item to list */ + if (head == NULL) + { + /* start the linked list */ + current_item = head = new_item; + } + else + { + /* add to the end and advance */ + current_item->next = new_item; + new_item->prev = current_item; + current_item = new_item; + } + + /* parse the name of the child */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_string(current_item, input_buffer)) + { + goto fail; /* failed to parse name */ + } + buffer_skip_whitespace(input_buffer); + + /* swap valuestring and string, because we parsed the name */ + current_item->string = current_item->valuestring; + current_item->valuestring = NULL; + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != ':')) + { + goto fail; /* invalid object */ + } + + /* parse the value */ + input_buffer->offset++; + buffer_skip_whitespace(input_buffer); + if (!parse_value(current_item, input_buffer)) + { + goto fail; /* failed to parse value */ + } + buffer_skip_whitespace(input_buffer); + } + while (can_access_at_index(input_buffer, 0) && (buffer_at_offset(input_buffer)[0] == ',')); + + if (cannot_access_at_index(input_buffer, 0) || (buffer_at_offset(input_buffer)[0] != '}')) + { + goto fail; /* expected end of object */ + } + +success: + input_buffer->depth--; + + item->type = cJSON_Object; + item->child = head; + + input_buffer->offset++; + return true; + +fail: + if (head != NULL) + { + cJSON_Delete(head); + } + + return false; +} + +/* Render an object to text. */ +static cJSON_bool print_object(const cJSON * const item, printbuffer * const output_buffer) +{ + unsigned char *output_pointer = NULL; + size_t length = 0; + cJSON *current_item = item->child; + + if (output_buffer == NULL) + { + return false; + } + + /* Compose the output: */ + length = (size_t) (output_buffer->format ? 2 : 1); /* fmt: {\n */ + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + + *output_pointer++ = '{'; + output_buffer->depth++; + if (output_buffer->format) + { + *output_pointer++ = '\n'; + } + output_buffer->offset += length; + + while (current_item) + { + if (output_buffer->format) + { + size_t i; + output_pointer = ensure(output_buffer, output_buffer->depth); + if (output_pointer == NULL) + { + return false; + } + for (i = 0; i < output_buffer->depth; i++) + { + *output_pointer++ = '\t'; + } + output_buffer->offset += output_buffer->depth; + } + + /* print key */ + if (!print_string_ptr((unsigned char*)current_item->string, output_buffer)) + { + return false; + } + update_offset(output_buffer); + + length = (size_t) (output_buffer->format ? 2 : 1); + output_pointer = ensure(output_buffer, length); + if (output_pointer == NULL) + { + return false; + } + *output_pointer++ = ':'; + if (output_buffer->format) + { + *output_pointer++ = '\t'; + } + output_buffer->offset += length; + + /* print value */ + if (!print_value(current_item, output_buffer)) + { + return false; + } + update_offset(output_buffer); + + /* print comma if not last */ + length = (size_t) ((output_buffer->format ? 1 : 0) + (current_item->next ? 1 : 0)); + output_pointer = ensure(output_buffer, length + 1); + if (output_pointer == NULL) + { + return false; + } + if (current_item->next) + { + *output_pointer++ = ','; + } + + if (output_buffer->format) + { + *output_pointer++ = '\n'; + } + *output_pointer = '\0'; + output_buffer->offset += length; + + current_item = current_item->next; + } + + output_pointer = ensure(output_buffer, output_buffer->format ? (output_buffer->depth + 1) : 2); + if (output_pointer == NULL) + { + return false; + } + if (output_buffer->format) + { + size_t i; + for (i = 0; i < (output_buffer->depth - 1); i++) + { + *output_pointer++ = '\t'; + } + } + *output_pointer++ = '}'; + *output_pointer = '\0'; + output_buffer->depth--; + + return true; +} + +/* Get Array size/item / object item. */ +CJSON_PUBLIC(int) cJSON_GetArraySize(const cJSON *array) +{ + cJSON *child = NULL; + size_t size = 0; + + if (array == NULL) + { + return 0; + } + + child = array->child; + + while(child != NULL) + { + size++; + child = child->next; + } + + /* FIXME: Can overflow here. Cannot be fixed without breaking the API */ + + return (int)size; +} + +static cJSON* get_array_item(const cJSON *array, size_t index) +{ + cJSON *current_child = NULL; + + if (array == NULL) + { + return NULL; + } + + current_child = array->child; + while ((current_child != NULL) && (index > 0)) + { + index--; + current_child = current_child->next; + } + + return current_child; +} + +CJSON_PUBLIC(cJSON *) cJSON_GetArrayItem(const cJSON *array, int index) +{ + if (index < 0) + { + return NULL; + } + + return get_array_item(array, (size_t)index); +} + +static cJSON *get_object_item(const cJSON * const object, const char * const name, const cJSON_bool case_sensitive) +{ + cJSON *current_element = NULL; + + if ((object == NULL) || (name == NULL)) + { + return NULL; + } + + current_element = object->child; + if (case_sensitive) + { + while ((current_element != NULL) && (strcmp(name, current_element->string) != 0)) + { + current_element = current_element->next; + } + } + else + { + while ((current_element != NULL) && (case_insensitive_strcmp((const unsigned char*)name, (const unsigned char*)(current_element->string)) != 0)) + { + current_element = current_element->next; + } + } + + return current_element; +} + +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItem(const cJSON * const object, const char * const string) +{ + return get_object_item(object, string, false); +} + +CJSON_PUBLIC(cJSON *) cJSON_GetObjectItemCaseSensitive(const cJSON * const object, const char * const string) +{ + return get_object_item(object, string, true); +} + +CJSON_PUBLIC(cJSON_bool) cJSON_HasObjectItem(const cJSON *object, const char *string) +{ + return cJSON_GetObjectItem(object, string) ? 1 : 0; +} + +/* Utility for array list handling. */ +static void suffix_object(cJSON *prev, cJSON *item) +{ + prev->next = item; + item->prev = prev; +} + +/* Utility for handling references. */ +static cJSON *create_reference(const cJSON *item, const internal_hooks * const hooks) +{ + cJSON *reference = NULL; + if (item == NULL) + { + return NULL; + } + + reference = cJSON_New_Item(hooks); + if (reference == NULL) + { + return NULL; + } + + memcpy(reference, item, sizeof(cJSON)); + reference->string = NULL; + reference->type |= cJSON_IsReference; + reference->next = reference->prev = NULL; + return reference; +} + +static cJSON_bool add_item_to_array(cJSON *array, cJSON *item) +{ + cJSON *child = NULL; + + if ((item == NULL) || (array == NULL)) + { + return false; + } + + child = array->child; + + if (child == NULL) + { + /* list is empty, start new one */ + array->child = item; + } + else + { + /* append to the end */ + while (child->next) + { + child = child->next; + } + suffix_object(child, item); + } + + return true; +} + +/* Add item to array/object. */ +CJSON_PUBLIC(void) cJSON_AddItemToArray(cJSON *array, cJSON *item) +{ + add_item_to_array(array, item); +} + +#if defined(__clang__) || (defined(__GNUC__) && ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 5)))) + #pragma GCC diagnostic push +#endif +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Wcast-qual" +#endif +/* helper function to cast away const */ +static void* cast_away_const(const void* string) +{ + return (void*)string; +} +#if defined(__clang__) || (defined(__GNUC__) && ((__GNUC__ > 4) || ((__GNUC__ == 4) && (__GNUC_MINOR__ > 5)))) + #pragma GCC diagnostic pop +#endif + + +static cJSON_bool add_item_to_object(cJSON * const object, const char * const string, cJSON * const item, const internal_hooks * const hooks, const cJSON_bool constant_key) +{ + char *new_key = NULL; + int new_type = cJSON_Invalid; + + if ((object == NULL) || (string == NULL) || (item == NULL)) + { + return false; + } + + if (constant_key) + { + new_key = (char*)cast_away_const(string); + new_type = item->type | cJSON_StringIsConst; + } + else + { + new_key = (char*)cJSON_strdup((const unsigned char*)string, hooks); + if (new_key == NULL) + { + return false; + } + + new_type = item->type & ~cJSON_StringIsConst; + } + + if (!(item->type & cJSON_StringIsConst) && (item->string != NULL)) + { + hooks->deallocate(item->string); + } + + item->string = new_key; + item->type = new_type; + + return add_item_to_array(object, item); +} + +CJSON_PUBLIC(void) cJSON_AddItemToObject(cJSON *object, const char *string, cJSON *item) +{ + add_item_to_object(object, string, item, &global_hooks, false); +} + +/* Add an item to an object with constant string as key */ +CJSON_PUBLIC(void) cJSON_AddItemToObjectCS(cJSON *object, const char *string, cJSON *item) +{ + add_item_to_object(object, string, item, &global_hooks, true); +} + +CJSON_PUBLIC(void) cJSON_AddItemReferenceToArray(cJSON *array, cJSON *item) +{ + if (array == NULL) + { + return; + } + + add_item_to_array(array, create_reference(item, &global_hooks)); +} + +CJSON_PUBLIC(void) cJSON_AddItemReferenceToObject(cJSON *object, const char *string, cJSON *item) +{ + if ((object == NULL) || (string == NULL)) + { + return; + } + + add_item_to_object(object, string, create_reference(item, &global_hooks), &global_hooks, false); +} + +CJSON_PUBLIC(cJSON*) cJSON_AddNullToObject(cJSON * const object, const char * const name) +{ + cJSON *null = cJSON_CreateNull(); + if (add_item_to_object(object, name, null, &global_hooks, false)) + { + return null; + } + + cJSON_Delete(null); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddTrueToObject(cJSON * const object, const char * const name) +{ + cJSON *true_item = cJSON_CreateTrue(); + if (add_item_to_object(object, name, true_item, &global_hooks, false)) + { + return true_item; + } + + cJSON_Delete(true_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddFalseToObject(cJSON * const object, const char * const name) +{ + cJSON *false_item = cJSON_CreateFalse(); + if (add_item_to_object(object, name, false_item, &global_hooks, false)) + { + return false_item; + } + + cJSON_Delete(false_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddBoolToObject(cJSON * const object, const char * const name, const cJSON_bool boolean) +{ + cJSON *bool_item = cJSON_CreateBool(boolean); + if (add_item_to_object(object, name, bool_item, &global_hooks, false)) + { + return bool_item; + } + + cJSON_Delete(bool_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddNumberToObject(cJSON * const object, const char * const name, const double number) +{ + cJSON *number_item = cJSON_CreateNumber(number); + if (add_item_to_object(object, name, number_item, &global_hooks, false)) + { + return number_item; + } + + cJSON_Delete(number_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddStringToObject(cJSON * const object, const char * const name, const char * const string) +{ + cJSON *string_item = cJSON_CreateString(string); + if (add_item_to_object(object, name, string_item, &global_hooks, false)) + { + return string_item; + } + + cJSON_Delete(string_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddRawToObject(cJSON * const object, const char * const name, const char * const raw) +{ + cJSON *raw_item = cJSON_CreateRaw(raw); + if (add_item_to_object(object, name, raw_item, &global_hooks, false)) + { + return raw_item; + } + + cJSON_Delete(raw_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddObjectToObject(cJSON * const object, const char * const name) +{ + cJSON *object_item = cJSON_CreateObject(); + if (add_item_to_object(object, name, object_item, &global_hooks, false)) + { + return object_item; + } + + cJSON_Delete(object_item); + return NULL; +} + +CJSON_PUBLIC(cJSON*) cJSON_AddArrayToObject(cJSON * const object, const char * const name) +{ + cJSON *array = cJSON_CreateArray(); + if (add_item_to_object(object, name, array, &global_hooks, false)) + { + return array; + } + + cJSON_Delete(array); + return NULL; +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemViaPointer(cJSON *parent, cJSON * const item) +{ + if ((parent == NULL) || (item == NULL)) + { + return NULL; + } + + if (item->prev != NULL) + { + /* not the first element */ + item->prev->next = item->next; + } + if (item->next != NULL) + { + /* not the last element */ + item->next->prev = item->prev; + } + + if (item == parent->child) + { + /* first element */ + parent->child = item->next; + } + /* make sure the detached item doesn't point anywhere anymore */ + item->prev = NULL; + item->next = NULL; + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromArray(cJSON *array, int which) +{ + if (which < 0) + { + return NULL; + } + + return cJSON_DetachItemViaPointer(array, get_array_item(array, (size_t)which)); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromArray(cJSON *array, int which) +{ + cJSON_Delete(cJSON_DetachItemFromArray(array, which)); +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObject(cJSON *object, const char *string) +{ + cJSON *to_detach = cJSON_GetObjectItem(object, string); + + return cJSON_DetachItemViaPointer(object, to_detach); +} + +CJSON_PUBLIC(cJSON *) cJSON_DetachItemFromObjectCaseSensitive(cJSON *object, const char *string) +{ + cJSON *to_detach = cJSON_GetObjectItemCaseSensitive(object, string); + + return cJSON_DetachItemViaPointer(object, to_detach); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromObject(cJSON *object, const char *string) +{ + cJSON_Delete(cJSON_DetachItemFromObject(object, string)); +} + +CJSON_PUBLIC(void) cJSON_DeleteItemFromObjectCaseSensitive(cJSON *object, const char *string) +{ + cJSON_Delete(cJSON_DetachItemFromObjectCaseSensitive(object, string)); +} + +/* Replace array/object items with new ones. */ +CJSON_PUBLIC(void) cJSON_InsertItemInArray(cJSON *array, int which, cJSON *newitem) +{ + cJSON *after_inserted = NULL; + + if (which < 0) + { + return; + } + + after_inserted = get_array_item(array, (size_t)which); + if (after_inserted == NULL) + { + add_item_to_array(array, newitem); + return; + } + + newitem->next = after_inserted; + newitem->prev = after_inserted->prev; + after_inserted->prev = newitem; + if (after_inserted == array->child) + { + array->child = newitem; + } + else + { + newitem->prev->next = newitem; + } +} + +CJSON_PUBLIC(cJSON_bool) cJSON_ReplaceItemViaPointer(cJSON * const parent, cJSON * const item, cJSON * replacement) +{ + if ((parent == NULL) || (replacement == NULL) || (item == NULL)) + { + return false; + } + + if (replacement == item) + { + return true; + } + + replacement->next = item->next; + replacement->prev = item->prev; + + if (replacement->next != NULL) + { + replacement->next->prev = replacement; + } + if (replacement->prev != NULL) + { + replacement->prev->next = replacement; + } + if (parent->child == item) + { + parent->child = replacement; + } + + item->next = NULL; + item->prev = NULL; + cJSON_Delete(item); + + return true; +} + +CJSON_PUBLIC(void) cJSON_ReplaceItemInArray(cJSON *array, int which, cJSON *newitem) +{ + if (which < 0) + { + return; + } + + cJSON_ReplaceItemViaPointer(array, get_array_item(array, (size_t)which), newitem); +} + +static cJSON_bool replace_item_in_object(cJSON *object, const char *string, cJSON *replacement, cJSON_bool case_sensitive) +{ + if ((replacement == NULL) || (string == NULL)) + { + return false; + } + + /* replace the name in the replacement */ + if (!(replacement->type & cJSON_StringIsConst) && (replacement->string != NULL)) + { + cJSON_free(replacement->string); + } + replacement->string = (char*)cJSON_strdup((const unsigned char*)string, &global_hooks); + replacement->type &= ~cJSON_StringIsConst; + + cJSON_ReplaceItemViaPointer(object, get_object_item(object, string, case_sensitive), replacement); + + return true; +} + +CJSON_PUBLIC(void) cJSON_ReplaceItemInObject(cJSON *object, const char *string, cJSON *newitem) +{ + replace_item_in_object(object, string, newitem, false); +} + +CJSON_PUBLIC(void) cJSON_ReplaceItemInObjectCaseSensitive(cJSON *object, const char *string, cJSON *newitem) +{ + replace_item_in_object(object, string, newitem, true); +} + +/* Create basic types: */ +CJSON_PUBLIC(cJSON *) cJSON_CreateNull(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_NULL; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateTrue(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_True; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateFalse(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_False; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateBool(cJSON_bool b) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = b ? cJSON_True : cJSON_False; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateNumber(double num) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Number; + item->valuedouble = num; + + /* use saturation in case of overflow */ + if (num >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (num <= INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)num; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateInt64(long long num) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Number; + item->valuedouble = static_cast<double>(num); + + // For integer which is out of the range of [INT_MIN, INT_MAX], it may lose precision if we cast it to double. + // Instead, we keep the integer literal as a string. + if (num > INT_MAX || num < INT_MIN) + { + char buf[21]; + sprintf(buf, "%lld", num); + item->valuestring = (char*)cJSON_strdup((const unsigned char*)buf, &global_hooks); + } + + /* use saturation in case of overflow */ + if (num >= INT_MAX) + { + item->valueint = INT_MAX; + } + else if (num <= INT_MIN) + { + item->valueint = INT_MIN; + } + else + { + item->valueint = (int)num; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateString(const char *string) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_String; + item->valuestring = (char*)cJSON_strdup((const unsigned char*)string, &global_hooks); + if(!item->valuestring) + { + cJSON_Delete(item); + return NULL; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateStringReference(const char *string) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) + { + item->type = cJSON_String | cJSON_IsReference; + item->valuestring = (char*)cast_away_const(string); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateObjectReference(const cJSON *child) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) { + item->type = cJSON_Object | cJSON_IsReference; + item->child = (cJSON*)cast_away_const(child); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateArrayReference(const cJSON *child) { + cJSON *item = cJSON_New_Item(&global_hooks); + if (item != NULL) { + item->type = cJSON_Array | cJSON_IsReference; + item->child = (cJSON*)cast_away_const(child); + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateRaw(const char *raw) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type = cJSON_Raw; + item->valuestring = (char*)cJSON_strdup((const unsigned char*)raw, &global_hooks); + if(!item->valuestring) + { + cJSON_Delete(item); + return NULL; + } + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateArray(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if(item) + { + item->type=cJSON_Array; + } + + return item; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateObject(void) +{ + cJSON *item = cJSON_New_Item(&global_hooks); + if (item) + { + item->type = cJSON_Object; + } + + return item; +} + +/* Create Arrays: */ +CJSON_PUBLIC(cJSON *) cJSON_CreateIntArray(const int *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber(numbers[i]); + if (!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateFloatArray(const float *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber((double)numbers[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateDoubleArray(const double *numbers, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (numbers == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for(i = 0;a && (i < (size_t)count); i++) + { + n = cJSON_CreateNumber(numbers[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p, n); + } + p = n; + } + + return a; +} + +CJSON_PUBLIC(cJSON *) cJSON_CreateStringArray(const char **strings, int count) +{ + size_t i = 0; + cJSON *n = NULL; + cJSON *p = NULL; + cJSON *a = NULL; + + if ((count < 0) || (strings == NULL)) + { + return NULL; + } + + a = cJSON_CreateArray(); + + for (i = 0; a && (i < (size_t)count); i++) + { + n = cJSON_CreateString(strings[i]); + if(!n) + { + cJSON_Delete(a); + return NULL; + } + if(!i) + { + a->child = n; + } + else + { + suffix_object(p,n); + } + p = n; + } + + return a; +} + +/* Duplication */ +CJSON_PUBLIC(cJSON *) cJSON_Duplicate(const cJSON *item, cJSON_bool recurse) +{ + cJSON *newitem = NULL; + cJSON *child = NULL; + cJSON *next = NULL; + cJSON *newchild = NULL; + + /* Bail on bad ptr */ + if (!item) + { + goto fail; + } + /* Create new item */ + newitem = cJSON_New_Item(&global_hooks); + if (!newitem) + { + goto fail; + } + /* Copy over all vars */ + newitem->type = item->type & (~cJSON_IsReference); + newitem->valueint = item->valueint; + newitem->valuedouble = item->valuedouble; + if (item->valuestring) + { + newitem->valuestring = (char*)cJSON_strdup((unsigned char*)item->valuestring, &global_hooks); + if (!newitem->valuestring) + { + goto fail; + } + } + if (item->string) + { + newitem->string = (item->type&cJSON_StringIsConst) ? item->string : (char*)cJSON_strdup((unsigned char*)item->string, &global_hooks); + if (!newitem->string) + { + goto fail; + } + } + /* If non-recursive, then we're done! */ + if (!recurse) + { + return newitem; + } + /* Walk the ->next chain for the child. */ + child = item->child; + while (child != NULL) + { + newchild = cJSON_Duplicate(child, true); /* Duplicate (with recurse) each item in the ->next chain */ + if (!newchild) + { + goto fail; + } + if (next != NULL) + { + /* If newitem->child already set, then crosswire ->prev and ->next and move on */ + next->next = newchild; + newchild->prev = next; + next = newchild; + } + else + { + /* Set newitem->child and move to it */ + newitem->child = newchild; + next = newchild; + } + child = child->next; + } + + return newitem; + +fail: + if (newitem != NULL) + { + cJSON_Delete(newitem); + } + + return NULL; +} + +CJSON_PUBLIC(void) cJSON_Minify(char *json) +{ + unsigned char *into = (unsigned char*)json; + + if (json == NULL) + { + return; + } + + while (*json) + { + if (*json == ' ') + { + json++; + } + else if (*json == '\t') + { + /* Whitespace characters. */ + json++; + } + else if (*json == '\r') + { + json++; + } + else if (*json=='\n') + { + json++; + } + else if ((*json == '/') && (json[1] == '/')) + { + /* double-slash comments, to end of line. */ + while (*json && (*json != '\n')) + { + json++; + } + } + else if ((*json == '/') && (json[1] == '*')) + { + /* multiline comments. */ + while (*json && !((*json == '*') && (json[1] == '/'))) + { + json++; + } + json += 2; + } + else if (*json == '\"') + { + /* string literals, which are \" sensitive. */ + *into++ = (unsigned char)*json++; + while (*json && (*json != '\"')) + { + if (*json == '\\') + { + *into++ = (unsigned char)*json++; + } + *into++ = (unsigned char)*json++; + } + *into++ = (unsigned char)*json++; + } + else + { + /* All other characters. */ + *into++ = (unsigned char)*json++; + } + } + + /* and null-terminate. */ + *into = '\0'; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsInvalid(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Invalid; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsFalse(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_False; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsTrue(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xff) == cJSON_True; +} + + +CJSON_PUBLIC(cJSON_bool) cJSON_IsBool(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & (cJSON_True | cJSON_False)) != 0; +} +CJSON_PUBLIC(cJSON_bool) cJSON_IsNull(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_NULL; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsNumber(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Number; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsString(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_String; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsArray(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Array; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsObject(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Object; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_IsRaw(const cJSON * const item) +{ + if (item == NULL) + { + return false; + } + + return (item->type & 0xFF) == cJSON_Raw; +} + +CJSON_PUBLIC(cJSON_bool) cJSON_Compare(const cJSON * const a, const cJSON * const b, const cJSON_bool case_sensitive) +{ + if ((a == NULL) || (b == NULL) || ((a->type & 0xFF) != (b->type & 0xFF)) || cJSON_IsInvalid(a)) + { + return false; + } + + /* check if type is valid */ + switch (a->type & 0xFF) + { + case cJSON_False: + case cJSON_True: + case cJSON_NULL: + case cJSON_Number: + case cJSON_String: + case cJSON_Raw: + case cJSON_Array: + case cJSON_Object: + break; + + default: + return false; + } + + /* identical objects are equal */ + if (a == b) + { + return true; + } + + switch (a->type & 0xFF) + { + /* in these cases and equal type is enough */ + case cJSON_False: + case cJSON_True: + case cJSON_NULL: + return true; + + case cJSON_Number: + if (a->valuedouble == b->valuedouble) + { + return true; + } + return false; + + case cJSON_String: + case cJSON_Raw: + if ((a->valuestring == NULL) || (b->valuestring == NULL)) + { + return false; + } + if (strcmp(a->valuestring, b->valuestring) == 0) + { + return true; + } + + return false; + + case cJSON_Array: + { + cJSON *a_element = a->child; + cJSON *b_element = b->child; + + for (; (a_element != NULL) && (b_element != NULL);) + { + if (!cJSON_Compare(a_element, b_element, case_sensitive)) + { + return false; + } + + a_element = a_element->next; + b_element = b_element->next; + } + + /* one of the arrays is longer than the other */ + if (a_element != b_element) { + return false; + } + + return true; + } + + case cJSON_Object: + { + cJSON *a_element = NULL; + cJSON *b_element = NULL; + cJSON_ArrayForEach(a_element, a) + { + /* TODO This has O(n^2) runtime, which is horrible! */ + b_element = get_object_item(b, a_element->string, case_sensitive); + if (b_element == NULL) + { + return false; + } + + if (!cJSON_Compare(a_element, b_element, case_sensitive)) + { + return false; + } + } + + /* doing this twice, once on a and b to prevent true comparison if a subset of b + * TODO: Do this the proper way, this is just a fix for now */ + cJSON_ArrayForEach(b_element, b) + { + a_element = get_object_item(a, b_element->string, case_sensitive); + if (a_element == NULL) + { + return false; + } + + if (!cJSON_Compare(b_element, a_element, case_sensitive)) + { + return false; + } + } + + return true; + } + + default: + return false; + } +} + +CJSON_PUBLIC(void *) cJSON_malloc(size_t size) +{ + return global_hooks.allocate(size); +} + +CJSON_PUBLIC(void) cJSON_free(void *object) +{ + global_hooks.deallocate(object); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/external/tinyxml2/tinyxml2.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/external/tinyxml2/tinyxml2.cpp new file mode 100644 index 0000000000..ebe0fd9eec --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/external/tinyxml2/tinyxml2.cpp @@ -0,0 +1,2802 @@ +/* +Original code by Lee Thomason (www.grinninglizard.com) + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any +damages arising from the use of this software. + +Permission is granted to anyone to use this software for any +purpose, including commercial applications, and to alter it and +redistribute it freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must +not claim that you wrote the original software. If you use this +software in a product, an acknowledgment in the product documentation +would be appreciated but is not required. + +2. Altered source versions must be plainly marked as such, and +must not be misrepresented as being the original software. + +3. This notice may not be removed or altered from any source +distribution. +*/ + +/* +This file has been modified from its original version by Amazon: + (1) Memory management operations use aws memory management api + (2) #includes all use <> +*/ + +#include <aws/core/external/tinyxml2/tinyxml2.h> + +#include <new> // yes, this one new style header, is in the Android SDK. +#if defined(ANDROID_NDK) || defined(__BORLANDC__) || defined(__QNXNTO__) +# include <stddef.h> +# include <stdarg.h> +#else +# include <cstddef> +# include <cstdarg> +#endif + +#if defined(_MSC_VER) && (_MSC_VER >= 1400 ) && (!defined WINCE) + // Microsoft Visual Studio, version 2005 and higher. Not WinCE. + /*int _snprintf_s( + char *buffer, + size_t sizeOfBuffer, + size_t count, + const char *format [, + argument] ... + );*/ + static inline int TIXML_SNPRINTF( char* buffer, size_t size, const char* format, ... ) + { + va_list va; + va_start( va, format ); + int result = vsnprintf_s( buffer, size, _TRUNCATE, format, va ); + va_end( va ); + return result; + } + + static inline int TIXML_VSNPRINTF( char* buffer, size_t size, const char* format, va_list va ) + { + int result = vsnprintf_s( buffer, size, _TRUNCATE, format, va ); + return result; + } + + #define TIXML_VSCPRINTF _vscprintf + #define TIXML_SSCANF sscanf_s +#elif defined _MSC_VER + // Microsoft Visual Studio 2003 and earlier or WinCE + #define TIXML_SNPRINTF _snprintf + #define TIXML_VSNPRINTF _vsnprintf + #define TIXML_SSCANF sscanf + #if (_MSC_VER < 1400 ) && (!defined WINCE) + // Microsoft Visual Studio 2003 and not WinCE. + #define TIXML_VSCPRINTF _vscprintf // VS2003's C runtime has this, but VC6 C runtime or WinCE SDK doesn't have. + #else + // Microsoft Visual Studio 2003 and earlier or WinCE. + static inline int TIXML_VSCPRINTF( const char* format, va_list va ) + { + int len = 512; + for (;;) { + len = len*2; + char* str = Aws::NewArray<char>(len, ALLOCATION_TAG); + const int required = _vsnprintf(str, len, format, va); + Aws::DeleteArray(str); + if ( required != -1 ) { + TIXMLASSERT( required >= 0 ); + len = required; + break; + } + } + TIXMLASSERT( len >= 0 ); + return len; + } + #endif +#else + // GCC version 3 and higher + //#warning( "Using sn* functions." ) + #define TIXML_SNPRINTF snprintf + #define TIXML_VSNPRINTF vsnprintf + static inline int TIXML_VSCPRINTF( const char* format, va_list va ) + { + int len = vsnprintf( 0, 0, format, va ); + TIXMLASSERT( len >= 0 ); + return len; + } + #define TIXML_SSCANF sscanf +#endif + + +static const char LINE_FEED = (char)0x0a; // all line endings are normalized to LF +static const char LF = LINE_FEED; +static const char CARRIAGE_RETURN = (char)0x0d; // CR gets filtered out +static const char CR = CARRIAGE_RETURN; +static const char SINGLE_QUOTE = '\''; +static const char DOUBLE_QUOTE = '\"'; + +// Bunch of unicode info at: +// http://www.unicode.org/faq/utf_bom.html +// ef bb bf (Microsoft "lead bytes") - designates UTF-8 + +static const unsigned char TIXML_UTF_LEAD_0 = 0xefU; +static const unsigned char TIXML_UTF_LEAD_1 = 0xbbU; +static const unsigned char TIXML_UTF_LEAD_2 = 0xbfU; + +namespace Aws +{ +namespace External +{ +namespace tinyxml2 +{ + +struct Entity { + const char* pattern; + int length; + char value; +}; + +static const int NUM_ENTITIES = 5; +static const Entity entities[NUM_ENTITIES] = { + { "quot", 4, DOUBLE_QUOTE }, + { "amp", 3, '&' }, + { "apos", 4, SINGLE_QUOTE }, + { "lt", 2, '<' }, + { "gt", 2, '>' } +}; + + +StrPair::~StrPair() +{ + Reset(); +} + + +void StrPair::TransferTo( StrPair* other ) +{ + if ( this == other ) { + return; + } + // This in effect implements the assignment operator by "moving" + // ownership (as in auto_ptr). + + TIXMLASSERT( other != 0 ); + TIXMLASSERT( other->_flags == 0 ); + TIXMLASSERT( other->_start == 0 ); + TIXMLASSERT( other->_end == 0 ); + + other->Reset(); + + other->_flags = _flags; + other->_start = _start; + other->_end = _end; + + _flags = 0; + _start = 0; + _end = 0; +} + + +void StrPair::Reset() +{ + if ( _flags & NEEDS_DELETE ) { + Aws::DeleteArray(_start); + } + _flags = 0; + _start = 0; + _end = 0; +} + + +void StrPair::SetStr( const char* str, int flags ) +{ + TIXMLASSERT( str ); + Reset(); + size_t len = strlen( str ); + TIXMLASSERT( _start == 0 ); + _start = Aws::NewArray<char>(len+1, ALLOCATION_TAG); + memcpy( _start, str, len+1 ); + _end = _start + len; + _flags = flags | NEEDS_DELETE; +} + + +char* StrPair::ParseText( char* p, const char* endTag, int strFlags, int* curLineNumPtr ) +{ + TIXMLASSERT( p ); + TIXMLASSERT( endTag && *endTag ); + TIXMLASSERT(curLineNumPtr); + + char* start = p; + char endChar = *endTag; + size_t length = strlen( endTag ); + + // Inner loop of text parsing. + while ( *p ) { + if ( *p == endChar && strncmp( p, endTag, length ) == 0 ) { + Set( start, p, strFlags ); + return p + length; + } else if (*p == '\n') { + ++(*curLineNumPtr); + } + ++p; + TIXMLASSERT( p ); + } + return 0; +} + + +char* StrPair::ParseName( char* p ) +{ + if ( !p || !(*p) ) { + return 0; + } + if ( !XMLUtil::IsNameStartChar( *p ) ) { + return 0; + } + + char* const start = p; + ++p; + while ( *p && XMLUtil::IsNameChar( *p ) ) { + ++p; + } + + Set( start, p, 0 ); + return p; +} + + +void StrPair::CollapseWhitespace() +{ + // Adjusting _start would cause undefined behavior on delete[] + TIXMLASSERT( ( _flags & NEEDS_DELETE ) == 0 ); + // Trim leading space. + _start = XMLUtil::SkipWhiteSpace( _start, 0 ); + + if ( *_start ) { + const char* p = _start; // the read pointer + char* q = _start; // the write pointer + + while( *p ) { + if ( XMLUtil::IsWhiteSpace( *p )) { + p = XMLUtil::SkipWhiteSpace( p, 0 ); + if ( *p == 0 ) { + break; // don't write to q; this trims the trailing space. + } + *q = ' '; + ++q; + } + *q = *p; + ++q; + ++p; + } + *q = 0; + } +} + + +const char* StrPair::GetStr() +{ + TIXMLASSERT( _start ); + TIXMLASSERT( _end ); + if ( _flags & NEEDS_FLUSH ) { + *_end = 0; + _flags ^= NEEDS_FLUSH; + + if ( _flags ) { + const char* p = _start; // the read pointer + char* q = _start; // the write pointer + + while( p < _end ) { + if ( (_flags & NEEDS_NEWLINE_NORMALIZATION) && *p == CR ) { + // CR-LF pair becomes LF + // CR alone becomes LF + // LF-CR becomes LF + if ( *(p+1) == LF ) { + p += 2; + } + else { + ++p; + } + *q = LF; + ++q; + } + else if ( (_flags & NEEDS_NEWLINE_NORMALIZATION) && *p == LF ) { + if ( *(p+1) == CR ) { + p += 2; + } + else { + ++p; + } + *q = LF; + ++q; + } + else if ( (_flags & NEEDS_ENTITY_PROCESSING) && *p == '&' ) { + // Entities handled by tinyXML2: + // - special entities in the entity table [in/out] + // - numeric character reference [in] + // 中 or 中 + + if ( *(p+1) == '#' ) { + const int buflen = 10; + char buf[buflen] = { 0 }; + int len = 0; + char* adjusted = const_cast<char*>( XMLUtil::GetCharacterRef( p, buf, &len ) ); + if ( adjusted == 0 ) { + *q = *p; + ++p; + ++q; + } + else { + TIXMLASSERT( 0 <= len && len <= buflen ); + TIXMLASSERT( q + len <= adjusted ); + p = adjusted; + memcpy( q, buf, len ); + q += len; + } + } + else { + bool entityFound = false; + for( int i = 0; i < NUM_ENTITIES; ++i ) { + const Entity& entity = entities[i]; + if ( strncmp( p + 1, entity.pattern, entity.length ) == 0 + && *( p + entity.length + 1 ) == ';' ) { + // Found an entity - convert. + *q = entity.value; + ++q; + p += entity.length + 2; + entityFound = true; + break; + } + } + if ( !entityFound ) { + // fixme: treat as error? + ++p; + ++q; + } + } + } + else { + *q = *p; + ++p; + ++q; + } + } + *q = 0; + } + // The loop below has plenty going on, and this + // is a less useful mode. Break it out. + if ( _flags & NEEDS_WHITESPACE_COLLAPSING ) { + CollapseWhitespace(); + } + _flags = (_flags & NEEDS_DELETE); + } + TIXMLASSERT( _start ); + return _start; +} + + + + +// --------- XMLUtil ----------- // + +const char* XMLUtil::writeBoolTrue = "true"; +const char* XMLUtil::writeBoolFalse = "false"; + +void XMLUtil::SetBoolSerialization(const char* writeTrue, const char* writeFalse) +{ + static const char* defTrue = "true"; + static const char* defFalse = "false"; + + writeBoolTrue = (writeTrue) ? writeTrue : defTrue; + writeBoolFalse = (writeFalse) ? writeFalse : defFalse; +} + + +const char* XMLUtil::ReadBOM( const char* p, bool* bom ) +{ + TIXMLASSERT( p ); + TIXMLASSERT( bom ); + *bom = false; + const unsigned char* pu = reinterpret_cast<const unsigned char*>(p); + // Check for BOM: + if ( *(pu+0) == TIXML_UTF_LEAD_0 + && *(pu+1) == TIXML_UTF_LEAD_1 + && *(pu+2) == TIXML_UTF_LEAD_2 ) { + *bom = true; + p += 3; + } + TIXMLASSERT( p ); + return p; +} + + +void XMLUtil::ConvertUTF32ToUTF8( unsigned long input, char* output, int* length ) +{ + const unsigned long BYTE_MASK = 0xBF; + const unsigned long BYTE_MARK = 0x80; + const unsigned long FIRST_BYTE_MARK[7] = { 0x00, 0x00, 0xC0, 0xE0, 0xF0, 0xF8, 0xFC }; + + if (input < 0x80) { + *length = 1; + } + else if ( input < 0x800 ) { + *length = 2; + } + else if ( input < 0x10000 ) { + *length = 3; + } + else if ( input < 0x200000 ) { + *length = 4; + } + else { + *length = 0; // This code won't convert this correctly anyway. + return; + } + + output += *length; + + // Scary scary fall throughs are annotated with carefully designed comments + // to suppress compiler warnings such as -Wimplicit-fallthrough in gcc + switch (*length) { + case 4: + --output; + *output = (char)((input | BYTE_MARK) & BYTE_MASK); + input >>= 6; + //fall through + case 3: + --output; + *output = (char)((input | BYTE_MARK) & BYTE_MASK); + input >>= 6; + //fall through + case 2: + --output; + *output = (char)((input | BYTE_MARK) & BYTE_MASK); + input >>= 6; + //fall through + case 1: + --output; + *output = (char)(input | FIRST_BYTE_MARK[*length]); + break; + default: + TIXMLASSERT( false ); + } +} + + +const char* XMLUtil::GetCharacterRef( const char* p, char* value, int* length ) +{ + // Presume an entity, and pull it out. + *length = 0; + + if ( *(p+1) == '#' && *(p+2) ) { + unsigned long ucs = 0; + TIXMLASSERT( sizeof( ucs ) >= 4 ); + ptrdiff_t delta = 0; + unsigned mult = 1; + static const char SEMICOLON = ';'; + + if ( *(p+2) == 'x' ) { + // Hexadecimal. + const char* q = p+3; + if ( !(*q) ) { + return 0; + } + + q = strchr( q, SEMICOLON ); + + if ( !q ) { + return 0; + } + TIXMLASSERT( *q == SEMICOLON ); + + delta = q-p; + --q; + + while ( *q != 'x' ) { + unsigned int digit = 0; + + if ( *q >= '0' && *q <= '9' ) { + digit = *q - '0'; + } + else if ( *q >= 'a' && *q <= 'f' ) { + digit = *q - 'a' + 10; + } + else if ( *q >= 'A' && *q <= 'F' ) { + digit = *q - 'A' + 10; + } + else { + return 0; + } + TIXMLASSERT( digit < 16 ); + TIXMLASSERT( digit == 0 || mult <= UINT_MAX / digit ); + const unsigned int digitScaled = mult * digit; + TIXMLASSERT( ucs <= ULONG_MAX - digitScaled ); + ucs += digitScaled; + TIXMLASSERT( mult <= UINT_MAX / 16 ); + mult *= 16; + --q; + } + } + else { + // Decimal. + const char* q = p+2; + if ( !(*q) ) { + return 0; + } + + q = strchr( q, SEMICOLON ); + + if ( !q ) { + return 0; + } + TIXMLASSERT( *q == SEMICOLON ); + + delta = q-p; + --q; + + while ( *q != '#' ) { + if ( *q >= '0' && *q <= '9' ) { + const unsigned int digit = *q - '0'; + TIXMLASSERT( digit < 10 ); + TIXMLASSERT( digit == 0 || mult <= UINT_MAX / digit ); + const unsigned int digitScaled = mult * digit; + TIXMLASSERT( ucs <= ULONG_MAX - digitScaled ); + ucs += digitScaled; + } + else { + return 0; + } + TIXMLASSERT( mult <= UINT_MAX / 10 ); + mult *= 10; + --q; + } + } + // convert the UCS to UTF-8 + ConvertUTF32ToUTF8( ucs, value, length ); + return p + delta + 1; + } + return p+1; +} + + +void XMLUtil::ToStr( int v, char* buffer, int bufferSize ) +{ + TIXML_SNPRINTF( buffer, bufferSize, "%d", v ); +} + + +void XMLUtil::ToStr( unsigned v, char* buffer, int bufferSize ) +{ + TIXML_SNPRINTF( buffer, bufferSize, "%u", v ); +} + + +void XMLUtil::ToStr( bool v, char* buffer, int bufferSize ) +{ + TIXML_SNPRINTF( buffer, bufferSize, "%s", v ? writeBoolTrue : writeBoolFalse); +} + +/* + ToStr() of a number is a very tricky topic. + https://github.com/leethomason/tinyxml2/issues/106 +*/ +void XMLUtil::ToStr( float v, char* buffer, int bufferSize ) +{ + TIXML_SNPRINTF( buffer, bufferSize, "%.8g", v ); +} + + +void XMLUtil::ToStr( double v, char* buffer, int bufferSize ) +{ + TIXML_SNPRINTF( buffer, bufferSize, "%.17g", v ); +} + + +void XMLUtil::ToStr(int64_t v, char* buffer, int bufferSize) +{ + // horrible syntax trick to make the compiler happy about %lld + TIXML_SNPRINTF(buffer, bufferSize, "%lld", (long long)v); +} + + +bool XMLUtil::ToInt( const char* str, int* value ) +{ + if ( TIXML_SSCANF( str, "%d", value ) == 1 ) { + return true; + } + return false; +} + +bool XMLUtil::ToUnsigned( const char* str, unsigned *value ) +{ + if ( TIXML_SSCANF( str, "%u", value ) == 1 ) { + return true; + } + return false; +} + +bool XMLUtil::ToBool( const char* str, bool* value ) +{ + int ival = 0; + if ( ToInt( str, &ival )) { + *value = (ival==0) ? false : true; + return true; + } + if ( StringEqual( str, "true" ) ) { + *value = true; + return true; + } + else if ( StringEqual( str, "false" ) ) { + *value = false; + return true; + } + return false; +} + + +bool XMLUtil::ToFloat( const char* str, float* value ) +{ + if ( TIXML_SSCANF( str, "%f", value ) == 1 ) { + return true; + } + return false; +} + + +bool XMLUtil::ToDouble( const char* str, double* value ) +{ + if ( TIXML_SSCANF( str, "%lf", value ) == 1 ) { + return true; + } + return false; +} + + +bool XMLUtil::ToInt64(const char* str, int64_t* value) +{ + long long v = 0; // horrible syntax trick to make the compiler happy about %lld + if (TIXML_SSCANF(str, "%lld", &v) == 1) { + *value = (int64_t)v; + return true; + } + return false; +} + + +char* XMLDocument::Identify( char* p, XMLNode** node ) +{ + TIXMLASSERT( node ); + TIXMLASSERT( p ); + char* const start = p; + int const startLine = _parseCurLineNum; + p = XMLUtil::SkipWhiteSpace( p, &_parseCurLineNum ); + if( !*p ) { + *node = 0; + TIXMLASSERT( p ); + return p; + } + + // These strings define the matching patterns: + static const char* xmlHeader = { "<?" }; + static const char* commentHeader = { "<!--" }; + static const char* cdataHeader = { "<![CDATA[" }; + static const char* dtdHeader = { "<!" }; + static const char* elementHeader = { "<" }; // and a header for everything else; check last. + + static const int xmlHeaderLen = 2; + static const int commentHeaderLen = 4; + static const int cdataHeaderLen = 9; + static const int dtdHeaderLen = 2; + static const int elementHeaderLen = 1; + + TIXMLASSERT( sizeof( XMLComment ) == sizeof( XMLUnknown ) ); // use same memory pool + TIXMLASSERT( sizeof( XMLComment ) == sizeof( XMLDeclaration ) ); // use same memory pool + XMLNode* returnNode = 0; + if ( XMLUtil::StringEqual( p, xmlHeader, xmlHeaderLen ) ) { + returnNode = CreateUnlinkedNode<XMLDeclaration>( _commentPool ); + returnNode->_parseLineNum = _parseCurLineNum; + p += xmlHeaderLen; + } + else if ( XMLUtil::StringEqual( p, commentHeader, commentHeaderLen ) ) { + returnNode = CreateUnlinkedNode<XMLComment>( _commentPool ); + returnNode->_parseLineNum = _parseCurLineNum; + p += commentHeaderLen; + } + else if ( XMLUtil::StringEqual( p, cdataHeader, cdataHeaderLen ) ) { + XMLText* text = CreateUnlinkedNode<XMLText>( _textPool ); + returnNode = text; + returnNode->_parseLineNum = _parseCurLineNum; + p += cdataHeaderLen; + text->SetCData( true ); + } + else if ( XMLUtil::StringEqual( p, dtdHeader, dtdHeaderLen ) ) { + returnNode = CreateUnlinkedNode<XMLUnknown>( _commentPool ); + returnNode->_parseLineNum = _parseCurLineNum; + p += dtdHeaderLen; + } + else if ( XMLUtil::StringEqual( p, elementHeader, elementHeaderLen ) ) { + returnNode = CreateUnlinkedNode<XMLElement>( _elementPool ); + returnNode->_parseLineNum = _parseCurLineNum; + p += elementHeaderLen; + } + else { + returnNode = CreateUnlinkedNode<XMLText>( _textPool ); + returnNode->_parseLineNum = _parseCurLineNum; // Report line of first non-whitespace character + p = start; // Back it up, all the text counts. + _parseCurLineNum = startLine; + } + + TIXMLASSERT( returnNode ); + TIXMLASSERT( p ); + *node = returnNode; + return p; +} + + +bool XMLDocument::Accept( XMLVisitor* visitor ) const +{ + TIXMLASSERT( visitor ); + if ( visitor->VisitEnter( *this ) ) { + for ( const XMLNode* node=FirstChild(); node; node=node->NextSibling() ) { + if ( !node->Accept( visitor ) ) { + break; + } + } + } + return visitor->VisitExit( *this ); +} + + +// --------- XMLNode ----------- // + +XMLNode::XMLNode( XMLDocument* doc ) : + _document( doc ), + _parent( 0 ), + _value(), + _parseLineNum( 0 ), + _firstChild( 0 ), _lastChild( 0 ), + _prev( 0 ), _next( 0 ), + _userData( 0 ), + _memPool( 0 ) +{ +} + + +XMLNode::~XMLNode() +{ + DeleteChildren(); + if ( _parent ) { + _parent->Unlink( this ); + } +} + +const char* XMLNode::Value() const +{ + // Edge case: XMLDocuments don't have a Value. Return null. + if ( this->ToDocument() ) + return 0; + return _value.GetStr(); +} + +void XMLNode::SetValue( const char* str, bool staticMem ) +{ + if ( staticMem ) { + _value.SetInternedStr( str ); + } + else { + _value.SetStr( str ); + } +} + +XMLNode* XMLNode::DeepClone(XMLDocument* target) const +{ + XMLNode* clone = this->ShallowClone(target); + if (!clone) return 0; + + for (const XMLNode* child = this->FirstChild(); child; child = child->NextSibling()) { + XMLNode* childClone = child->DeepClone(target); + TIXMLASSERT(childClone); + clone->InsertEndChild(childClone); + } + return clone; +} + +void XMLNode::DeleteChildren() +{ + while( _firstChild ) { + TIXMLASSERT( _lastChild ); + DeleteChild( _firstChild ); + } + _firstChild = _lastChild = 0; +} + + +void XMLNode::Unlink( XMLNode* child ) +{ + TIXMLASSERT( child ); + TIXMLASSERT( child->_document == _document ); + TIXMLASSERT( child->_parent == this ); + if ( child == _firstChild ) { + _firstChild = _firstChild->_next; + } + if ( child == _lastChild ) { + _lastChild = _lastChild->_prev; + } + + if ( child->_prev ) { + child->_prev->_next = child->_next; + } + if ( child->_next ) { + child->_next->_prev = child->_prev; + } + child->_next = 0; + child->_prev = 0; + child->_parent = 0; +} + + +void XMLNode::DeleteChild( XMLNode* node ) +{ + TIXMLASSERT( node ); + TIXMLASSERT( node->_document == _document ); + TIXMLASSERT( node->_parent == this ); + Unlink( node ); + TIXMLASSERT(node->_prev == 0); + TIXMLASSERT(node->_next == 0); + TIXMLASSERT(node->_parent == 0); + DeleteNode( node ); +} + + +XMLNode* XMLNode::InsertEndChild( XMLNode* addThis ) +{ + TIXMLASSERT( addThis ); + if ( addThis->_document != _document ) { + TIXMLASSERT( false ); + return 0; + } + InsertChildPreamble( addThis ); + + if ( _lastChild ) { + TIXMLASSERT( _firstChild ); + TIXMLASSERT( _lastChild->_next == 0 ); + _lastChild->_next = addThis; + addThis->_prev = _lastChild; + _lastChild = addThis; + + addThis->_next = 0; + } + else { + TIXMLASSERT( _firstChild == 0 ); + _firstChild = _lastChild = addThis; + + addThis->_prev = 0; + addThis->_next = 0; + } + addThis->_parent = this; + return addThis; +} + + +XMLNode* XMLNode::InsertFirstChild( XMLNode* addThis ) +{ + TIXMLASSERT( addThis ); + if ( addThis->_document != _document ) { + TIXMLASSERT( false ); + return 0; + } + InsertChildPreamble( addThis ); + + if ( _firstChild ) { + TIXMLASSERT( _lastChild ); + TIXMLASSERT( _firstChild->_prev == 0 ); + + _firstChild->_prev = addThis; + addThis->_next = _firstChild; + _firstChild = addThis; + + addThis->_prev = 0; + } + else { + TIXMLASSERT( _lastChild == 0 ); + _firstChild = _lastChild = addThis; + + addThis->_prev = 0; + addThis->_next = 0; + } + addThis->_parent = this; + return addThis; +} + + +XMLNode* XMLNode::InsertAfterChild( XMLNode* afterThis, XMLNode* addThis ) +{ + TIXMLASSERT( addThis ); + if ( addThis->_document != _document ) { + TIXMLASSERT( false ); + return 0; + } + + TIXMLASSERT( afterThis ); + + if ( afterThis->_parent != this ) { + TIXMLASSERT( false ); + return 0; + } + if ( afterThis == addThis ) { + // Current state: BeforeThis -> AddThis -> OneAfterAddThis + // Now AddThis must disappear from it's location and then + // reappear between BeforeThis and OneAfterAddThis. + // So just leave it where it is. + return addThis; + } + + if ( afterThis->_next == 0 ) { + // The last node or the only node. + return InsertEndChild( addThis ); + } + InsertChildPreamble( addThis ); + addThis->_prev = afterThis; + addThis->_next = afterThis->_next; + afterThis->_next->_prev = addThis; + afterThis->_next = addThis; + addThis->_parent = this; + return addThis; +} + + + + +const XMLElement* XMLNode::FirstChildElement( const char* name ) const +{ + for( const XMLNode* node = _firstChild; node; node = node->_next ) { + const XMLElement* element = node->ToElementWithName( name ); + if ( element ) { + return element; + } + } + return 0; +} + + +const XMLElement* XMLNode::LastChildElement( const char* name ) const +{ + for( const XMLNode* node = _lastChild; node; node = node->_prev ) { + const XMLElement* element = node->ToElementWithName( name ); + if ( element ) { + return element; + } + } + return 0; +} + + +const XMLElement* XMLNode::NextSiblingElement( const char* name ) const +{ + for( const XMLNode* node = _next; node; node = node->_next ) { + const XMLElement* element = node->ToElementWithName( name ); + if ( element ) { + return element; + } + } + return 0; +} + + +const XMLElement* XMLNode::PreviousSiblingElement( const char* name ) const +{ + for( const XMLNode* node = _prev; node; node = node->_prev ) { + const XMLElement* element = node->ToElementWithName( name ); + if ( element ) { + return element; + } + } + return 0; +} + + +char* XMLNode::ParseDeep( char* p, StrPair* parentEndTag, int* curLineNumPtr ) +{ + // This is a recursive method, but thinking about it "at the current level" + // it is a pretty simple flat list: + // <foo/> + // <!-- comment --> + // + // With a special case: + // <foo> + // </foo> + // <!-- comment --> + // + // Where the closing element (/foo) *must* be the next thing after the opening + // element, and the names must match. BUT the tricky bit is that the closing + // element will be read by the child. + // + // 'endTag' is the end tag for this node, it is returned by a call to a child. + // 'parentEnd' is the end tag for the parent, which is filled in and returned. + + while( p && *p ) { + XMLNode* node = 0; + + p = _document->Identify( p, &node ); + TIXMLASSERT( p ); + if ( node == 0 ) { + break; + } + + int initialLineNum = node->_parseLineNum; + + StrPair endTag; + p = node->ParseDeep( p, &endTag, curLineNumPtr ); + if ( !p ) { + DeleteNode( node ); + if ( !_document->Error() ) { + _document->SetError( XML_ERROR_PARSING, initialLineNum, 0); + } + break; + } + + XMLDeclaration* decl = node->ToDeclaration(); + if ( decl ) { + // Declarations are only allowed at document level + bool wellLocated = ( ToDocument() != 0 ); + if ( wellLocated ) { + // Multiple declarations are allowed but all declarations + // must occur before anything else + for ( const XMLNode* existingNode = _document->FirstChild(); existingNode; existingNode = existingNode->NextSibling() ) { + if ( !existingNode->ToDeclaration() ) { + wellLocated = false; + break; + } + } + } + if ( !wellLocated ) { + _document->SetError( XML_ERROR_PARSING_DECLARATION, initialLineNum, "XMLDeclaration value=%s", decl->Value()); + DeleteNode( node ); + break; + } + } + + XMLElement* ele = node->ToElement(); + if ( ele ) { + // We read the end tag. Return it to the parent. + if ( ele->ClosingType() == XMLElement::CLOSING ) { + if ( parentEndTag ) { + ele->_value.TransferTo( parentEndTag ); + } + node->_memPool->SetTracked(); // created and then immediately deleted. + DeleteNode( node ); + return p; + } + + // Handle an end tag returned to this level. + // And handle a bunch of annoying errors. + bool mismatch = false; + if ( endTag.Empty() ) { + if ( ele->ClosingType() == XMLElement::OPEN ) { + mismatch = true; + } + } + else { + if ( ele->ClosingType() != XMLElement::OPEN ) { + mismatch = true; + } + else if ( !XMLUtil::StringEqual( endTag.GetStr(), ele->Name() ) ) { + mismatch = true; + } + } + if ( mismatch ) { + _document->SetError( XML_ERROR_MISMATCHED_ELEMENT, initialLineNum, "XMLElement name=%s", ele->Name()); + DeleteNode( node ); + break; + } + } + InsertEndChild( node ); + } + return 0; +} + +/*static*/ void XMLNode::DeleteNode( XMLNode* node ) +{ + if ( node == 0 ) { + return; + } + TIXMLASSERT(node->_document); + if (!node->ToDocument()) { + node->_document->MarkInUse(node); + } + + MemPool* pool = node->_memPool; + node->~XMLNode(); + pool->Free( node ); +} + +void XMLNode::InsertChildPreamble( XMLNode* insertThis ) const +{ + TIXMLASSERT( insertThis ); + TIXMLASSERT( insertThis->_document == _document ); + + if (insertThis->_parent) { + insertThis->_parent->Unlink( insertThis ); + } + else { + insertThis->_document->MarkInUse(insertThis); + insertThis->_memPool->SetTracked(); + } +} + +const XMLElement* XMLNode::ToElementWithName( const char* name ) const +{ + const XMLElement* element = this->ToElement(); + if ( element == 0 ) { + return 0; + } + if ( name == 0 ) { + return element; + } + if ( XMLUtil::StringEqual( element->Name(), name ) ) { + return element; + } + return 0; +} + +// --------- XMLText ---------- // +char* XMLText::ParseDeep( char* p, StrPair*, int* curLineNumPtr ) +{ + if ( this->CData() ) { + p = _value.ParseText( p, "]]>", StrPair::NEEDS_NEWLINE_NORMALIZATION, curLineNumPtr ); + if ( !p ) { + _document->SetError( XML_ERROR_PARSING_CDATA, _parseLineNum, 0 ); + } + return p; + } + else { + int flags = _document->ProcessEntities() ? StrPair::TEXT_ELEMENT : StrPair::TEXT_ELEMENT_LEAVE_ENTITIES; + if ( _document->WhitespaceMode() == COLLAPSE_WHITESPACE ) { + flags |= StrPair::NEEDS_WHITESPACE_COLLAPSING; + } + + p = _value.ParseText( p, "<", flags, curLineNumPtr ); + if ( p && *p ) { + return p-1; + } + if ( !p ) { + _document->SetError( XML_ERROR_PARSING_TEXT, _parseLineNum, 0 ); + } + } + return 0; +} + + +XMLNode* XMLText::ShallowClone( XMLDocument* doc ) const +{ + if ( !doc ) { + doc = _document; + } + XMLText* text = doc->NewText( Value() ); // fixme: this will always allocate memory. Intern? + text->SetCData( this->CData() ); + return text; +} + + +bool XMLText::ShallowEqual( const XMLNode* compare ) const +{ + TIXMLASSERT( compare ); + const XMLText* text = compare->ToText(); + return ( text && XMLUtil::StringEqual( text->Value(), Value() ) ); +} + + +bool XMLText::Accept( XMLVisitor* visitor ) const +{ + TIXMLASSERT( visitor ); + return visitor->Visit( *this ); +} + + +// --------- XMLComment ---------- // + +XMLComment::XMLComment( XMLDocument* doc ) : XMLNode( doc ) +{ +} + + +XMLComment::~XMLComment() +{ +} + + +char* XMLComment::ParseDeep( char* p, StrPair*, int* curLineNumPtr ) +{ + // Comment parses as text. + p = _value.ParseText( p, "-->", StrPair::COMMENT, curLineNumPtr ); + if ( p == 0 ) { + _document->SetError( XML_ERROR_PARSING_COMMENT, _parseLineNum, 0 ); + } + return p; +} + + +XMLNode* XMLComment::ShallowClone( XMLDocument* doc ) const +{ + if ( !doc ) { + doc = _document; + } + XMLComment* comment = doc->NewComment( Value() ); // fixme: this will always allocate memory. Intern? + return comment; +} + + +bool XMLComment::ShallowEqual( const XMLNode* compare ) const +{ + TIXMLASSERT( compare ); + const XMLComment* comment = compare->ToComment(); + return ( comment && XMLUtil::StringEqual( comment->Value(), Value() )); +} + + +bool XMLComment::Accept( XMLVisitor* visitor ) const +{ + TIXMLASSERT( visitor ); + return visitor->Visit( *this ); +} + + +// --------- XMLDeclaration ---------- // + +XMLDeclaration::XMLDeclaration( XMLDocument* doc ) : XMLNode( doc ) +{ +} + + +XMLDeclaration::~XMLDeclaration() +{ + //printf( "~XMLDeclaration\n" ); +} + + +char* XMLDeclaration::ParseDeep( char* p, StrPair*, int* curLineNumPtr ) +{ + // Declaration parses as text. + p = _value.ParseText( p, "?>", StrPair::NEEDS_NEWLINE_NORMALIZATION, curLineNumPtr ); + if ( p == 0 ) { + _document->SetError( XML_ERROR_PARSING_DECLARATION, _parseLineNum, 0 ); + } + return p; +} + + +XMLNode* XMLDeclaration::ShallowClone( XMLDocument* doc ) const +{ + if ( !doc ) { + doc = _document; + } + XMLDeclaration* dec = doc->NewDeclaration( Value() ); // fixme: this will always allocate memory. Intern? + return dec; +} + + +bool XMLDeclaration::ShallowEqual( const XMLNode* compare ) const +{ + TIXMLASSERT( compare ); + const XMLDeclaration* declaration = compare->ToDeclaration(); + return ( declaration && XMLUtil::StringEqual( declaration->Value(), Value() )); +} + + + +bool XMLDeclaration::Accept( XMLVisitor* visitor ) const +{ + TIXMLASSERT( visitor ); + return visitor->Visit( *this ); +} + +// --------- XMLUnknown ---------- // + +XMLUnknown::XMLUnknown( XMLDocument* doc ) : XMLNode( doc ) +{ +} + + +XMLUnknown::~XMLUnknown() +{ +} + + +char* XMLUnknown::ParseDeep( char* p, StrPair*, int* curLineNumPtr ) +{ + // Unknown parses as text. + p = _value.ParseText( p, ">", StrPair::NEEDS_NEWLINE_NORMALIZATION, curLineNumPtr ); + if ( !p ) { + _document->SetError( XML_ERROR_PARSING_UNKNOWN, _parseLineNum, 0 ); + } + return p; +} + + +XMLNode* XMLUnknown::ShallowClone( XMLDocument* doc ) const +{ + if ( !doc ) { + doc = _document; + } + XMLUnknown* text = doc->NewUnknown( Value() ); // fixme: this will always allocate memory. Intern? + return text; +} + + +bool XMLUnknown::ShallowEqual( const XMLNode* compare ) const +{ + TIXMLASSERT( compare ); + const XMLUnknown* unknown = compare->ToUnknown(); + return ( unknown && XMLUtil::StringEqual( unknown->Value(), Value() )); +} + + +bool XMLUnknown::Accept( XMLVisitor* visitor ) const +{ + TIXMLASSERT( visitor ); + return visitor->Visit( *this ); +} + +// --------- XMLAttribute ---------- // + +const char* XMLAttribute::Name() const +{ + return _name.GetStr(); +} + +const char* XMLAttribute::Value() const +{ + return _value.GetStr(); +} + +char* XMLAttribute::ParseDeep( char* p, bool processEntities, int* curLineNumPtr ) +{ + // Parse using the name rules: bug fix, was using ParseText before + p = _name.ParseName( p ); + if ( !p || !*p ) { + return 0; + } + + // Skip white space before = + p = XMLUtil::SkipWhiteSpace( p, curLineNumPtr ); + if ( *p != '=' ) { + return 0; + } + + ++p; // move up to opening quote + p = XMLUtil::SkipWhiteSpace( p, curLineNumPtr ); + if ( *p != '\"' && *p != '\'' ) { + return 0; + } + + char endTag[2] = { *p, 0 }; + ++p; // move past opening quote + + p = _value.ParseText( p, endTag, processEntities ? StrPair::ATTRIBUTE_VALUE : StrPair::ATTRIBUTE_VALUE_LEAVE_ENTITIES, curLineNumPtr ); + return p; +} + + +void XMLAttribute::SetName( const char* n ) +{ + _name.SetStr( n ); +} + + +XMLError XMLAttribute::QueryIntValue( int* value ) const +{ + if ( XMLUtil::ToInt( Value(), value )) { + return XML_SUCCESS; + } + return XML_WRONG_ATTRIBUTE_TYPE; +} + + +XMLError XMLAttribute::QueryUnsignedValue( unsigned int* value ) const +{ + if ( XMLUtil::ToUnsigned( Value(), value )) { + return XML_SUCCESS; + } + return XML_WRONG_ATTRIBUTE_TYPE; +} + + +XMLError XMLAttribute::QueryInt64Value(int64_t* value) const +{ + if (XMLUtil::ToInt64(Value(), value)) { + return XML_SUCCESS; + } + return XML_WRONG_ATTRIBUTE_TYPE; +} + + +XMLError XMLAttribute::QueryBoolValue( bool* value ) const +{ + if ( XMLUtil::ToBool( Value(), value )) { + return XML_SUCCESS; + } + return XML_WRONG_ATTRIBUTE_TYPE; +} + + +XMLError XMLAttribute::QueryFloatValue( float* value ) const +{ + if ( XMLUtil::ToFloat( Value(), value )) { + return XML_SUCCESS; + } + return XML_WRONG_ATTRIBUTE_TYPE; +} + + +XMLError XMLAttribute::QueryDoubleValue( double* value ) const +{ + if ( XMLUtil::ToDouble( Value(), value )) { + return XML_SUCCESS; + } + return XML_WRONG_ATTRIBUTE_TYPE; +} + + +void XMLAttribute::SetAttribute( const char* v ) +{ + _value.SetStr( v ); +} + + +void XMLAttribute::SetAttribute( int v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + _value.SetStr( buf ); +} + + +void XMLAttribute::SetAttribute( unsigned v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + _value.SetStr( buf ); +} + + +void XMLAttribute::SetAttribute(int64_t v) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr(v, buf, BUF_SIZE); + _value.SetStr(buf); +} + + + +void XMLAttribute::SetAttribute( bool v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + _value.SetStr( buf ); +} + +void XMLAttribute::SetAttribute( double v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + _value.SetStr( buf ); +} + +void XMLAttribute::SetAttribute( float v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + _value.SetStr( buf ); +} + + +// --------- XMLElement ---------- // +XMLElement::XMLElement( XMLDocument* doc ) : XMLNode( doc ), + _closingType( OPEN ), + _rootAttribute( 0 ) +{ +} + + +XMLElement::~XMLElement() +{ + while( _rootAttribute ) { + XMLAttribute* next = _rootAttribute->_next; + DeleteAttribute( _rootAttribute ); + _rootAttribute = next; + } +} + + +const XMLAttribute* XMLElement::FindAttribute( const char* name ) const +{ + for( XMLAttribute* a = _rootAttribute; a; a = a->_next ) { + if ( XMLUtil::StringEqual( a->Name(), name ) ) { + return a; + } + } + return 0; +} + + +const char* XMLElement::Attribute( const char* name, const char* value ) const +{ + const XMLAttribute* a = FindAttribute( name ); + if ( !a ) { + return 0; + } + if ( !value || XMLUtil::StringEqual( a->Value(), value )) { + return a->Value(); + } + return 0; +} + +int XMLElement::IntAttribute(const char* name, int defaultValue) const +{ + int i = defaultValue; + QueryIntAttribute(name, &i); + return i; +} + +unsigned XMLElement::UnsignedAttribute(const char* name, unsigned defaultValue) const +{ + unsigned i = defaultValue; + QueryUnsignedAttribute(name, &i); + return i; +} + +int64_t XMLElement::Int64Attribute(const char* name, int64_t defaultValue) const +{ + int64_t i = defaultValue; + QueryInt64Attribute(name, &i); + return i; +} + +bool XMLElement::BoolAttribute(const char* name, bool defaultValue) const +{ + bool b = defaultValue; + QueryBoolAttribute(name, &b); + return b; +} + +double XMLElement::DoubleAttribute(const char* name, double defaultValue) const +{ + double d = defaultValue; + QueryDoubleAttribute(name, &d); + return d; +} + +float XMLElement::FloatAttribute(const char* name, float defaultValue) const +{ + float f = defaultValue; + QueryFloatAttribute(name, &f); + return f; +} + +const char* XMLElement::GetText() const +{ + if ( FirstChild() && FirstChild()->ToText() ) { + return FirstChild()->Value(); + } + return 0; +} + + +void XMLElement::SetText( const char* inText ) +{ + if ( FirstChild() && FirstChild()->ToText() ) + FirstChild()->SetValue( inText ); + else { + XMLText* theText = GetDocument()->NewText( inText ); + InsertFirstChild( theText ); + } +} + + +void XMLElement::SetText( int v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + SetText( buf ); +} + + +void XMLElement::SetText( unsigned v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + SetText( buf ); +} + + +void XMLElement::SetText(int64_t v) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr(v, buf, BUF_SIZE); + SetText(buf); +} + + +void XMLElement::SetText( bool v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + SetText( buf ); +} + + +void XMLElement::SetText( float v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + SetText( buf ); +} + + +void XMLElement::SetText( double v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + SetText( buf ); +} + + +XMLError XMLElement::QueryIntText( int* ival ) const +{ + if ( FirstChild() && FirstChild()->ToText() ) { + const char* t = FirstChild()->Value(); + if ( XMLUtil::ToInt( t, ival ) ) { + return XML_SUCCESS; + } + return XML_CAN_NOT_CONVERT_TEXT; + } + return XML_NO_TEXT_NODE; +} + + +XMLError XMLElement::QueryUnsignedText( unsigned* uval ) const +{ + if ( FirstChild() && FirstChild()->ToText() ) { + const char* t = FirstChild()->Value(); + if ( XMLUtil::ToUnsigned( t, uval ) ) { + return XML_SUCCESS; + } + return XML_CAN_NOT_CONVERT_TEXT; + } + return XML_NO_TEXT_NODE; +} + + +XMLError XMLElement::QueryInt64Text(int64_t* ival) const +{ + if (FirstChild() && FirstChild()->ToText()) { + const char* t = FirstChild()->Value(); + if (XMLUtil::ToInt64(t, ival)) { + return XML_SUCCESS; + } + return XML_CAN_NOT_CONVERT_TEXT; + } + return XML_NO_TEXT_NODE; +} + + +XMLError XMLElement::QueryBoolText( bool* bval ) const +{ + if ( FirstChild() && FirstChild()->ToText() ) { + const char* t = FirstChild()->Value(); + if ( XMLUtil::ToBool( t, bval ) ) { + return XML_SUCCESS; + } + return XML_CAN_NOT_CONVERT_TEXT; + } + return XML_NO_TEXT_NODE; +} + + +XMLError XMLElement::QueryDoubleText( double* dval ) const +{ + if ( FirstChild() && FirstChild()->ToText() ) { + const char* t = FirstChild()->Value(); + if ( XMLUtil::ToDouble( t, dval ) ) { + return XML_SUCCESS; + } + return XML_CAN_NOT_CONVERT_TEXT; + } + return XML_NO_TEXT_NODE; +} + + +XMLError XMLElement::QueryFloatText( float* fval ) const +{ + if ( FirstChild() && FirstChild()->ToText() ) { + const char* t = FirstChild()->Value(); + if ( XMLUtil::ToFloat( t, fval ) ) { + return XML_SUCCESS; + } + return XML_CAN_NOT_CONVERT_TEXT; + } + return XML_NO_TEXT_NODE; +} + +int XMLElement::IntText(int defaultValue) const +{ + int i = defaultValue; + QueryIntText(&i); + return i; +} + +unsigned XMLElement::UnsignedText(unsigned defaultValue) const +{ + unsigned i = defaultValue; + QueryUnsignedText(&i); + return i; +} + +int64_t XMLElement::Int64Text(int64_t defaultValue) const +{ + int64_t i = defaultValue; + QueryInt64Text(&i); + return i; +} + +bool XMLElement::BoolText(bool defaultValue) const +{ + bool b = defaultValue; + QueryBoolText(&b); + return b; +} + +double XMLElement::DoubleText(double defaultValue) const +{ + double d = defaultValue; + QueryDoubleText(&d); + return d; +} + +float XMLElement::FloatText(float defaultValue) const +{ + float f = defaultValue; + QueryFloatText(&f); + return f; +} + + +XMLAttribute* XMLElement::FindOrCreateAttribute( const char* name ) +{ + XMLAttribute* last = 0; + XMLAttribute* attrib = 0; + for( attrib = _rootAttribute; + attrib; + last = attrib, attrib = attrib->_next ) { + if ( XMLUtil::StringEqual( attrib->Name(), name ) ) { + break; + } + } + if ( !attrib ) { + attrib = CreateAttribute(); + TIXMLASSERT( attrib ); + if ( last ) { + TIXMLASSERT( last->_next == 0 ); + last->_next = attrib; + } + else { + TIXMLASSERT( _rootAttribute == 0 ); + _rootAttribute = attrib; + } + attrib->SetName( name ); + } + return attrib; +} + + +void XMLElement::DeleteAttribute( const char* name ) +{ + XMLAttribute* prev = 0; + for( XMLAttribute* a=_rootAttribute; a; a=a->_next ) { + if ( XMLUtil::StringEqual( name, a->Name() ) ) { + if ( prev ) { + prev->_next = a->_next; + } + else { + _rootAttribute = a->_next; + } + DeleteAttribute( a ); + break; + } + prev = a; + } +} + + +char* XMLElement::ParseAttributes( char* p, int* curLineNumPtr ) +{ + XMLAttribute* prevAttribute = 0; + + // Read the attributes. + while( p ) { + p = XMLUtil::SkipWhiteSpace( p, curLineNumPtr ); + if ( !(*p) ) { + _document->SetError( XML_ERROR_PARSING_ELEMENT, _parseLineNum, "XMLElement name=%s", Name() ); + return 0; + } + + // attribute. + if (XMLUtil::IsNameStartChar( *p ) ) { + XMLAttribute* attrib = CreateAttribute(); + TIXMLASSERT( attrib ); + attrib->_parseLineNum = _document->_parseCurLineNum; + + int attrLineNum = attrib->_parseLineNum; + + p = attrib->ParseDeep( p, _document->ProcessEntities(), curLineNumPtr ); + if ( !p || Attribute( attrib->Name() ) ) { + DeleteAttribute( attrib ); + _document->SetError( XML_ERROR_PARSING_ATTRIBUTE, attrLineNum, "XMLElement name=%s", Name() ); + return 0; + } + // There is a minor bug here: if the attribute in the source xml + // document is duplicated, it will not be detected and the + // attribute will be doubly added. However, tracking the 'prevAttribute' + // avoids re-scanning the attribute list. Preferring performance for + // now, may reconsider in the future. + if ( prevAttribute ) { + TIXMLASSERT( prevAttribute->_next == 0 ); + prevAttribute->_next = attrib; + } + else { + TIXMLASSERT( _rootAttribute == 0 ); + _rootAttribute = attrib; + } + prevAttribute = attrib; + } + // end of the tag + else if ( *p == '>' ) { + ++p; + break; + } + // end of the tag + else if ( *p == '/' && *(p+1) == '>' ) { + _closingType = CLOSED; + return p+2; // done; sealed element. + } + else { + _document->SetError( XML_ERROR_PARSING_ELEMENT, _parseLineNum, 0 ); + return 0; + } + } + return p; +} + +void XMLElement::DeleteAttribute( XMLAttribute* attribute ) +{ + if ( attribute == 0 ) { + return; + } + MemPool* pool = attribute->_memPool; + attribute->~XMLAttribute(); + pool->Free( attribute ); +} + +XMLAttribute* XMLElement::CreateAttribute() +{ + TIXMLASSERT( sizeof( XMLAttribute ) == _document->_attributePool.ItemSize() ); + XMLAttribute* attrib = new (_document->_attributePool.Alloc() ) XMLAttribute(); + TIXMLASSERT( attrib ); + attrib->_memPool = &_document->_attributePool; + attrib->_memPool->SetTracked(); + return attrib; +} + +// +// <ele></ele> +// <ele>foo<b>bar</b></ele> +// +char* XMLElement::ParseDeep( char* p, StrPair* parentEndTag, int* curLineNumPtr ) +{ + // Read the element name. + p = XMLUtil::SkipWhiteSpace( p, curLineNumPtr ); + + // The closing element is the </element> form. It is + // parsed just like a regular element then deleted from + // the DOM. + if ( *p == '/' ) { + _closingType = CLOSING; + ++p; + } + + p = _value.ParseName( p ); + if ( _value.Empty() ) { + return 0; + } + + p = ParseAttributes( p, curLineNumPtr ); + if ( !p || !*p || _closingType != OPEN ) { + return p; + } + + p = XMLNode::ParseDeep( p, parentEndTag, curLineNumPtr ); + return p; +} + + + +XMLNode* XMLElement::ShallowClone( XMLDocument* doc ) const +{ + if ( !doc ) { + doc = _document; + } + XMLElement* element = doc->NewElement( Value() ); // fixme: this will always allocate memory. Intern? + for( const XMLAttribute* a=FirstAttribute(); a; a=a->Next() ) { + element->SetAttribute( a->Name(), a->Value() ); // fixme: this will always allocate memory. Intern? + } + return element; +} + + +bool XMLElement::ShallowEqual( const XMLNode* compare ) const +{ + TIXMLASSERT( compare ); + const XMLElement* other = compare->ToElement(); + if ( other && XMLUtil::StringEqual( other->Name(), Name() )) { + + const XMLAttribute* a=FirstAttribute(); + const XMLAttribute* b=other->FirstAttribute(); + + while ( a && b ) { + if ( !XMLUtil::StringEqual( a->Value(), b->Value() ) ) { + return false; + } + a = a->Next(); + b = b->Next(); + } + if ( a || b ) { + // different count + return false; + } + return true; + } + return false; +} + + +bool XMLElement::Accept( XMLVisitor* visitor ) const +{ + TIXMLASSERT( visitor ); + if ( visitor->VisitEnter( *this, _rootAttribute ) ) { + for ( const XMLNode* node=FirstChild(); node; node=node->NextSibling() ) { + if ( !node->Accept( visitor ) ) { + break; + } + } + } + return visitor->VisitExit( *this ); +} + + +// --------- XMLDocument ----------- // + +// Warning: List must match 'enum XMLError' +const char* XMLDocument::_errorNames[XML_ERROR_COUNT] = { + "XML_SUCCESS", + "XML_NO_ATTRIBUTE", + "XML_WRONG_ATTRIBUTE_TYPE", + "XML_ERROR_FILE_NOT_FOUND", + "XML_ERROR_FILE_COULD_NOT_BE_OPENED", + "XML_ERROR_FILE_READ_ERROR", + "UNUSED_XML_ERROR_ELEMENT_MISMATCH", + "XML_ERROR_PARSING_ELEMENT", + "XML_ERROR_PARSING_ATTRIBUTE", + "UNUSED_XML_ERROR_IDENTIFYING_TAG", + "XML_ERROR_PARSING_TEXT", + "XML_ERROR_PARSING_CDATA", + "XML_ERROR_PARSING_COMMENT", + "XML_ERROR_PARSING_DECLARATION", + "XML_ERROR_PARSING_UNKNOWN", + "XML_ERROR_EMPTY_DOCUMENT", + "XML_ERROR_MISMATCHED_ELEMENT", + "XML_ERROR_PARSING", + "XML_CAN_NOT_CONVERT_TEXT", + "XML_NO_TEXT_NODE" +}; + + +XMLDocument::XMLDocument( bool processEntities, Whitespace whitespaceMode ) : + XMLNode( 0 ), + _writeBOM( false ), + _processEntities( processEntities ), + _errorID(XML_SUCCESS), + _whitespaceMode( whitespaceMode ), + _errorStr(), + _errorLineNum( 0 ), + _charBuffer( 0 ), + _parseCurLineNum( 0 ), + _unlinked(), + _elementPool(), + _attributePool(), + _textPool(), + _commentPool() +{ + // avoid VC++ C4355 warning about 'this' in initializer list (C4355 is off by default in VS2012+) + _document = this; +} + + +XMLDocument::~XMLDocument() +{ + Clear(); +} + + +void XMLDocument::MarkInUse(XMLNode* node) +{ + TIXMLASSERT(node); + TIXMLASSERT(node->_parent == 0); + + for (int i = 0; i < _unlinked.Size(); ++i) { + if (node == _unlinked[i]) { + _unlinked.SwapRemove(i); + break; + } + } +} + +void XMLDocument::Clear() +{ + DeleteChildren(); + while( _unlinked.Size()) { + DeleteNode(_unlinked[0]); // Will remove from _unlinked as part of delete. + } + +#ifdef DEBUG + const bool hadError = Error(); +#endif + ClearError(); + + Aws::DeleteArray(_charBuffer); + _charBuffer = 0; + +#if 0 + _textPool.Trace( "text" ); + _elementPool.Trace( "element" ); + _commentPool.Trace( "comment" ); + _attributePool.Trace( "attribute" ); +#endif + +#ifdef DEBUG + if ( !hadError ) { + TIXMLASSERT( _elementPool.CurrentAllocs() == _elementPool.Untracked() ); + TIXMLASSERT( _attributePool.CurrentAllocs() == _attributePool.Untracked() ); + TIXMLASSERT( _textPool.CurrentAllocs() == _textPool.Untracked() ); + TIXMLASSERT( _commentPool.CurrentAllocs() == _commentPool.Untracked() ); + } +#endif +} + + +void XMLDocument::DeepCopy(XMLDocument* target) const +{ + TIXMLASSERT(target); + if (target == this) { + return; // technically success - a no-op. + } + + target->Clear(); + for (const XMLNode* node = this->FirstChild(); node; node = node->NextSibling()) { + target->InsertEndChild(node->DeepClone(target)); + } +} + +XMLElement* XMLDocument::NewElement( const char* name ) +{ + XMLElement* ele = CreateUnlinkedNode<XMLElement>( _elementPool ); + ele->SetName( name ); + return ele; +} + + +XMLComment* XMLDocument::NewComment( const char* str ) +{ + XMLComment* comment = CreateUnlinkedNode<XMLComment>( _commentPool ); + comment->SetValue( str ); + return comment; +} + + +XMLText* XMLDocument::NewText( const char* str ) +{ + XMLText* text = CreateUnlinkedNode<XMLText>( _textPool ); + text->SetValue( str ); + return text; +} + + +XMLDeclaration* XMLDocument::NewDeclaration( const char* str ) +{ + XMLDeclaration* dec = CreateUnlinkedNode<XMLDeclaration>( _commentPool ); + dec->SetValue( str ? str : "xml version=\"1.0\" encoding=\"UTF-8\"" ); + return dec; +} + + +XMLUnknown* XMLDocument::NewUnknown( const char* str ) +{ + XMLUnknown* unk = CreateUnlinkedNode<XMLUnknown>( _commentPool ); + unk->SetValue( str ); + return unk; +} + +static FILE* callfopen( const char* filepath, const char* mode ) +{ + TIXMLASSERT( filepath ); + TIXMLASSERT( mode ); +#if defined(_MSC_VER) && (_MSC_VER >= 1400 ) && (!defined WINCE) + FILE* fp = 0; + errno_t err = fopen_s( &fp, filepath, mode ); + if ( err ) { + return 0; + } +#else + FILE* fp = fopen( filepath, mode ); +#endif + return fp; +} + +void XMLDocument::DeleteNode( XMLNode* node ) { + TIXMLASSERT( node ); + TIXMLASSERT(node->_document == this ); + if (node->_parent) { + node->_parent->DeleteChild( node ); + } + else { + // Isn't in the tree. + // Use the parent delete. + // Also, we need to mark it tracked: we 'know' + // it was never used. + node->_memPool->SetTracked(); + // Call the static XMLNode version: + XMLNode::DeleteNode(node); + } +} + + +XMLError XMLDocument::LoadFile( const char* filename ) +{ + Clear(); + FILE* fp = callfopen( filename, "rb" ); + if ( !fp ) { + SetError( XML_ERROR_FILE_NOT_FOUND, 0, "filename=%s", filename ? filename : "<null>"); + return _errorID; + } + LoadFile( fp ); + fclose( fp ); + return _errorID; +} + +// This is likely overengineered template art to have a check that unsigned long value incremented +// by one still fits into size_t. If size_t type is larger than unsigned long type +// (x86_64-w64-mingw32 target) then the check is redundant and gcc and clang emit +// -Wtype-limits warning. This piece makes the compiler select code with a check when a check +// is useful and code with no check when a check is redundant depending on how size_t and unsigned long +// types sizes relate to each other. +template +<bool = (sizeof(unsigned long) >= sizeof(size_t))> +struct LongFitsIntoSizeTMinusOne { + static bool Fits( unsigned long value ) + { + return value < (size_t)-1; + } +}; + +template <> +struct LongFitsIntoSizeTMinusOne<false> { + static bool Fits( unsigned long ) + { + return true; + } +}; + +XMLError XMLDocument::LoadFile( FILE* fp ) +{ + Clear(); + + fseek( fp, 0, SEEK_SET ); + if ( fgetc( fp ) == EOF && ferror( fp ) != 0 ) { + SetError( XML_ERROR_FILE_READ_ERROR, 0, 0 ); + return _errorID; + } + + fseek( fp, 0, SEEK_END ); + const long filelength = ftell( fp ); + fseek( fp, 0, SEEK_SET ); + if ( filelength == -1L ) { + SetError( XML_ERROR_FILE_READ_ERROR, 0, 0 ); + return _errorID; + } + TIXMLASSERT( filelength >= 0 ); + + if ( !LongFitsIntoSizeTMinusOne<>::Fits( filelength ) ) { + // Cannot handle files which won't fit in buffer together with null terminator + SetError( XML_ERROR_FILE_READ_ERROR, 0, 0 ); + return _errorID; + } + + if ( filelength == 0 ) { + SetError( XML_ERROR_EMPTY_DOCUMENT, 0, 0 ); + return _errorID; + } + + const size_t size = filelength; + TIXMLASSERT( _charBuffer == 0 ); + _charBuffer = Aws::NewArray <char>(size+1, ALLOCATION_TAG); + size_t read = fread( _charBuffer, 1, size, fp ); + if ( read != size ) { + SetError( XML_ERROR_FILE_READ_ERROR, 0, 0 ); + return _errorID; + } + + _charBuffer[size] = 0; + + Parse(); + return _errorID; +} + + +XMLError XMLDocument::SaveFile( const char* filename, bool compact ) +{ + FILE* fp = callfopen( filename, "w" ); + if ( !fp ) { + SetError( XML_ERROR_FILE_COULD_NOT_BE_OPENED, 0, "filename=%s", filename ? filename : "<null>"); + return _errorID; + } + SaveFile(fp, compact); + fclose( fp ); + return _errorID; +} + + +XMLError XMLDocument::SaveFile( FILE* fp, bool compact ) +{ + // Clear any error from the last save, otherwise it will get reported + // for *this* call. + ClearError(); + XMLPrinter stream( fp, compact ); + Print( &stream ); + return _errorID; +} + + +XMLError XMLDocument::Parse( const char* p, size_t len ) +{ + Clear(); + + if ( len == 0 || !p || !*p ) { + SetError( XML_ERROR_EMPTY_DOCUMENT, 0, 0 ); + return _errorID; + } + if ( len == (size_t)(-1) ) { + len = strlen( p ); + } + TIXMLASSERT( _charBuffer == 0 ); + _charBuffer = Aws::NewArray<char>(len+1, ALLOCATION_TAG); + memcpy( _charBuffer, p, len ); + _charBuffer[len] = 0; + + Parse(); + if ( Error() ) { + // clean up now essentially dangling memory. + // and the parse fail can put objects in the + // pools that are dead and inaccessible. + DeleteChildren(); + _elementPool.Clear(); + _attributePool.Clear(); + _textPool.Clear(); + _commentPool.Clear(); + } + return _errorID; +} + + +void XMLDocument::Print( XMLPrinter* streamer ) const +{ + if ( streamer ) { + Accept( streamer ); + } + else { + XMLPrinter stdoutStreamer( stdout ); + Accept( &stdoutStreamer ); + } +} + + +void XMLDocument::SetError( XMLError error, int lineNum, const char* format, ... ) +{ + TIXMLASSERT( error >= 0 && error < XML_ERROR_COUNT ); + _errorID = error; + _errorLineNum = lineNum; + _errorStr.Reset(); + + if (format) { + size_t BUFFER_SIZE = 1000; + char* buffer = Aws::NewArray<char>(BUFFER_SIZE, ALLOCATION_TAG); + TIXML_SNPRINTF(buffer, BUFFER_SIZE, "Error=%s ErrorID=%d (0x%x) Line number=%d: ", ErrorIDToName(error), int(error), int(error), lineNum); + size_t len = strlen(buffer); + + va_list va; + va_start( va, format ); + TIXML_VSNPRINTF( buffer + len, BUFFER_SIZE - len, format, va ); + va_end( va ); + + _errorStr.SetStr(buffer); + Aws::DeleteArray(buffer); + } +} + + +/*static*/ const char* XMLDocument::ErrorIDToName(XMLError errorID) +{ + TIXMLASSERT( errorID >= 0 && errorID < XML_ERROR_COUNT ); + const char* errorName = _errorNames[errorID]; + TIXMLASSERT( errorName && errorName[0] ); + return errorName; +} + +const char* XMLDocument::ErrorStr() const +{ + return _errorStr.Empty() ? "" : _errorStr.GetStr(); +} + + +void XMLDocument::PrintError() const +{ + printf("%s\n", ErrorStr()); +} + +const char* XMLDocument::ErrorName() const +{ + return ErrorIDToName(_errorID); +} + +void XMLDocument::Parse() +{ + TIXMLASSERT( NoChildren() ); // Clear() must have been called previously + TIXMLASSERT( _charBuffer ); + _parseCurLineNum = 1; + _parseLineNum = 1; + char* p = _charBuffer; + p = XMLUtil::SkipWhiteSpace( p, &_parseCurLineNum ); + p = const_cast<char*>( XMLUtil::ReadBOM( p, &_writeBOM ) ); + if ( !*p ) { + SetError( XML_ERROR_EMPTY_DOCUMENT, 0, 0 ); + return; + } + ParseDeep(p, 0, &_parseCurLineNum ); +} + +XMLPrinter::XMLPrinter( FILE* file, bool compact, int depth ) : + _elementJustOpened( false ), + _stack(), + _firstElement( true ), + _fp( file ), + _depth( depth ), + _textDepth( -1 ), + _processEntities( true ), + _compactMode( compact ), + _buffer() +{ + for( int i=0; i<ENTITY_RANGE; ++i ) { + _entityFlag[i] = false; + _restrictedEntityFlag[i] = false; + } + for( int i=0; i<NUM_ENTITIES; ++i ) { + const char entityValue = entities[i].value; + const unsigned char flagIndex = (unsigned char)entityValue; + TIXMLASSERT( flagIndex < ENTITY_RANGE ); + _entityFlag[flagIndex] = true; + } + _restrictedEntityFlag[(unsigned char)'&'] = true; + _restrictedEntityFlag[(unsigned char)'<'] = true; + _restrictedEntityFlag[(unsigned char)'>'] = true; // not required, but consistency is nice + _buffer.Push( 0 ); +} + + +void XMLPrinter::Print( const char* format, ... ) +{ + va_list va; + va_start( va, format ); + + if ( _fp ) { + vfprintf( _fp, format, va ); + } + else { + const int len = TIXML_VSCPRINTF( format, va ); + // Close out and re-start the va-args + va_end( va ); + TIXMLASSERT( len >= 0 ); + va_start( va, format ); + TIXMLASSERT( _buffer.Size() > 0 && _buffer[_buffer.Size() - 1] == 0 ); + char* p = _buffer.PushArr( len ) - 1; // back up over the null terminator. + TIXML_VSNPRINTF( p, len+1, format, va ); + } + va_end( va ); +} + + +void XMLPrinter::Write( const char* data, size_t size ) +{ + if ( _fp ) { + fwrite ( data , sizeof(char), size, _fp); + } + else { + char* p = _buffer.PushArr( static_cast<int>(size) ) - 1; // back up over the null terminator. + memcpy( p, data, size ); + p[size] = 0; + } +} + + +void XMLPrinter::Putc( char ch ) +{ + if ( _fp ) { + fputc ( ch, _fp); + } + else { + char* p = _buffer.PushArr( sizeof(char) ) - 1; // back up over the null terminator. + p[0] = ch; + p[1] = 0; + } +} + + +void XMLPrinter::PrintSpace( int depth ) +{ + for( int i=0; i<depth; ++i ) { + Write( " " ); + } +} + + +void XMLPrinter::PrintString( const char* p, bool restricted ) +{ + // Look for runs of bytes between entities to print. + const char* q = p; + + if ( _processEntities ) { + const bool* flag = restricted ? _restrictedEntityFlag : _entityFlag; + while ( *q ) { + TIXMLASSERT( p <= q ); + // Remember, char is sometimes signed. (How many times has that bitten me?) + if ( *q > 0 && *q < ENTITY_RANGE ) { + // Check for entities. If one is found, flush + // the stream up until the entity, write the + // entity, and keep looking. + if ( flag[(unsigned char)(*q)] ) { + while ( p < q ) { + const size_t delta = q - p; + const int toPrint = ( INT_MAX < delta ) ? INT_MAX : (int)delta; + Write( p, toPrint ); + p += toPrint; + } + bool entityPatternPrinted = false; + for( int i=0; i<NUM_ENTITIES; ++i ) { + if ( entities[i].value == *q ) { + Putc( '&' ); + Write( entities[i].pattern, entities[i].length ); + Putc( ';' ); + entityPatternPrinted = true; + break; + } + } + if ( !entityPatternPrinted ) { + // TIXMLASSERT( entityPatternPrinted ) causes gcc -Wunused-but-set-variable in release + TIXMLASSERT( false ); + } + ++p; + } + } + ++q; + TIXMLASSERT( p <= q ); + } + } + // Flush the remaining string. This will be the entire + // string if an entity wasn't found. + TIXMLASSERT( p <= q ); + if ( !_processEntities || ( p < q ) ) { + const size_t delta = q - p; + const int toPrint = ( INT_MAX < delta ) ? INT_MAX : (int)delta; + Write( p, toPrint ); + } +} + + +void XMLPrinter::PushHeader( bool writeBOM, bool writeDec ) +{ + if ( writeBOM ) { + static const unsigned char bom[] = { TIXML_UTF_LEAD_0, TIXML_UTF_LEAD_1, TIXML_UTF_LEAD_2, 0 }; + Write( reinterpret_cast< const char* >( bom ) ); + } + if ( writeDec ) { + PushDeclaration( "xml version=\"1.0\"" ); + } +} + + +void XMLPrinter::OpenElement( const char* name, bool compactMode ) +{ + SealElementIfJustOpened(); + _stack.Push( name ); + + if ( _textDepth < 0 && !_firstElement && !compactMode ) { + Putc( '\n' ); + } + if ( !compactMode ) { + PrintSpace( _depth ); + } + + Write ( "<" ); + Write ( name ); + + _elementJustOpened = true; + _firstElement = false; + ++_depth; +} + + +void XMLPrinter::PushAttribute( const char* name, const char* value ) +{ + TIXMLASSERT( _elementJustOpened ); + Putc ( ' ' ); + Write( name ); + Write( "=\"" ); + PrintString( value, false ); + Putc ( '\"' ); +} + + +void XMLPrinter::PushAttribute( const char* name, int v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + PushAttribute( name, buf ); +} + + +void XMLPrinter::PushAttribute( const char* name, unsigned v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + PushAttribute( name, buf ); +} + + +void XMLPrinter::PushAttribute(const char* name, int64_t v) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr(v, buf, BUF_SIZE); + PushAttribute(name, buf); +} + + +void XMLPrinter::PushAttribute( const char* name, bool v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + PushAttribute( name, buf ); +} + + +void XMLPrinter::PushAttribute( const char* name, double v ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( v, buf, BUF_SIZE ); + PushAttribute( name, buf ); +} + + +void XMLPrinter::CloseElement( bool compactMode ) +{ + --_depth; + const char* name = _stack.Pop(); + + if ( _elementJustOpened ) { + Write( "/>" ); + } + else { + if ( _textDepth < 0 && !compactMode) { + Putc( '\n' ); + PrintSpace( _depth ); + } + Write ( "</" ); + Write ( name ); + Write ( ">" ); + } + + if ( _textDepth == _depth ) { + _textDepth = -1; + } + if ( _depth == 0 && !compactMode) { + Putc( '\n' ); + } + _elementJustOpened = false; +} + + +void XMLPrinter::SealElementIfJustOpened() +{ + if ( !_elementJustOpened ) { + return; + } + _elementJustOpened = false; + Putc( '>' ); +} + + +void XMLPrinter::PushText( const char* text, bool cdata ) +{ + _textDepth = _depth-1; + + SealElementIfJustOpened(); + if ( cdata ) { + Write( "<![CDATA[" ); + Write( text ); + Write( "]]>" ); + } + else { + PrintString( text, true ); + } +} + +void XMLPrinter::PushText( int64_t value ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( value, buf, BUF_SIZE ); + PushText( buf, false ); +} + +void XMLPrinter::PushText( int value ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( value, buf, BUF_SIZE ); + PushText( buf, false ); +} + + +void XMLPrinter::PushText( unsigned value ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( value, buf, BUF_SIZE ); + PushText( buf, false ); +} + + +void XMLPrinter::PushText( bool value ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( value, buf, BUF_SIZE ); + PushText( buf, false ); +} + + +void XMLPrinter::PushText( float value ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( value, buf, BUF_SIZE ); + PushText( buf, false ); +} + + +void XMLPrinter::PushText( double value ) +{ + char buf[BUF_SIZE]; + XMLUtil::ToStr( value, buf, BUF_SIZE ); + PushText( buf, false ); +} + + +void XMLPrinter::PushComment( const char* comment ) +{ + SealElementIfJustOpened(); + if ( _textDepth < 0 && !_firstElement && !_compactMode) { + Putc( '\n' ); + PrintSpace( _depth ); + } + _firstElement = false; + + Write( "<!--" ); + Write( comment ); + Write( "-->" ); +} + + +void XMLPrinter::PushDeclaration( const char* value ) +{ + SealElementIfJustOpened(); + if ( _textDepth < 0 && !_firstElement && !_compactMode) { + Putc( '\n' ); + PrintSpace( _depth ); + } + _firstElement = false; + + Write( "<?" ); + Write( value ); + Write( "?>" ); +} + + +void XMLPrinter::PushUnknown( const char* value ) +{ + SealElementIfJustOpened(); + if ( _textDepth < 0 && !_firstElement && !_compactMode) { + Putc( '\n' ); + PrintSpace( _depth ); + } + _firstElement = false; + + Write( "<!" ); + Write( value ); + Putc( '>' ); +} + + +bool XMLPrinter::VisitEnter( const XMLDocument& doc ) +{ + _processEntities = doc.ProcessEntities(); + if ( doc.HasBOM() ) { + PushHeader( true, false ); + } + return true; +} + + +bool XMLPrinter::VisitEnter( const XMLElement& element, const XMLAttribute* attribute ) +{ + const XMLElement* parentElem = 0; + if ( element.Parent() ) { + parentElem = element.Parent()->ToElement(); + } + const bool compactMode = parentElem ? CompactMode( *parentElem ) : _compactMode; + OpenElement( element.Name(), compactMode ); + while ( attribute ) { + PushAttribute( attribute->Name(), attribute->Value() ); + attribute = attribute->Next(); + } + return true; +} + + +bool XMLPrinter::VisitExit( const XMLElement& element ) +{ + CloseElement( CompactMode(element) ); + return true; +} + + +bool XMLPrinter::Visit( const XMLText& text ) +{ + PushText( text.Value(), text.CData() ); + return true; +} + + +bool XMLPrinter::Visit( const XMLComment& comment ) +{ + PushComment( comment.Value() ); + return true; +} + +bool XMLPrinter::Visit( const XMLDeclaration& declaration ) +{ + PushDeclaration( declaration.Value() ); + return true; +} + + +bool XMLPrinter::Visit( const XMLUnknown& unknown ) +{ + PushUnknown( unknown.Value() ); + return true; +} + +} // namespace tinyxml2 +} // namespace External +} // namespace Aws
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpClient.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpClient.cpp new file mode 100644 index 0000000000..8542023393 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpClient.cpp @@ -0,0 +1,49 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/HttpClient.h> +#include <aws/core/http/HttpRequest.h> + +using namespace Aws; +using namespace Aws::Http; + +HttpClient::HttpClient() : + m_disableRequestProcessing( false ), + m_requestProcessingSignalLock(), + m_requestProcessingSignal() +{ +} + +void HttpClient::DisableRequestProcessing() +{ + m_disableRequestProcessing = true; + m_requestProcessingSignal.notify_all(); +} + +void HttpClient::EnableRequestProcessing() +{ + m_disableRequestProcessing = false; +} + +bool HttpClient::IsRequestProcessingEnabled() const +{ + return m_disableRequestProcessing.load() == false; +} + +void HttpClient::RetryRequestSleep(std::chrono::milliseconds sleepTime) +{ + std::unique_lock< std::mutex > signalLocker(m_requestProcessingSignalLock); + m_requestProcessingSignal.wait_for(signalLocker, sleepTime, [this](){ return m_disableRequestProcessing.load() == true; }); +} + +bool HttpClient::ContinueRequest(const Aws::Http::HttpRequest& request) const +{ + if (request.GetContinueRequestHandler()) + { + return request.GetContinueRequestHandler()(&request); + } + + return true; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp new file mode 100644 index 0000000000..a556e39a5d --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp @@ -0,0 +1,203 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/HttpClientFactory.h> + +#if ENABLE_CURL_CLIENT +#include <aws/core/http/curl/CurlHttpClient.h> +#include <signal.h> + +#elif ENABLE_WINDOWS_CLIENT +#include <aws/core/client/ClientConfiguration.h> +#if ENABLE_WINDOWS_IXML_HTTP_REQUEST_2_CLIENT +#error #include <aws/core/http/windows/IXmlHttpRequest2HttpClient.h> +#if BYPASS_DEFAULT_PROXY +#error #include <aws/core/http/windows/WinHttpSyncHttpClient.h> +#endif +#else +#error #include <aws/core/http/windows/WinINetSyncHttpClient.h> +#error #include <aws/core/http/windows/WinHttpSyncHttpClient.h> +#endif +#endif + +#include <aws/core/http/standard/StandardHttpRequest.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <cassert> + +using namespace Aws::Client; +using namespace Aws::Http; +using namespace Aws::Utils::Logging; + +namespace Aws +{ + namespace Http + { + static std::shared_ptr<HttpClientFactory>& GetHttpClientFactory() + { + static std::shared_ptr<HttpClientFactory> s_HttpClientFactory(nullptr); + return s_HttpClientFactory; + } + static bool s_InitCleanupCurlFlag(false); + static bool s_InstallSigPipeHandler(false); + + static const char* HTTP_CLIENT_FACTORY_ALLOCATION_TAG = "HttpClientFactory"; + +#if ENABLE_CURL_CLIENT && !defined(_WIN32) + static void LogAndSwallowHandler(int signal) + { + switch(signal) + { + case SIGPIPE: + AWS_LOGSTREAM_ERROR(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "Received a SIGPIPE error"); + break; + default: + AWS_LOGSTREAM_ERROR(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "Unhandled system SIGNAL error" << signal); + } + } +#endif + + class DefaultHttpClientFactory : public HttpClientFactory + { + std::shared_ptr<HttpClient> CreateHttpClient(const ClientConfiguration& clientConfiguration) const override + { + // Figure out whether the selected option is available but fail gracefully and return a default of some type if not + // Windows clients: Http and Inet are always options, Curl MIGHT be an option if USE_CURL_CLIENT is on, and http is "default" + // Other clients: Curl is your default +#if ENABLE_WINDOWS_CLIENT +#if ENABLE_WINDOWS_IXML_HTTP_REQUEST_2_CLIENT +#if BYPASS_DEFAULT_PROXY + switch (clientConfiguration.httpLibOverride) + { + case TransferLibType::WIN_HTTP_CLIENT: + AWS_LOGSTREAM_INFO(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "Creating WinHTTP http client."); + return Aws::MakeShared<WinHttpSyncHttpClient>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration); + case TransferLibType::WIN_INET_CLIENT: + AWS_LOGSTREAM_WARN(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "WinINet http client is not supported with the current build configuration."); + // fall-through + default: + AWS_LOGSTREAM_INFO(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "Creating IXMLHttpRequest http client."); + return Aws::MakeShared<IXmlHttpRequest2HttpClient>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration); + } +#else + return Aws::MakeShared<IXmlHttpRequest2HttpClient>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration); +#endif // BYPASS_DEFAULT_PROXY +#else + switch (clientConfiguration.httpLibOverride) + { + case TransferLibType::WIN_INET_CLIENT: + return Aws::MakeShared<WinINetSyncHttpClient>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration); + + default: + return Aws::MakeShared<WinHttpSyncHttpClient>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration); + } +#endif // ENABLE_WINDOWS_IXML_HTTP_REQUEST_2_CLIENT +#elif ENABLE_CURL_CLIENT + return Aws::MakeShared<CurlHttpClient>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration); +#else + // When neither of these clients is enabled, gcc gives a warning (converted + // to error by -Werror) about the unused clientConfiguration parameter. We + // prevent that warning with AWS_UNREFERENCED_PARAM. + AWS_UNREFERENCED_PARAM(clientConfiguration); + AWS_LOGSTREAM_WARN(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "SDK was built without an Http implementation, default http client factory can't create an Http client instance."); + return nullptr; +#endif + } + + std::shared_ptr<HttpRequest> CreateHttpRequest(const Aws::String &uri, HttpMethod method, + const Aws::IOStreamFactory &streamFactory) const override + { + return CreateHttpRequest(URI(uri), method, streamFactory); + } + + std::shared_ptr<HttpRequest> CreateHttpRequest(const URI& uri, HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override + { + auto request = Aws::MakeShared<Standard::StandardHttpRequest>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, uri, method); + request->SetResponseStreamFactory(streamFactory); + + return request; + } + + void InitStaticState() override + { +#if ENABLE_CURL_CLIENT + if(s_InitCleanupCurlFlag) + { + CurlHttpClient::InitGlobalState(); + } +#if !defined (_WIN32) + if(s_InstallSigPipeHandler) + { + ::signal(SIGPIPE, LogAndSwallowHandler); + } +#endif +#elif ENABLE_WINDOWS_IXML_HTTP_REQUEST_2_CLIENT + IXmlHttpRequest2HttpClient::InitCOM(); +#endif + } + + virtual void CleanupStaticState() override + { +#if ENABLE_CURL_CLIENT + if(s_InitCleanupCurlFlag) + { + CurlHttpClient::CleanupGlobalState(); + } +#endif + } + }; + + void SetInitCleanupCurlFlag(bool initCleanupFlag) + { + s_InitCleanupCurlFlag = initCleanupFlag; + } + + void SetInstallSigPipeHandlerFlag(bool install) + { + s_InstallSigPipeHandler = install; + } + + void InitHttp() + { + if(!GetHttpClientFactory()) + { + GetHttpClientFactory() = Aws::MakeShared<DefaultHttpClientFactory>(HTTP_CLIENT_FACTORY_ALLOCATION_TAG); + } + GetHttpClientFactory()->InitStaticState(); + } + + void CleanupHttp() + { + if(GetHttpClientFactory()) + { + GetHttpClientFactory()->CleanupStaticState(); + GetHttpClientFactory() = nullptr; + } + } + + void SetHttpClientFactory(const std::shared_ptr<HttpClientFactory>& factory) + { + CleanupHttp(); + GetHttpClientFactory() = factory; + } + + std::shared_ptr<HttpClient> CreateHttpClient(const Aws::Client::ClientConfiguration& clientConfiguration) + { + assert(GetHttpClientFactory()); + return GetHttpClientFactory()->CreateHttpClient(clientConfiguration); + } + + std::shared_ptr<HttpRequest> CreateHttpRequest(const Aws::String& uri, HttpMethod method, const Aws::IOStreamFactory& streamFactory) + { + assert(GetHttpClientFactory()); + return GetHttpClientFactory()->CreateHttpRequest(uri, method, streamFactory); + } + + std::shared_ptr<HttpRequest> CreateHttpRequest(const URI& uri, HttpMethod method, const Aws::IOStreamFactory& streamFactory) + { + assert(GetHttpClientFactory()); + return GetHttpClientFactory()->CreateHttpRequest(uri, method, streamFactory); + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpRequest.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpRequest.cpp new file mode 100644 index 0000000000..95cb626c22 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpRequest.cpp @@ -0,0 +1,40 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/HttpRequest.h> + +namespace Aws +{ +namespace Http +{ + +const char DATE_HEADER[] = "date"; +const char AWS_DATE_HEADER[] = "X-Amz-Date"; +const char AWS_SECURITY_TOKEN[] = "X-Amz-Security-Token"; +const char ACCEPT_HEADER[] = "accept"; +const char ACCEPT_CHAR_SET_HEADER[] = "accept-charset"; +const char ACCEPT_ENCODING_HEADER[] = "accept-encoding"; +const char AUTHORIZATION_HEADER[] = "authorization"; +const char AWS_AUTHORIZATION_HEADER[] = "authorization"; +const char COOKIE_HEADER[] = "cookie"; +const char CONTENT_LENGTH_HEADER[] = "content-length"; +const char CONTENT_TYPE_HEADER[] = "content-type"; +const char TRANSFER_ENCODING_HEADER[] = "transfer-encoding"; +const char USER_AGENT_HEADER[] = "user-agent"; +const char VIA_HEADER[] = "via"; +const char HOST_HEADER[] = "host"; +const char AMZ_TARGET_HEADER[] = "x-amz-target"; +const char X_AMZ_EXPIRES_HEADER[] = "X-Amz-Expires"; +const char CONTENT_MD5_HEADER[] = "content-md5"; +const char API_VERSION_HEADER[] = "x-amz-api-version"; +const char SDK_INVOCATION_ID_HEADER[] = "amz-sdk-invocation-id"; +const char SDK_REQUEST_HEADER[] = "amz-sdk-request"; +const char CHUNKED_VALUE[] = "chunked"; + +} // Http +} // Aws + + + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpTypes.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpTypes.cpp new file mode 100644 index 0000000000..4d313e52f3 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/HttpTypes.cpp @@ -0,0 +1,42 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/HttpTypes.h> +#include <cassert> + +using namespace Aws::Http; + +namespace Aws +{ +namespace Http +{ + +namespace HttpMethodMapper +{ +const char* GetNameForHttpMethod(HttpMethod httpMethod) +{ + switch (httpMethod) + { + case HttpMethod::HTTP_GET: + return "GET"; + case HttpMethod::HTTP_POST: + return "POST"; + case HttpMethod::HTTP_DELETE: + return "DELETE"; + case HttpMethod::HTTP_PUT: + return "PUT"; + case HttpMethod::HTTP_HEAD: + return "HEAD"; + case HttpMethod::HTTP_PATCH: + return "PATCH"; + default: + assert(0); + return "GET"; + } +} + +} // namespace HttpMethodMapper +} // namespace Http +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/Scheme.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/Scheme.cpp new file mode 100644 index 0000000000..5dcea06aab --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/Scheme.cpp @@ -0,0 +1,54 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/Scheme.h> +#include <aws/core/utils/memory/stl/AWSString.h> +#include <aws/core/utils/StringUtils.h> + +using namespace Aws::Http; +using namespace Aws::Utils; + +namespace Aws +{ +namespace Http +{ +namespace SchemeMapper +{ + + const char* ToString(Scheme scheme) + { + switch (scheme) + { + case Scheme::HTTP: + return "http"; + case Scheme::HTTPS: + return "https"; + default: + return "http"; + } + } + + Scheme FromString(const char* name) + { + Aws::String trimmedString = StringUtils::Trim(name); + Aws::String loweredTrimmedString = StringUtils::ToLower(trimmedString.c_str()); + + if (loweredTrimmedString == "http") + { + return Scheme::HTTP; + } + //this branch is technically unneeded, but it is here so we don't have a subtle bug + //creep in as we extend this enum. + else if (loweredTrimmedString == "https") + { + return Scheme::HTTPS; + } + + return Scheme::HTTPS; + } + +} // namespace SchemeMapper +} // namespace Http +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/URI.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/URI.cpp new file mode 100644 index 0000000000..a2239df54b --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/URI.cpp @@ -0,0 +1,510 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/URI.h> + +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/memory/stl/AWSSet.h> + +#include <cstdlib> +#include <cctype> +#include <cassert> +#include <algorithm> +#include <iomanip> + +using namespace Aws::Http; +using namespace Aws::Utils; + +namespace Aws +{ +namespace Http +{ + +const char* SEPARATOR = "://"; + +} // namespace Http +} // namespace Aws + +URI::URI() : m_scheme(Scheme::HTTP), m_port(HTTP_DEFAULT_PORT) +{ +} + +URI::URI(const Aws::String& uri) : m_scheme(Scheme::HTTP), m_port(HTTP_DEFAULT_PORT) +{ + ParseURIParts(uri); +} + +URI::URI(const char* uri) : m_scheme(Scheme::HTTP), m_port(HTTP_DEFAULT_PORT) +{ + ParseURIParts(uri); +} + +URI& URI::operator =(const Aws::String& uri) +{ + this->ParseURIParts(uri); + return *this; +} + +URI& URI::operator =(const char* uri) +{ + this->ParseURIParts(uri); + return *this; +} + +bool URI::operator ==(const URI& other) const +{ + return CompareURIParts(other); +} + +bool URI::operator ==(const Aws::String& other) const +{ + return CompareURIParts(other); +} + +bool URI::operator ==(const char* other) const +{ + return CompareURIParts(other); +} + +bool URI::operator !=(const URI& other) const +{ + return !(*this == other); +} + +bool URI::operator !=(const Aws::String& other) const +{ + return !(*this == other); +} + +bool URI::operator !=(const char* other) const +{ + return !(*this == other); +} + +void URI::SetScheme(Scheme value) +{ + assert(value == Scheme::HTTP || value == Scheme::HTTPS); + + if (value == Scheme::HTTP) + { + m_port = m_port == HTTPS_DEFAULT_PORT || m_port == 0 ? HTTP_DEFAULT_PORT : m_port; + m_scheme = value; + } + else if (value == Scheme::HTTPS) + { + m_port = m_port == HTTP_DEFAULT_PORT || m_port == 0 ? HTTPS_DEFAULT_PORT : m_port; + m_scheme = value; + } +} + +Aws::String URI::URLEncodePathRFC3986(const Aws::String& path) +{ + if(path.empty()) + { + return path; + } + + const Aws::Vector<Aws::String> pathParts = StringUtils::Split(path, '/'); + Aws::StringStream ss; + ss << std::hex << std::uppercase; + + // escape characters appearing in a URL path according to RFC 3986 + for (const auto& segment : pathParts) + { + ss << '/'; + for(unsigned char c : segment) // alnum results in UB if the value of c is not unsigned char & is not EOF + { + // §2.3 unreserved characters + if (StringUtils::IsAlnum(c)) + { + ss << c; + continue; + } + switch(c) + { + // §2.3 unreserved characters + case '-': case '_': case '.': case '~': + // The path section of the URL allow reserved characters to appear unescaped + // RFC 3986 §2.2 Reserved characters + // NOTE: this implementation does not accurately implement the RFC on purpose to accommodate for + // discrepancies in the implementations of URL encoding between AWS services for legacy reasons. + case '$': case '&': case ',': + case ':': case '=': case '@': + ss << c; + break; + default: + ss << '%' << std::setfill('0') << std::setw(2) << (int)((unsigned char)c) << std::setw(0); + } + } + } + + //if the last character was also a slash, then add that back here. + if (path.back() == '/') + { + ss << '/'; + } + + return ss.str(); +} + +Aws::String URI::URLEncodePath(const Aws::String& path) +{ + Aws::Vector<Aws::String> pathParts = StringUtils::Split(path, '/'); + Aws::StringStream ss; + + for (Aws::Vector<Aws::String>::iterator iter = pathParts.begin(); iter != pathParts.end(); ++iter) + { + ss << '/' << StringUtils::URLEncode(iter->c_str()); + } + + //if the last character was also a slash, then add that back here. + if (path.length() > 0 && path[path.length() - 1] == '/') + { + ss << '/'; + } + + if (path.length() > 0 && path[0] != '/') + { + return ss.str().substr(1); + } + else + { + return ss.str(); + } +} + +void URI::SetPath(const Aws::String& value) +{ + const Aws::Vector<Aws::String> pathParts = StringUtils::Split(value, '/'); + Aws::String path; + path.reserve(value.length() + 1/* in case we have to append slash before the path. */); + + for (const auto& segment : pathParts) + { + path.push_back('/'); + path.append(segment); + } + + if (value.back() == '/') + { + path.push_back('/'); + } + m_path = std::move(path); +} + +//ugh, this isn't even part of the canonicalization spec. It is part of how our services have implemented their signers though.... +//it doesn't really hurt anything to reorder it though, so go ahead and sort the values for parameters with the same key +void InsertValueOrderedParameter(QueryStringParameterCollection& queryParams, const Aws::String& key, const Aws::String& value) +{ + auto entriesAtKey = queryParams.equal_range(key); + for (auto& entry = entriesAtKey.first; entry != entriesAtKey.second; ++entry) + { + if (entry->second > value) + { + queryParams.emplace_hint(entry, key, value); + return; + } + } + + queryParams.emplace(key, value); +} + +QueryStringParameterCollection URI::GetQueryStringParameters(bool decode) const +{ + Aws::String queryString = GetQueryString(); + + QueryStringParameterCollection parameterCollection; + + //if we actually have a query string + if (queryString.size() > 0) + { + size_t currentPos = 1, locationOfNextDelimiter = 1; + + //while we have params left to parse + while (currentPos < queryString.size()) + { + //find next key/value pair + locationOfNextDelimiter = queryString.find('&', currentPos); + + Aws::String keyValuePair; + + //if this isn't the last parameter + if (locationOfNextDelimiter != Aws::String::npos) + { + keyValuePair = queryString.substr(currentPos, locationOfNextDelimiter - currentPos); + } + //if it is the last parameter + else + { + keyValuePair = queryString.substr(currentPos); + } + + //split on = + size_t locationOfEquals = keyValuePair.find('='); + Aws::String key = keyValuePair.substr(0, locationOfEquals); + Aws::String value = keyValuePair.substr(locationOfEquals + 1); + + if(decode) + { + InsertValueOrderedParameter(parameterCollection, StringUtils::URLDecode(key.c_str()), StringUtils::URLDecode(value.c_str())); + } + else + { + InsertValueOrderedParameter(parameterCollection, key, value); + } + + currentPos += keyValuePair.size() + 1; + } + } + + return parameterCollection; +} + +void URI::CanonicalizeQueryString() +{ + QueryStringParameterCollection sortedParameters = GetQueryStringParameters(false); + Aws::StringStream queryStringStream; + + bool first = true; + + if(sortedParameters.size() > 0) + { + queryStringStream << "?"; + } + + if(m_queryString.find('=') != std::string::npos) + { + for (QueryStringParameterCollection::iterator iter = sortedParameters.begin(); + iter != sortedParameters.end(); ++iter) + { + if (!first) + { + queryStringStream << "&"; + } + + first = false; + queryStringStream << iter->first.c_str() << "=" << iter->second.c_str(); + } + + m_queryString = queryStringStream.str(); + } +} + +void URI::AddQueryStringParameter(const char* key, const Aws::String& value) +{ + if (m_queryString.size() <= 0) + { + m_queryString.append("?"); + } + else + { + m_queryString.append("&"); + } + + m_queryString.append(StringUtils::URLEncode(key) + "=" + StringUtils::URLEncode(value.c_str())); +} + +void URI::AddQueryStringParameter(const Aws::Map<Aws::String, Aws::String>& queryStringPairs) +{ + for(const auto& entry: queryStringPairs) + { + AddQueryStringParameter(entry.first.c_str(), entry.second); + } +} + +void URI::SetQueryString(const Aws::String& str) +{ + m_queryString = ""; + + if (str.empty()) return; + + if (str.front() != '?') + { + m_queryString.append("?").append(str); + } + else + { + m_queryString = str; + } +} + +Aws::String URI::GetURIString(bool includeQueryString) const +{ + assert(m_authority.size() > 0); + + Aws::StringStream ss; + ss << SchemeMapper::ToString(m_scheme) << SEPARATOR << m_authority; + + if (m_scheme == Scheme::HTTP && m_port != HTTP_DEFAULT_PORT) + { + ss << ":" << m_port; + } + else if (m_scheme == Scheme::HTTPS && m_port != HTTPS_DEFAULT_PORT) + { + ss << ":" << m_port; + } + + if(m_path != "/") + { + ss << URLEncodePathRFC3986(m_path); + } + + if(includeQueryString) + { + ss << m_queryString; + } + + return ss.str(); +} + +void URI::ParseURIParts(const Aws::String& uri) +{ + ExtractAndSetScheme(uri); + ExtractAndSetAuthority(uri); + ExtractAndSetPort(uri); + ExtractAndSetPath(uri); + ExtractAndSetQueryString(uri); +} + +void URI::ExtractAndSetScheme(const Aws::String& uri) +{ + size_t posOfSeparator = uri.find(SEPARATOR); + + if (posOfSeparator != Aws::String::npos) + { + Aws::String schemePortion = uri.substr(0, posOfSeparator); + SetScheme(SchemeMapper::FromString(schemePortion.c_str())); + } + else + { + SetScheme(Scheme::HTTP); + } +} + +void URI::ExtractAndSetAuthority(const Aws::String& uri) +{ + size_t authorityStart = uri.find(SEPARATOR); + + if (authorityStart == Aws::String::npos) + { + authorityStart = 0; + } + else + { + authorityStart += 3; + } + + size_t posOfEndOfAuthorityPort = uri.find(':', authorityStart); + size_t posOfEndOfAuthoritySlash = uri.find('/', authorityStart); + size_t posOfEndOfAuthorityQuery = uri.find('?', authorityStart); + size_t posEndOfAuthority = (std::min)({posOfEndOfAuthorityPort, posOfEndOfAuthoritySlash, posOfEndOfAuthorityQuery}); + if (posEndOfAuthority == Aws::String::npos) + { + posEndOfAuthority = uri.length(); + } + + SetAuthority(uri.substr(authorityStart, posEndOfAuthority - authorityStart)); +} + +void URI::ExtractAndSetPort(const Aws::String& uri) +{ + size_t authorityStart = uri.find(SEPARATOR); + + if(authorityStart == Aws::String::npos) + { + authorityStart = 0; + } + else + { + authorityStart += 3; + } + + size_t positionOfPortDelimiter = uri.find(':', authorityStart); + + bool hasPort = positionOfPortDelimiter != Aws::String::npos; + + if ((uri.find('/', authorityStart) < positionOfPortDelimiter) || (uri.find('?', authorityStart) < positionOfPortDelimiter)) + { + hasPort = false; + } + + if (hasPort) + { + Aws::String strPort; + + size_t i = positionOfPortDelimiter + 1; + char currentDigit = uri[i]; + + while (std::isdigit(currentDigit)) + { + strPort += currentDigit; + currentDigit = uri[++i]; + } + + SetPort(static_cast<uint16_t>(atoi(strPort.c_str()))); + } +} + +void URI::ExtractAndSetPath(const Aws::String& uri) +{ + size_t authorityStart = uri.find(SEPARATOR); + + if (authorityStart == Aws::String::npos) + { + authorityStart = 0; + } + else + { + authorityStart += 3; + } + + size_t pathEnd = uri.find('?'); + + if (pathEnd == Aws::String::npos) + { + pathEnd = uri.length(); + } + + Aws::String authorityAndPath = uri.substr(authorityStart, pathEnd - authorityStart); + + size_t pathStart = authorityAndPath.find('/'); + + if (pathStart != Aws::String::npos) + { + SetPath(authorityAndPath.substr(pathStart, pathEnd - pathStart)); + } + else + { + SetPath("/"); + } +} + +void URI::ExtractAndSetQueryString(const Aws::String& uri) +{ + size_t queryStart = uri.find('?'); + + if (queryStart != Aws::String::npos) + { + m_queryString = uri.substr(queryStart); + } +} + +Aws::String URI::GetFormParameters() const +{ + if(m_queryString.length() == 0) + { + return ""; + } + else + { + return m_queryString.substr(1); + } +} + +bool URI::CompareURIParts(const URI& other) const +{ + return m_scheme == other.m_scheme && m_authority == other.m_authority && m_path == other.m_path && m_queryString == other.m_queryString; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/curl/CurlHandleContainer.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/curl/CurlHandleContainer.cpp new file mode 100644 index 0000000000..1a965cd795 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/curl/CurlHandleContainer.cpp @@ -0,0 +1,153 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/curl/CurlHandleContainer.h> +#include <aws/core/utils/logging/LogMacros.h> + +#include <algorithm> + +using namespace Aws::Utils::Logging; +using namespace Aws::Http; + +static const char* CURL_HANDLE_CONTAINER_TAG = "CurlHandleContainer"; + + +CurlHandleContainer::CurlHandleContainer(unsigned maxSize, long httpRequestTimeout, long connectTimeout, bool enableTcpKeepAlive, + unsigned long tcpKeepAliveIntervalMs, long lowSpeedTime, unsigned long lowSpeedLimit) : + m_maxPoolSize(maxSize), m_httpRequestTimeout(httpRequestTimeout), m_connectTimeout(connectTimeout), m_enableTcpKeepAlive(enableTcpKeepAlive), + m_tcpKeepAliveIntervalMs(tcpKeepAliveIntervalMs), m_lowSpeedTime(lowSpeedTime), m_lowSpeedLimit(lowSpeedLimit), m_poolSize(0) +{ + AWS_LOGSTREAM_INFO(CURL_HANDLE_CONTAINER_TAG, "Initializing CurlHandleContainer with size " << maxSize); +} + +CurlHandleContainer::~CurlHandleContainer() +{ + AWS_LOGSTREAM_INFO(CURL_HANDLE_CONTAINER_TAG, "Cleaning up CurlHandleContainer."); + for (CURL* handle : m_handleContainer.ShutdownAndWait(m_poolSize)) + { + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Cleaning up " << handle); + curl_easy_cleanup(handle); + } +} + +CURL* CurlHandleContainer::AcquireCurlHandle() +{ + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Attempting to acquire curl connection."); + + if(!m_handleContainer.HasResourcesAvailable()) + { + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "No current connections available in pool. Attempting to create new connections."); + CheckAndGrowPool(); + } + + CURL* handle = m_handleContainer.Acquire(); + AWS_LOGSTREAM_INFO(CURL_HANDLE_CONTAINER_TAG, "Connection has been released. Continuing."); + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Returning connection handle " << handle); + return handle; +} + +void CurlHandleContainer::ReleaseCurlHandle(CURL* handle) +{ + if (handle) + { + curl_easy_reset(handle); + SetDefaultOptionsOnHandle(handle); + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Releasing curl handle " << handle); + m_handleContainer.Release(handle); + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Notified waiting threads."); + } +} + +void CurlHandleContainer::DestroyCurlHandle(CURL* handle) +{ + if (!handle) + { + return; + } + + curl_easy_cleanup(handle); + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Destroy curl handle: " << handle); + { + std::lock_guard<std::mutex> locker(m_containerLock); + // Other threads could be blocked and waiting on m_handleContainer.Acquire() + // If the handle is not released back to the pool, it could create a deadlock + // Create a new handle and release that into the pool + handle = CreateCurlHandleInPool(); + } + if (handle) + { + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "Created replacement handle and released to pool: " << handle); + } +} + + +CURL* CurlHandleContainer::CreateCurlHandleInPool() +{ + CURL* curlHandle = curl_easy_init(); + + if (curlHandle) + { + SetDefaultOptionsOnHandle(curlHandle); + m_handleContainer.Release(curlHandle); + } + else + { + AWS_LOGSTREAM_ERROR(CURL_HANDLE_CONTAINER_TAG, "curl_easy_init failed to allocate."); + } + return curlHandle; +} + +bool CurlHandleContainer::CheckAndGrowPool() +{ + std::lock_guard<std::mutex> locker(m_containerLock); + if (m_poolSize < m_maxPoolSize) + { + unsigned multiplier = m_poolSize > 0 ? m_poolSize : 1; + unsigned amountToAdd = (std::min)(multiplier * 2, m_maxPoolSize - m_poolSize); + AWS_LOGSTREAM_DEBUG(CURL_HANDLE_CONTAINER_TAG, "attempting to grow pool size by " << amountToAdd); + + unsigned actuallyAdded = 0; + for (unsigned i = 0; i < amountToAdd; ++i) + { + CURL* curlHandle = CreateCurlHandleInPool(); + + if (curlHandle) + { + ++actuallyAdded; + } + else + { + break; + } + } + + AWS_LOGSTREAM_INFO(CURL_HANDLE_CONTAINER_TAG, "Pool grown by " << actuallyAdded); + m_poolSize += actuallyAdded; + + return actuallyAdded > 0; + } + + AWS_LOGSTREAM_INFO(CURL_HANDLE_CONTAINER_TAG, "Pool cannot be grown any further, already at max size."); + + return false; +} + +void CurlHandleContainer::SetDefaultOptionsOnHandle(CURL* handle) +{ + //for timeouts to work in a multi-threaded context, + //always turn signals off. This also forces dns queries to + //not be included in the timeout calculations. + curl_easy_setopt(handle, CURLOPT_NOSIGNAL, 1L); + curl_easy_setopt(handle, CURLOPT_TIMEOUT_MS, m_httpRequestTimeout); + curl_easy_setopt(handle, CURLOPT_CONNECTTIMEOUT_MS, m_connectTimeout); + curl_easy_setopt(handle, CURLOPT_LOW_SPEED_LIMIT, m_lowSpeedLimit); + curl_easy_setopt(handle, CURLOPT_LOW_SPEED_TIME, m_lowSpeedTime < 1000 ? (m_lowSpeedTime == 0 ? 0 : 1) : m_lowSpeedTime / 1000); + curl_easy_setopt(handle, CURLOPT_TCP_KEEPALIVE, m_enableTcpKeepAlive ? 1L : 0L); + curl_easy_setopt(handle, CURLOPT_TCP_KEEPINTVL, m_tcpKeepAliveIntervalMs / 1000); + curl_easy_setopt(handle, CURLOPT_TCP_KEEPIDLE, m_tcpKeepAliveIntervalMs / 1000); +#ifdef CURL_HAS_H2 + curl_easy_setopt(handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_0); +#endif +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp new file mode 100644 index 0000000000..2fb9cc9643 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp @@ -0,0 +1,730 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/curl/CurlHttpClient.h> +#include <aws/core/http/HttpRequest.h> +#include <aws/core/http/standard/StandardHttpResponse.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/ratelimiter/RateLimiterInterface.h> +#include <aws/core/utils/DateTime.h> +#include <aws/core/monitoring/HttpClientMetrics.h> +#include <cassert> +#include <algorithm> + + +using namespace Aws::Client; +using namespace Aws::Http; +using namespace Aws::Http::Standard; +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; +using namespace Aws::Monitoring; + +#ifdef AWS_CUSTOM_MEMORY_MANAGEMENT + +static const char* MemTag = "libcurl"; +static size_t offset = sizeof(size_t); + +static void* malloc_callback(size_t size) +{ + char* newMem = reinterpret_cast<char*>(Aws::Malloc(MemTag, size + offset)); + std::size_t* pointerToSize = reinterpret_cast<std::size_t*>(newMem); + *pointerToSize = size; + return reinterpret_cast<void*>(newMem + offset); +} + +static void free_callback(void* ptr) +{ + if(ptr) + { + char* shiftedMemory = reinterpret_cast<char*>(ptr); + Aws::Free(shiftedMemory - offset); + } +} + +static void* realloc_callback(void* ptr, size_t size) +{ + if(!ptr) + { + return malloc_callback(size); + } + + + if(!size && ptr) + { + free_callback(ptr); + return nullptr; + } + + char* originalLenCharPtr = reinterpret_cast<char*>(ptr) - offset; + size_t originalLen = *reinterpret_cast<size_t*>(originalLenCharPtr); + + char* rawMemory = reinterpret_cast<char*>(Aws::Malloc(MemTag, size + offset)); + if(rawMemory) + { + std::size_t* pointerToSize = reinterpret_cast<std::size_t*>(rawMemory); + *pointerToSize = size; + + size_t copyLength = (std::min)(originalLen, size); +#ifdef _MSC_VER + memcpy_s(rawMemory + offset, size, ptr, copyLength); +#else + memcpy(rawMemory + offset, ptr, copyLength); +#endif + free_callback(ptr); + return reinterpret_cast<void*>(rawMemory + offset); + } + else + { + return ptr; + } + +} + +static void* calloc_callback(size_t nmemb, size_t size) +{ + size_t dataSize = nmemb * size; + char* newMem = reinterpret_cast<char*>(Aws::Malloc(MemTag, dataSize + offset)); + std::size_t* pointerToSize = reinterpret_cast<std::size_t*>(newMem); + *pointerToSize = dataSize; +#ifdef _MSC_VER + memset_s(newMem + offset, dataSize, 0, dataSize); +#else + memset(newMem + offset, 0, dataSize); +#endif + + return reinterpret_cast<void*>(newMem + offset); +} + +static char* strdup_callback(const char* str) +{ + size_t len = strlen(str) + 1; + size_t newLen = len + offset; + char* newMem = reinterpret_cast<char*>(Aws::Malloc(MemTag, newLen)); + + if(newMem) + { + std::size_t* pointerToSize = reinterpret_cast<std::size_t*>(newMem); + *pointerToSize = len; +#ifdef _MSC_VER + memcpy_s(newMem + offset, len, str, len); +#else + memcpy(newMem + offset, str, len); +#endif + return newMem + offset; + } + return nullptr; +} + +#endif + +struct CurlWriteCallbackContext +{ + CurlWriteCallbackContext(const CurlHttpClient* client, + HttpRequest* request, + HttpResponse* response, + Aws::Utils::RateLimits::RateLimiterInterface* rateLimiter) : + m_client(client), + m_request(request), + m_response(response), + m_rateLimiter(rateLimiter), + m_numBytesResponseReceived(0) + {} + + const CurlHttpClient* m_client; + HttpRequest* m_request; + HttpResponse* m_response; + Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter; + int64_t m_numBytesResponseReceived; +}; + +struct CurlReadCallbackContext +{ + CurlReadCallbackContext(const CurlHttpClient* client, HttpRequest* request, Aws::Utils::RateLimits::RateLimiterInterface* limiter) : + m_client(client), + m_rateLimiter(limiter), + m_request(request) + {} + + const CurlHttpClient* m_client; + CURL* m_curlHandle; + Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter; + HttpRequest* m_request; +}; + +static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient"; + +static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata) +{ + if (ptr) + { + CurlWriteCallbackContext* context = reinterpret_cast<CurlWriteCallbackContext*>(userdata); + + const CurlHttpClient* client = context->m_client; + if(!client->ContinueRequest(*context->m_request) || !client->IsRequestProcessingEnabled()) + { + return 0; + } + + HttpResponse* response = context->m_response; + size_t sizeToWrite = size * nmemb; + if (context->m_rateLimiter) + { + context->m_rateLimiter->ApplyAndPayForCost(static_cast<int64_t>(sizeToWrite)); + } + + response->GetResponseBody().write(ptr, static_cast<std::streamsize>(sizeToWrite)); + if (context->m_request->IsEventStreamRequest()) + { + response->GetResponseBody().flush(); + } + auto& receivedHandler = context->m_request->GetDataReceivedEventHandler(); + if (receivedHandler) + { + receivedHandler(context->m_request, context->m_response, static_cast<long long>(sizeToWrite)); + } + + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, sizeToWrite << " bytes written to response."); + context->m_numBytesResponseReceived += sizeToWrite; + return sizeToWrite; + } + return 0; +} + +static size_t WriteHeader(char* ptr, size_t size, size_t nmemb, void* userdata) +{ + if (ptr) + { + CurlWriteCallbackContext* context = reinterpret_cast<CurlWriteCallbackContext*>(userdata); + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, ptr); + HttpResponse* response = context->m_response; + Aws::String headerLine(ptr); + Aws::Vector<Aws::String> keyValuePair = StringUtils::Split(headerLine, ':', 2); + + if (keyValuePair.size() == 2) + { + response->AddHeader(StringUtils::Trim(keyValuePair[0].c_str()), StringUtils::Trim(keyValuePair[1].c_str())); + } + + return size * nmemb; + } + return 0; +} + + +static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata) +{ + CurlReadCallbackContext* context = reinterpret_cast<CurlReadCallbackContext*>(userdata); + if(context == nullptr) + { + return 0; + } + + const CurlHttpClient* client = context->m_client; + if(!client->ContinueRequest(*context->m_request) || !client->IsRequestProcessingEnabled()) + { + return CURL_READFUNC_ABORT; + } + + HttpRequest* request = context->m_request; + const std::shared_ptr<Aws::IOStream>& ioStream = request->GetContentBody(); + + const size_t amountToRead = size * nmemb; + if (ioStream != nullptr && amountToRead > 0) + { + if (request->IsEventStreamRequest()) + { + // Waiting for next available character to read. + // Without peek(), readsome() will keep reading 0 byte from the stream. + ioStream->peek(); + ioStream->readsome(ptr, amountToRead); + } + else + { + ioStream->read(ptr, amountToRead); + } + size_t amountRead = static_cast<size_t>(ioStream->gcount()); + auto& sentHandler = request->GetDataSentEventHandler(); + if (sentHandler) + { + sentHandler(request, static_cast<long long>(amountRead)); + } + + if (context->m_rateLimiter) + { + context->m_rateLimiter->ApplyAndPayForCost(static_cast<int64_t>(amountRead)); + } + + return amountRead; + } + + return 0; +} + +static size_t SeekBody(void* userdata, curl_off_t offset, int origin) +{ + CurlReadCallbackContext* context = reinterpret_cast<CurlReadCallbackContext*>(userdata); + if(context == nullptr) + { + return CURL_SEEKFUNC_FAIL; + } + + const CurlHttpClient* client = context->m_client; + if(!client->ContinueRequest(*context->m_request) || !client->IsRequestProcessingEnabled()) + { + return CURL_SEEKFUNC_FAIL; + } + + HttpRequest* request = context->m_request; + const std::shared_ptr<Aws::IOStream>& ioStream = request->GetContentBody(); + + std::ios_base::seekdir dir; + switch(origin) + { + case SEEK_SET: + dir = std::ios_base::beg; + break; + case SEEK_CUR: + dir = std::ios_base::cur; + break; + case SEEK_END: + dir = std::ios_base::end; + break; + default: + return CURL_SEEKFUNC_FAIL; + } + + ioStream->clear(); + ioStream->seekg(offset, dir); + if (ioStream->fail()) { + return CURL_SEEKFUNC_CANTSEEK; + } + + return CURL_SEEKFUNC_OK; +} + +void SetOptCodeForHttpMethod(CURL* requestHandle, const std::shared_ptr<HttpRequest>& request) +{ + switch (request->GetMethod()) + { + case HttpMethod::HTTP_GET: + curl_easy_setopt(requestHandle, CURLOPT_HTTPGET, 1L); + break; + case HttpMethod::HTTP_POST: + if (request->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER) && request->GetHeaderValue(Aws::Http::CONTENT_LENGTH_HEADER) == "0") + { + curl_easy_setopt(requestHandle, CURLOPT_CUSTOMREQUEST, "POST"); + } + else + { + curl_easy_setopt(requestHandle, CURLOPT_POST, 1L); + } + break; + case HttpMethod::HTTP_PUT: + if ((!request->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER) || request->GetHeaderValue(Aws::Http::CONTENT_LENGTH_HEADER) == "0") && + !request->HasHeader(Aws::Http::TRANSFER_ENCODING_HEADER)) + { + curl_easy_setopt(requestHandle, CURLOPT_CUSTOMREQUEST, "PUT"); + } + else + { + curl_easy_setopt(requestHandle, CURLOPT_PUT, 1L); + } + break; + case HttpMethod::HTTP_HEAD: + curl_easy_setopt(requestHandle, CURLOPT_HTTPGET, 1L); + curl_easy_setopt(requestHandle, CURLOPT_NOBODY, 1L); + break; + case HttpMethod::HTTP_PATCH: + if ((!request->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)|| request->GetHeaderValue(Aws::Http::CONTENT_LENGTH_HEADER) == "0") && + !request->HasHeader(Aws::Http::TRANSFER_ENCODING_HEADER)) + { + curl_easy_setopt(requestHandle, CURLOPT_CUSTOMREQUEST, "PATCH"); + } + else + { + curl_easy_setopt(requestHandle, CURLOPT_POST, 1L); + curl_easy_setopt(requestHandle, CURLOPT_CUSTOMREQUEST, "PATCH"); + } + + break; + case HttpMethod::HTTP_DELETE: + curl_easy_setopt(requestHandle, CURLOPT_CUSTOMREQUEST, "DELETE"); + break; + default: + assert(0); + curl_easy_setopt(requestHandle, CURLOPT_CUSTOMREQUEST, "GET"); + break; + } +} + + +std::atomic<bool> CurlHttpClient::isInit(false); + +void CurlHttpClient::InitGlobalState() +{ + if (!isInit) + { + auto curlVersionData = curl_version_info(CURLVERSION_NOW); + AWS_LOGSTREAM_INFO(CURL_HTTP_CLIENT_TAG, "Initializing Curl library with version: " << curlVersionData->version + << ", ssl version: " << curlVersionData->ssl_version); + isInit = true; +#ifdef AWS_CUSTOM_MEMORY_MANAGEMENT + curl_global_init_mem(CURL_GLOBAL_ALL, &malloc_callback, &free_callback, &realloc_callback, &strdup_callback, &calloc_callback); +#else + curl_global_init(CURL_GLOBAL_ALL); +#endif + } +} + + +void CurlHttpClient::CleanupGlobalState() +{ + curl_global_cleanup(); +} + +Aws::String CurlInfoTypeToString(curl_infotype type) +{ + switch(type) + { + case CURLINFO_TEXT: + return "Text"; + + case CURLINFO_HEADER_IN: + return "HeaderIn"; + + case CURLINFO_HEADER_OUT: + return "HeaderOut"; + + case CURLINFO_DATA_IN: + return "DataIn"; + + case CURLINFO_DATA_OUT: + return "DataOut"; + + case CURLINFO_SSL_DATA_IN: + return "SSLDataIn"; + + case CURLINFO_SSL_DATA_OUT: + return "SSLDataOut"; + + default: + return "Unknown"; + } +} + +int CurlDebugCallback(CURL *handle, curl_infotype type, char *data, size_t size, void *userptr) +{ + AWS_UNREFERENCED_PARAM(handle); + AWS_UNREFERENCED_PARAM(userptr); + + if(type == CURLINFO_SSL_DATA_IN || type == CURLINFO_SSL_DATA_OUT) + { + AWS_LOGSTREAM_DEBUG("CURL", "(" << CurlInfoTypeToString(type) << ") " << size << "bytes"); + } + else + { + Aws::String debugString(data, size); + AWS_LOGSTREAM_DEBUG("CURL", "(" << CurlInfoTypeToString(type) << ") " << debugString); + } + + return 0; +} + + +CurlHttpClient::CurlHttpClient(const ClientConfiguration& clientConfig) : + Base(), + m_curlHandleContainer(clientConfig.maxConnections, clientConfig.httpRequestTimeoutMs, clientConfig.connectTimeoutMs, clientConfig.enableTcpKeepAlive, + clientConfig.tcpKeepAliveIntervalMs, clientConfig.requestTimeoutMs, clientConfig.lowSpeedLimit), + m_isUsingProxy(!clientConfig.proxyHost.empty()), m_proxyUserName(clientConfig.proxyUserName), + m_proxyPassword(clientConfig.proxyPassword), m_proxyScheme(SchemeMapper::ToString(clientConfig.proxyScheme)), m_proxyHost(clientConfig.proxyHost), + m_proxySSLCertPath(clientConfig.proxySSLCertPath), m_proxySSLCertType(clientConfig.proxySSLCertType), + m_proxySSLKeyPath(clientConfig.proxySSLKeyPath), m_proxySSLKeyType(clientConfig.proxySSLKeyType), + m_proxyKeyPasswd(clientConfig.proxySSLKeyPassword), + m_proxyPort(clientConfig.proxyPort), m_verifySSL(clientConfig.verifySSL), m_caPath(clientConfig.caPath), + m_caFile(clientConfig.caFile), m_proxyCaPath(clientConfig.proxyCaPath), m_proxyCaFile(clientConfig.proxyCaFile), + m_disableExpectHeader(clientConfig.disableExpectHeader) +{ + if (clientConfig.followRedirects == FollowRedirectsPolicy::NEVER || + (clientConfig.followRedirects == FollowRedirectsPolicy::DEFAULT && clientConfig.region == Aws::Region::AWS_GLOBAL)) + { + m_allowRedirects = false; + } + else + { + m_allowRedirects = true; + } +} + + +std::shared_ptr<HttpResponse> CurlHttpClient::MakeRequest(const std::shared_ptr<HttpRequest>& request, + Aws::Utils::RateLimits::RateLimiterInterface* readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter) const +{ + URI uri = request->GetUri(); + Aws::String url = uri.GetURIString(); + std::shared_ptr<HttpResponse> response = Aws::MakeShared<StandardHttpResponse>(CURL_HTTP_CLIENT_TAG, request); + + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, "Making request to " << url); + struct curl_slist* headers = NULL; + + if (writeLimiter != nullptr) + { + writeLimiter->ApplyAndPayForCost(request->GetSize()); + } + + Aws::StringStream headerStream; + HeaderValueCollection requestHeaders = request->GetHeaders(); + + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, "Including headers:"); + for (auto& requestHeader : requestHeaders) + { + headerStream.str(""); + headerStream << requestHeader.first << ": " << requestHeader.second; + Aws::String headerString = headerStream.str(); + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, headerString); + headers = curl_slist_append(headers, headerString.c_str()); + } + + if (!request->HasHeader(Http::TRANSFER_ENCODING_HEADER)) + { + headers = curl_slist_append(headers, "transfer-encoding:"); + } + + if (!request->HasHeader(Http::CONTENT_LENGTH_HEADER)) + { + headers = curl_slist_append(headers, "content-length:"); + } + + if (!request->HasHeader(Http::CONTENT_TYPE_HEADER)) + { + headers = curl_slist_append(headers, "content-type:"); + } + + // Discard Expect header so as to avoid using multiple payloads to send a http request (header + body) + if (m_disableExpectHeader) + { + headers = curl_slist_append(headers, "Expect:"); + } + + CURL* connectionHandle = m_curlHandleContainer.AcquireCurlHandle(); + + if (connectionHandle) + { + AWS_LOGSTREAM_DEBUG(CURL_HTTP_CLIENT_TAG, "Obtained connection handle " << connectionHandle); + + if (headers) + { + curl_easy_setopt(connectionHandle, CURLOPT_HTTPHEADER, headers); + } + + CurlWriteCallbackContext writeContext(this, request.get(), response.get(), readLimiter); + CurlReadCallbackContext readContext(this, request.get(), writeLimiter); + + SetOptCodeForHttpMethod(connectionHandle, request); + + curl_easy_setopt(connectionHandle, CURLOPT_URL, url.c_str()); + curl_easy_setopt(connectionHandle, CURLOPT_WRITEFUNCTION, WriteData); + curl_easy_setopt(connectionHandle, CURLOPT_WRITEDATA, &writeContext); + curl_easy_setopt(connectionHandle, CURLOPT_HEADERFUNCTION, WriteHeader); + curl_easy_setopt(connectionHandle, CURLOPT_HEADERDATA, &writeContext); + + //we only want to override the default path if someone has explicitly told us to. + if(!m_caPath.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_CAPATH, m_caPath.c_str()); + } + if(!m_caFile.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_CAINFO, m_caFile.c_str()); + } + + // only set by android test builds because the emulator is missing a cert needed for aws services +#ifdef TEST_CERT_PATH + curl_easy_setopt(connectionHandle, CURLOPT_CAPATH, TEST_CERT_PATH); +#endif // TEST_CERT_PATH + + if (m_verifySSL) + { + curl_easy_setopt(connectionHandle, CURLOPT_SSL_VERIFYPEER, 1L); + curl_easy_setopt(connectionHandle, CURLOPT_SSL_VERIFYHOST, 2L); + +#if LIBCURL_VERSION_MAJOR >= 7 +#if LIBCURL_VERSION_MINOR >= 34 + curl_easy_setopt(connectionHandle, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1); +#endif //LIBCURL_VERSION_MINOR +#endif //LIBCURL_VERSION_MAJOR + } + else + { + curl_easy_setopt(connectionHandle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(connectionHandle, CURLOPT_SSL_VERIFYHOST, 0L); + } + + if (m_allowRedirects) + { + curl_easy_setopt(connectionHandle, CURLOPT_FOLLOWLOCATION, 1L); + } + else + { + curl_easy_setopt(connectionHandle, CURLOPT_FOLLOWLOCATION, 0L); + } + +#ifdef ENABLE_CURL_LOGGING + curl_easy_setopt(connectionHandle, CURLOPT_VERBOSE, 1); + curl_easy_setopt(connectionHandle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); +#endif + if (m_isUsingProxy) + { + Aws::StringStream ss; + ss << m_proxyScheme << "://" << m_proxyHost; + curl_easy_setopt(connectionHandle, CURLOPT_PROXY, ss.str().c_str()); + curl_easy_setopt(connectionHandle, CURLOPT_PROXYPORT, (long) m_proxyPort); + if(!m_proxyCaPath.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_CAPATH, m_proxyCaPath.c_str()); + } + if(!m_proxyCaFile.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_CAINFO, m_proxyCaFile.c_str()); + } + if (!m_proxyUserName.empty() || !m_proxyPassword.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXYUSERNAME, m_proxyUserName.c_str()); + curl_easy_setopt(connectionHandle, CURLOPT_PROXYPASSWORD, m_proxyPassword.c_str()); + } +#ifdef CURL_HAS_TLS_PROXY + if (!m_proxySSLCertPath.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_SSLCERT, m_proxySSLCertPath.c_str()); + if (!m_proxySSLCertType.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_SSLCERTTYPE, m_proxySSLCertType.c_str()); + } + } + if (!m_proxySSLKeyPath.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_SSLKEY, m_proxySSLKeyPath.c_str()); + if (!m_proxySSLKeyType.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_SSLKEYTYPE, m_proxySSLKeyType.c_str()); + } + if (!m_proxyKeyPasswd.empty()) + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY_KEYPASSWD, m_proxyKeyPasswd.c_str()); + } + } +#endif //CURL_HAS_TLS_PROXY + } + else + { + curl_easy_setopt(connectionHandle, CURLOPT_PROXY, ""); + } + + if (request->GetContentBody()) + { + curl_easy_setopt(connectionHandle, CURLOPT_READFUNCTION, ReadBody); + curl_easy_setopt(connectionHandle, CURLOPT_READDATA, &readContext); + curl_easy_setopt(connectionHandle, CURLOPT_SEEKFUNCTION, SeekBody); + curl_easy_setopt(connectionHandle, CURLOPT_SEEKDATA, &readContext); + } + + OverrideOptionsOnConnectionHandle(connectionHandle); + Aws::Utils::DateTime startTransmissionTime = Aws::Utils::DateTime::Now(); + CURLcode curlResponseCode = curl_easy_perform(connectionHandle); + bool shouldContinueRequest = ContinueRequest(*request); + if (curlResponseCode != CURLE_OK && shouldContinueRequest) + { + response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); + Aws::StringStream ss; + ss << "curlCode: " << curlResponseCode << ", " << curl_easy_strerror(curlResponseCode); + response->SetClientErrorMessage(ss.str()); + AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Curl returned error code " << curlResponseCode + << " - " << curl_easy_strerror(curlResponseCode)); + } + else if(!shouldContinueRequest) + { + response->SetClientErrorType(CoreErrors::USER_CANCELLED); + response->SetClientErrorMessage("Request cancelled by user's continuation handler"); + } + else + { + long responseCode; + curl_easy_getinfo(connectionHandle, CURLINFO_RESPONSE_CODE, &responseCode); + response->SetResponseCode(static_cast<HttpResponseCode>(responseCode)); + AWS_LOGSTREAM_DEBUG(CURL_HTTP_CLIENT_TAG, "Returned http response code " << responseCode); + + char* contentType = nullptr; + curl_easy_getinfo(connectionHandle, CURLINFO_CONTENT_TYPE, &contentType); + if (contentType) + { + response->SetContentType(contentType); + AWS_LOGSTREAM_DEBUG(CURL_HTTP_CLIENT_TAG, "Returned content type " << contentType); + } + + if (request->GetMethod() != HttpMethod::HTTP_HEAD && + writeContext.m_client->IsRequestProcessingEnabled() && + response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) + { + const Aws::String& contentLength = response->GetHeader(Aws::Http::CONTENT_LENGTH_HEADER); + int64_t numBytesResponseReceived = writeContext.m_numBytesResponseReceived; + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, "Response content-length header: " << contentLength); + AWS_LOGSTREAM_TRACE(CURL_HTTP_CLIENT_TAG, "Response body length: " << numBytesResponseReceived); + if (StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived) + { + response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage("Response body length doesn't match the content-length header."); + AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Response body length doesn't match the content-length header."); + } + } + + AWS_LOGSTREAM_DEBUG(CURL_HTTP_CLIENT_TAG, "Releasing curl handle " << connectionHandle); + } + + double timep; + CURLcode ret = curl_easy_getinfo(connectionHandle, CURLINFO_NAMELOOKUP_TIME, &timep); // DNS Resolve Latency, seconds. + if (ret == CURLE_OK) + { + request->AddRequestMetric(GetHttpClientMetricNameByType(HttpClientMetricsType::DnsLatency), static_cast<int64_t>(timep * 1000));// to milliseconds + } + + ret = curl_easy_getinfo(connectionHandle, CURLINFO_STARTTRANSFER_TIME, &timep); // Connect Latency + if (ret == CURLE_OK) + { + request->AddRequestMetric(GetHttpClientMetricNameByType(HttpClientMetricsType::ConnectLatency), static_cast<int64_t>(timep * 1000)); + } + + ret = curl_easy_getinfo(connectionHandle, CURLINFO_APPCONNECT_TIME, &timep); // Ssl Latency + if (ret == CURLE_OK) + { + request->AddRequestMetric(GetHttpClientMetricNameByType(HttpClientMetricsType::SslLatency), static_cast<int64_t>(timep * 1000)); + } + + const char* ip = nullptr; + auto curlGetInfoResult = curl_easy_getinfo(connectionHandle, CURLINFO_PRIMARY_IP, &ip); // Get the IP address of the remote endpoint + if (curlGetInfoResult == CURLE_OK && ip) + { + request->SetResolvedRemoteHost(ip); + } + if (curlResponseCode != CURLE_OK) + { + m_curlHandleContainer.DestroyCurlHandle(connectionHandle); + } + else + { + m_curlHandleContainer.ReleaseCurlHandle(connectionHandle); + } + //go ahead and flush the response body stream + response->GetResponseBody().flush(); + request->AddRequestMetric(GetHttpClientMetricNameByType(HttpClientMetricsType::RequestLatency), (DateTime::Now() - startTransmissionTime).count()); + } + + if (headers) + { + curl_slist_free_all(headers); + } + + return response; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/standard/StandardHttpRequest.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/standard/StandardHttpRequest.cpp new file mode 100644 index 0000000000..47a0ee4fac --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/standard/StandardHttpRequest.cpp @@ -0,0 +1,104 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/standard/StandardHttpRequest.h> + +#include <aws/core/utils/StringUtils.h> + +#include <iostream> +#include <algorithm> +#include <cassert> + +using namespace Aws::Http; +using namespace Aws::Http::Standard; +using namespace Aws::Utils; + +static bool IsDefaultPort(const URI& uri) +{ + switch(uri.GetPort()) + { + case 80: + return uri.GetScheme() == Scheme::HTTP; + case 443: + return uri.GetScheme() == Scheme::HTTPS; + default: + return false; + } +} + +StandardHttpRequest::StandardHttpRequest(const URI& uri, HttpMethod method) : + HttpRequest(uri, method), + bodyStream(nullptr), + m_responseStreamFactory() +{ + if(IsDefaultPort(uri)) + { + StandardHttpRequest::SetHeaderValue(HOST_HEADER, uri.GetAuthority()); + } + else + { + Aws::StringStream host; + host << uri.GetAuthority() << ":" << uri.GetPort(); + StandardHttpRequest::SetHeaderValue(HOST_HEADER, host.str()); + } +} + +HeaderValueCollection StandardHttpRequest::GetHeaders() const +{ + HeaderValueCollection headers; + + for (HeaderValueCollection::const_iterator iter = headerMap.begin(); iter != headerMap.end(); ++iter) + { + headers.emplace(HeaderValuePair(iter->first, iter->second)); + } + + return headers; +} + +const Aws::String& StandardHttpRequest::GetHeaderValue(const char* headerName) const +{ + auto iter = headerMap.find(headerName); + assert (iter != headerMap.end()); + return iter->second; +} + +void StandardHttpRequest::SetHeaderValue(const char* headerName, const Aws::String& headerValue) +{ + headerMap[StringUtils::ToLower(headerName)] = StringUtils::Trim(headerValue.c_str()); +} + +void StandardHttpRequest::SetHeaderValue(const Aws::String& headerName, const Aws::String& headerValue) +{ + headerMap[StringUtils::ToLower(headerName.c_str())] = StringUtils::Trim(headerValue.c_str()); +} + +void StandardHttpRequest::DeleteHeader(const char* headerName) +{ + headerMap.erase(StringUtils::ToLower(headerName)); +} + +bool StandardHttpRequest::HasHeader(const char* headerName) const +{ + return headerMap.find(StringUtils::ToLower(headerName)) != headerMap.end(); +} + +int64_t StandardHttpRequest::GetSize() const +{ + int64_t size = 0; + + std::for_each(headerMap.cbegin(), headerMap.cend(), [&](const HeaderValueCollection::value_type& kvPair){ size += kvPair.first.length(); size += kvPair.second.length(); }); + + return size; +} + +const Aws::IOStreamFactory& StandardHttpRequest::GetResponseStreamFactory() const +{ + return m_responseStreamFactory; +} + +void StandardHttpRequest::SetResponseStreamFactory(const Aws::IOStreamFactory& factory) +{ + m_responseStreamFactory = factory; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp new file mode 100644 index 0000000000..92d7a062b6 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp @@ -0,0 +1,46 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/http/standard/StandardHttpResponse.h> + +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/memory/AWSMemory.h> + +#include <istream> + +using namespace Aws::Http; +using namespace Aws::Http::Standard; +using namespace Aws::Utils; + + +HeaderValueCollection StandardHttpResponse::GetHeaders() const +{ + HeaderValueCollection headerValueCollection; + + for (Aws::Map<Aws::String, Aws::String>::const_iterator iter = headerMap.begin(); iter != headerMap.end(); ++iter) + { + headerValueCollection.emplace(HeaderValuePair(iter->first, iter->second)); + } + + return headerValueCollection; +} + +bool StandardHttpResponse::HasHeader(const char* headerName) const +{ + return headerMap.find(StringUtils::ToLower(headerName)) != headerMap.end(); +} + +const Aws::String& StandardHttpResponse::GetHeader(const Aws::String& headerName) const +{ + Aws::Map<Aws::String, Aws::String>::const_iterator foundValue = headerMap.find(StringUtils::ToLower(headerName.c_str())); + return foundValue->second; +} + +void StandardHttpResponse::AddHeader(const Aws::String& headerName, const Aws::String& headerValue) +{ + headerMap[StringUtils::ToLower(headerName.c_str())] = headerValue; +} + + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp new file mode 100644 index 0000000000..24145e4d92 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp @@ -0,0 +1,506 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/internal/AWSHttpResourceClient.h> +#include <aws/core/client/DefaultRetryStrategy.h> +#include <aws/core/http/HttpClient.h> +#include <aws/core/http/HttpClientFactory.h> +#include <aws/core/http/HttpResponse.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/platform/Environment.h> +#include <aws/core/client/AWSError.h> +#include <aws/core/client/CoreErrors.h> +#include <aws/core/utils/xml/XmlSerializer.h> +#include <mutex> +#include <sstream> + +using namespace Aws; +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; +using namespace Aws::Utils::Xml; +using namespace Aws::Http; +using namespace Aws::Client; +using namespace Aws::Internal; + +static const char EC2_SECURITY_CREDENTIALS_RESOURCE[] = "/latest/meta-data/iam/security-credentials"; +static const char EC2_REGION_RESOURCE[] = "/latest/meta-data/placement/availability-zone"; +static const char EC2_IMDS_TOKEN_RESOURCE[] = "/latest/api/token"; +static const char EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE[] = "21600"; +static const char EC2_IMDS_TOKEN_TTL_HEADER[] = "x-aws-ec2-metadata-token-ttl-seconds"; +static const char EC2_IMDS_TOKEN_HEADER[] = "x-aws-ec2-metadata-token"; +static const char RESOURCE_CLIENT_CONFIGURATION_ALLOCATION_TAG[] = "AWSHttpResourceClient"; +static const char EC2_METADATA_CLIENT_LOG_TAG[] = "EC2MetadataClient"; +static const char ECS_CREDENTIALS_CLIENT_LOG_TAG[] = "ECSCredentialsClient"; + +namespace Aws +{ + namespace Client + { + Aws::String ComputeUserAgentString(); + } + + namespace Internal + { + static ClientConfiguration MakeDefaultHttpResourceClientConfiguration(const char *logtag) + { + ClientConfiguration res; + + res.maxConnections = 2; + res.scheme = Scheme::HTTP; + + #if defined(WIN32) && defined(BYPASS_DEFAULT_PROXY) + // For security reasons, we must bypass any proxy settings when fetching sensitive information, for example + // user credentials. On Windows, IXMLHttpRequest2 does not support bypassing proxy settings, therefore, + // we force using WinHTTP client. On POSIX systems, CURL is set to bypass proxy settings by default. + res.httpLibOverride = TransferLibType::WIN_HTTP_CLIENT; + AWS_LOGSTREAM_INFO(logtag, "Overriding the current HTTP client to WinHTTP to bypass proxy settings."); + #else + (void) logtag; // To disable warning about unused variable + #endif + // Explicitly set the proxy settings to empty/zero to avoid relying on defaults that could potentially change + // in the future. + res.proxyHost = ""; + res.proxyUserName = ""; + res.proxyPassword = ""; + res.proxyPort = 0; + + // EC2MetadataService throttles by delaying the response so the service client should set a large read timeout. + // EC2MetadataService delay is in order of seconds so it only make sense to retry after a couple of seconds. + res.connectTimeoutMs = 1000; + res.requestTimeoutMs = 1000; + res.retryStrategy = Aws::MakeShared<DefaultRetryStrategy>(RESOURCE_CLIENT_CONFIGURATION_ALLOCATION_TAG, 1, 1000); + + return res; + } + + AWSHttpResourceClient::AWSHttpResourceClient(const Aws::Client::ClientConfiguration& clientConfiguration, const char* logtag) + : m_logtag(logtag), m_retryStrategy(clientConfiguration.retryStrategy), m_httpClient(nullptr) + { + AWS_LOGSTREAM_INFO(m_logtag.c_str(), + "Creating AWSHttpResourceClient with max connections " + << clientConfiguration.maxConnections + << " and scheme " + << SchemeMapper::ToString(clientConfiguration.scheme)); + + m_httpClient = CreateHttpClient(clientConfiguration); + } + + AWSHttpResourceClient::AWSHttpResourceClient(const char* logtag) + : AWSHttpResourceClient(MakeDefaultHttpResourceClientConfiguration(logtag), logtag) + { + } + + AWSHttpResourceClient::~AWSHttpResourceClient() + { + } + + Aws::String AWSHttpResourceClient::GetResource(const char* endpoint, const char* resource, const char* authToken) const + { + return GetResourceWithAWSWebServiceResult(endpoint, resource, authToken).GetPayload(); + } + + AmazonWebServiceResult<Aws::String> AWSHttpResourceClient::GetResourceWithAWSWebServiceResult(const char *endpoint, const char *resource, const char *authToken) const + { + Aws::StringStream ss; + ss << endpoint; + if (resource) + { + ss << resource; + } + std::shared_ptr<HttpRequest> request(CreateHttpRequest(ss.str(), HttpMethod::HTTP_GET, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + + request->SetUserAgent(ComputeUserAgentString()); + + if (authToken) + { + request->SetHeaderValue(Aws::Http::AWS_AUTHORIZATION_HEADER, authToken); + } + + return GetResourceWithAWSWebServiceResult(request); + } + + AmazonWebServiceResult<Aws::String> AWSHttpResourceClient::GetResourceWithAWSWebServiceResult(const std::shared_ptr<HttpRequest> &httpRequest) const + { + AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Retrieving credentials from " << httpRequest->GetURIString()); + + for (long retries = 0;; retries++) + { + std::shared_ptr<HttpResponse> response(m_httpClient->MakeRequest(httpRequest)); + + if (response->GetResponseCode() == HttpResponseCode::OK) + { + Aws::IStreamBufIterator eos; + return {Aws::String(Aws::IStreamBufIterator(response->GetResponseBody()), eos), response->GetHeaders(), HttpResponseCode::OK}; + } + + const Aws::Client::AWSError<Aws::Client::CoreErrors> error = [this, &response]() { + if (response->HasClientError() || response->GetResponseBody().tellp() < 1) + { + AWS_LOGSTREAM_ERROR(m_logtag.c_str(), "Http request to retrieve credentials failed"); + return AWSError<CoreErrors>(CoreErrors::NETWORK_CONNECTION, true); // Retryable + } + else if (m_errorMarshaller) + { + return m_errorMarshaller->Marshall(*response); + } + else + { + const auto responseCode = response->GetResponseCode(); + + AWS_LOGSTREAM_ERROR(m_logtag.c_str(), "Http request to retrieve credentials failed with error code " + << static_cast<int>(responseCode)); + return CoreErrorsMapper::GetErrorForHttpResponseCode(responseCode); + } + }(); + + if (!m_retryStrategy->ShouldRetry(error, retries)) + { + AWS_LOGSTREAM_ERROR(m_logtag.c_str(), "Can not retrive resource from " << httpRequest->GetURIString()); + return {{}, response->GetHeaders(), error.GetResponseCode()}; + } + auto sleepMillis = m_retryStrategy->CalculateDelayBeforeNextRetry(error, retries); + AWS_LOGSTREAM_WARN(m_logtag.c_str(), "Request failed, now waiting " << sleepMillis << " ms before attempting again."); + m_httpClient->RetryRequestSleep(std::chrono::milliseconds(sleepMillis)); + } + } + + EC2MetadataClient::EC2MetadataClient(const char* endpoint) + : AWSHttpResourceClient(EC2_METADATA_CLIENT_LOG_TAG), m_endpoint(endpoint), m_tokenRequired(true) + { + } + + EC2MetadataClient::EC2MetadataClient(const Aws::Client::ClientConfiguration &clientConfiguration, const char *endpoint) + : AWSHttpResourceClient(clientConfiguration, EC2_METADATA_CLIENT_LOG_TAG), m_endpoint(endpoint), m_tokenRequired(true) + { + } + + EC2MetadataClient::~EC2MetadataClient() + { + + } + + Aws::String EC2MetadataClient::GetResource(const char* resourcePath) const + { + return GetResource(m_endpoint.c_str(), resourcePath, nullptr/*authToken*/); + } + + Aws::String EC2MetadataClient::GetDefaultCredentials() const + { + std::unique_lock<std::recursive_mutex> locker(m_tokenMutex); + if (m_tokenRequired) + { + return GetDefaultCredentialsSecurely(); + } + + AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Getting default credentials for ec2 instance"); + auto result = GetResourceWithAWSWebServiceResult(m_endpoint.c_str(), EC2_SECURITY_CREDENTIALS_RESOURCE, nullptr); + Aws::String credentialsString = result.GetPayload(); + auto httpResponseCode = result.GetResponseCode(); + + // Note, if service is insane, it might return 404 for our initial secure call, + // then when we fall back to insecure call, it might return 401 ask for secure call, + // Then, SDK might get into a recursive loop call situation between secure and insecure call. + if (httpResponseCode == Http::HttpResponseCode::UNAUTHORIZED) + { + m_tokenRequired = true; + return {}; + } + locker.unlock(); + + Aws::String trimmedCredentialsString = StringUtils::Trim(credentialsString.c_str()); + if (trimmedCredentialsString.empty()) return {}; + + Aws::Vector<Aws::String> securityCredentials = StringUtils::Split(trimmedCredentialsString, '\n'); + + AWS_LOGSTREAM_DEBUG(m_logtag.c_str(), "Calling EC2MetadataService resource, " << EC2_SECURITY_CREDENTIALS_RESOURCE + << " returned credential string " << trimmedCredentialsString); + + if (securityCredentials.size() == 0) + { + AWS_LOGSTREAM_WARN(m_logtag.c_str(), "Initial call to ec2Metadataservice to get credentials failed"); + return {}; + } + + Aws::StringStream ss; + ss << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << securityCredentials[0]; + AWS_LOGSTREAM_DEBUG(m_logtag.c_str(), "Calling EC2MetadataService resource " << ss.str()); + return GetResource(ss.str().c_str()); + } + + Aws::String EC2MetadataClient::GetDefaultCredentialsSecurely() const + { + std::unique_lock<std::recursive_mutex> locker(m_tokenMutex); + if (!m_tokenRequired) + { + return GetDefaultCredentials(); + } + + Aws::StringStream ss; + ss << m_endpoint << EC2_IMDS_TOKEN_RESOURCE; + std::shared_ptr<HttpRequest> tokenRequest(CreateHttpRequest(ss.str(), HttpMethod::HTTP_PUT, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + tokenRequest->SetHeaderValue(EC2_IMDS_TOKEN_TTL_HEADER, EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE); + auto userAgentString = ComputeUserAgentString(); + tokenRequest->SetUserAgent(userAgentString); + AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Calling EC2MetadataService to get token"); + auto result = GetResourceWithAWSWebServiceResult(tokenRequest); + Aws::String tokenString = result.GetPayload(); + Aws::String trimmedTokenString = StringUtils::Trim(tokenString.c_str()); + + if (result.GetResponseCode() == HttpResponseCode::BAD_REQUEST) + { + return {}; + } + else if (result.GetResponseCode() != HttpResponseCode::OK || trimmedTokenString.empty()) + { + m_tokenRequired = false; + AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Calling EC2MetadataService to get token failed, falling back to less secure way."); + return GetDefaultCredentials(); + } + m_token = trimmedTokenString; + locker.unlock(); + ss.str(""); + ss << m_endpoint << EC2_SECURITY_CREDENTIALS_RESOURCE; + std::shared_ptr<HttpRequest> profileRequest(CreateHttpRequest(ss.str(), HttpMethod::HTTP_GET, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + profileRequest->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, trimmedTokenString); + profileRequest->SetUserAgent(userAgentString); + Aws::String profileString = GetResourceWithAWSWebServiceResult(profileRequest).GetPayload(); + + Aws::String trimmedProfileString = StringUtils::Trim(profileString.c_str()); + Aws::Vector<Aws::String> securityCredentials = StringUtils::Split(trimmedProfileString, '\n'); + + AWS_LOGSTREAM_DEBUG(m_logtag.c_str(), "Calling EC2MetadataService resource, " << EC2_SECURITY_CREDENTIALS_RESOURCE + << " with token returned profile string " << trimmedProfileString); + if (securityCredentials.size() == 0) + { + AWS_LOGSTREAM_WARN(m_logtag.c_str(), "Calling EC2Metadataservice to get profiles failed"); + return {}; + } + + ss.str(""); + ss << m_endpoint << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << securityCredentials[0]; + std::shared_ptr<HttpRequest> credentialsRequest(CreateHttpRequest(ss.str(), HttpMethod::HTTP_GET, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + credentialsRequest->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, trimmedTokenString); + credentialsRequest->SetUserAgent(userAgentString); + AWS_LOGSTREAM_DEBUG(m_logtag.c_str(), "Calling EC2MetadataService resource " << ss.str() << " with token."); + return GetResourceWithAWSWebServiceResult(credentialsRequest).GetPayload(); + } + + Aws::String EC2MetadataClient::GetCurrentRegion() const + { + if (!m_region.empty()) + { + return m_region; + } + + AWS_LOGSTREAM_TRACE(m_logtag.c_str(), "Getting current region for ec2 instance"); + + Aws::StringStream ss; + ss << m_endpoint << EC2_REGION_RESOURCE; + std::shared_ptr<HttpRequest> regionRequest(CreateHttpRequest(ss.str(), HttpMethod::HTTP_GET, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + { + std::lock_guard<std::recursive_mutex> locker(m_tokenMutex); + if (m_tokenRequired) + { + regionRequest->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, m_token); + } + } + regionRequest->SetUserAgent(ComputeUserAgentString()); + Aws::String azString = GetResourceWithAWSWebServiceResult(regionRequest).GetPayload(); + + if (azString.empty()) + { + AWS_LOGSTREAM_INFO(m_logtag.c_str() , + "Unable to pull region from instance metadata service "); + return {}; + } + + Aws::String trimmedAZString = StringUtils::Trim(azString.c_str()); + AWS_LOGSTREAM_DEBUG(m_logtag.c_str(), "Calling EC2MetadataService resource " + << EC2_REGION_RESOURCE << " , returned credential string " << trimmedAZString); + + Aws::String region; + region.reserve(trimmedAZString.length()); + + bool digitFound = false; + for (auto character : trimmedAZString) + { + if(digitFound && !isdigit(character)) + { + break; + } + if (isdigit(character)) + { + digitFound = true; + } + + region.append(1, character); + } + + AWS_LOGSTREAM_INFO(m_logtag.c_str(), "Detected current region as " << region); + m_region = region; + return region; + } + + #ifdef _MSC_VER + // VS2015 compiler's bug, warning s_ec2metadataClient: symbol will be dynamically initialized (implementation limitation) + AWS_SUPPRESS_WARNING(4592, + static std::shared_ptr<EC2MetadataClient> s_ec2metadataClient(nullptr); + ) + #else + static std::shared_ptr<EC2MetadataClient> s_ec2metadataClient(nullptr); + #endif + + void InitEC2MetadataClient() + { + if (s_ec2metadataClient) + { + return; + } + s_ec2metadataClient = Aws::MakeShared<EC2MetadataClient>(EC2_METADATA_CLIENT_LOG_TAG); + } + + void CleanupEC2MetadataClient() + { + if (!s_ec2metadataClient) + { + return; + } + s_ec2metadataClient = nullptr; + } + + std::shared_ptr<EC2MetadataClient> GetEC2MetadataClient() + { + return s_ec2metadataClient; + } + + + ECSCredentialsClient::ECSCredentialsClient(const char* resourcePath, const char* endpoint, const char* token) + : AWSHttpResourceClient(ECS_CREDENTIALS_CLIENT_LOG_TAG), + m_resourcePath(resourcePath), m_endpoint(endpoint), m_token(token) + { + } + + ECSCredentialsClient::ECSCredentialsClient(const Aws::Client::ClientConfiguration& clientConfiguration, const char* resourcePath, const char* endpoint, const char* token) + : AWSHttpResourceClient(clientConfiguration, ECS_CREDENTIALS_CLIENT_LOG_TAG), + m_resourcePath(resourcePath), m_endpoint(endpoint), m_token(token) + { + } + + static const char STS_RESOURCE_CLIENT_LOG_TAG[] = "STSResourceClient"; + STSCredentialsClient::STSCredentialsClient(const Aws::Client::ClientConfiguration& clientConfiguration) + : AWSHttpResourceClient(clientConfiguration, STS_RESOURCE_CLIENT_LOG_TAG) + { + SetErrorMarshaller(Aws::MakeUnique<Aws::Client::XmlErrorMarshaller>(STS_RESOURCE_CLIENT_LOG_TAG)); + + Aws::StringStream ss; + if (clientConfiguration.scheme == Aws::Http::Scheme::HTTP) + { + ss << "http://"; + } + else + { + ss << "https://"; + } + + static const int CN_NORTH_1_HASH = Aws::Utils::HashingUtils::HashString(Aws::Region::CN_NORTH_1); + static const int CN_NORTHWEST_1_HASH = Aws::Utils::HashingUtils::HashString(Aws::Region::CN_NORTHWEST_1); + auto hash = Aws::Utils::HashingUtils::HashString(clientConfiguration.region.c_str()); + + ss << "sts." << clientConfiguration.region << ".amazonaws.com"; + if (hash == CN_NORTH_1_HASH || hash == CN_NORTHWEST_1_HASH) + { + ss << ".cn"; + } + m_endpoint = ss.str(); + + AWS_LOGSTREAM_INFO(STS_RESOURCE_CLIENT_LOG_TAG, "Creating STS ResourceClient with endpoint: " << m_endpoint); + } + + STSCredentialsClient::STSAssumeRoleWithWebIdentityResult STSCredentialsClient::GetAssumeRoleWithWebIdentityCredentials(const STSAssumeRoleWithWebIdentityRequest& request) + { + //Calculate query string + Aws::StringStream ss; + ss << "Action=AssumeRoleWithWebIdentity" + << "&Version=2011-06-15" + << "&RoleSessionName=" << Aws::Utils::StringUtils::URLEncode(request.roleSessionName.c_str()) + << "&RoleArn=" << Aws::Utils::StringUtils::URLEncode(request.roleArn.c_str()) + << "&WebIdentityToken=" << Aws::Utils::StringUtils::URLEncode(request.webIdentityToken.c_str()); + + std::shared_ptr<HttpRequest> httpRequest(CreateHttpRequest(m_endpoint, HttpMethod::HTTP_POST, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod)); + + httpRequest->SetUserAgent(ComputeUserAgentString()); + + std::shared_ptr<Aws::IOStream> body = Aws::MakeShared<Aws::StringStream>("STS_RESOURCE_CLIENT_LOG_TAG"); + *body << ss.str(); + + httpRequest->AddContentBody(body); + body->seekg(0, body->end); + auto streamSize = body->tellg(); + body->seekg(0, body->beg); + Aws::StringStream contentLength; + contentLength << streamSize; + httpRequest->SetContentLength(contentLength.str()); + httpRequest->SetContentType("application/x-www-form-urlencoded"); + + Aws::String credentialsStr = GetResourceWithAWSWebServiceResult(httpRequest).GetPayload(); + + //Parse credentials + STSAssumeRoleWithWebIdentityResult result; + if (credentialsStr.empty()) + { + AWS_LOGSTREAM_WARN(STS_RESOURCE_CLIENT_LOG_TAG, "Get an empty credential from sts"); + return result; + } + + const Utils::Xml::XmlDocument xmlDocument = XmlDocument::CreateFromXmlString(credentialsStr); + XmlNode rootNode = xmlDocument.GetRootElement(); + XmlNode resultNode = rootNode; + if (!rootNode.IsNull() && (rootNode.GetName() != "AssumeRoleWithWebIdentityResult")) + { + resultNode = rootNode.FirstChild("AssumeRoleWithWebIdentityResult"); + } + + if (!resultNode.IsNull()) + { + XmlNode credentialsNode = resultNode.FirstChild("Credentials"); + if (!credentialsNode.IsNull()) + { + XmlNode accessKeyIdNode = credentialsNode.FirstChild("AccessKeyId"); + if (!accessKeyIdNode.IsNull()) + { + result.creds.SetAWSAccessKeyId(accessKeyIdNode.GetText()); + } + + XmlNode secretAccessKeyNode = credentialsNode.FirstChild("SecretAccessKey"); + if (!secretAccessKeyNode.IsNull()) + { + result.creds.SetAWSSecretKey(secretAccessKeyNode.GetText()); + } + + XmlNode sessionTokenNode = credentialsNode.FirstChild("SessionToken"); + if (!sessionTokenNode.IsNull()) + { + result.creds.SetSessionToken(sessionTokenNode.GetText()); + } + + XmlNode expirationNode = credentialsNode.FirstChild("Expiration"); + if (!expirationNode.IsNull()) + { + result.creds.SetExpiration(DateTime(StringUtils::Trim(expirationNode.GetText().c_str()).c_str(), DateFormat::ISO_8601)); + } + } + } + return result; + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/DefaultMonitoring.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/DefaultMonitoring.cpp new file mode 100644 index 0000000000..9953004bc3 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/DefaultMonitoring.cpp @@ -0,0 +1,340 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/memory/AWSMemory.h> +#include <aws/core/monitoring/DefaultMonitoring.h> +#include <aws/core/utils/DateTime.h> +#include <aws/core/utils/json/JsonSerializer.h> +#include <aws/core/utils/Outcome.h> +#include <aws/core/client/AWSClient.h> +#include <aws/core/auth/AWSCredentialsProvider.h> +#include <aws/core/platform/Environment.h> +#include <aws/core/config/AWSProfileConfigLoader.h> +#include <aws/core/utils/logging/LogMacros.h> +using namespace Aws::Utils; + +namespace Aws +{ + namespace Monitoring + { + static const char DEFAULT_MONITORING_ALLOC_TAG[] = "DefaultMonitoringAllocTag"; + static const int CLIENT_ID_LENGTH_LIMIT = 256; + static const int USER_AGENT_LENGTH_LIMIT = 256; + static const int ERROR_MESSAGE_LENGTH_LIMIT = 512; + + const char DEFAULT_MONITORING_CLIENT_ID[] = ""; // default to empty; + const char DEFAULT_MONITORING_HOST[] = "127.0.0.1"; // default to loopback ip address instead of "localhost" based on design specification. + unsigned short DEFAULT_MONITORING_PORT = 31000; //default to 31000; + bool DEFAULT_MONITORING_ENABLE = false; //default to false; + + const int DefaultMonitoring::DEFAULT_MONITORING_VERSION = 1; + const char DefaultMonitoring::DEFAULT_CSM_CONFIG_ENABLED[] = "csm_enabled"; + const char DefaultMonitoring::DEFAULT_CSM_CONFIG_CLIENT_ID[] = "csm_client_id"; + const char DefaultMonitoring::DEFAULT_CSM_CONFIG_HOST[] = "csm_host"; + const char DefaultMonitoring::DEFAULT_CSM_CONFIG_PORT[] = "csm_port"; + const char DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_ENABLED[] = "AWS_CSM_ENABLED"; + const char DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_CLIENT_ID[] = "AWS_CSM_CLIENT_ID"; + const char DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_HOST[] = "AWS_CSM_HOST"; + const char DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_PORT[] = "AWS_CSM_PORT"; + + + struct DefaultContext + { + Aws::Utils::DateTime apiCallStartTime; + Aws::Utils::DateTime attemptStartTime; + int retryCount = 0; + bool lastAttemptSucceeded = false; + bool lastErrorRetryable = false; //doesn't apply if last attempt succeeded. + const Aws::Client::HttpResponseOutcome* outcome = nullptr; + }; + + static inline void FillRequiredFieldsToJson(Json::JsonValue& json, + const Aws::String& type, + const Aws::String& service, + const Aws::String& api, + const Aws::String& clientId, + const DateTime& timestamp, + int version, + const Aws::String& userAgent) + { + json.WithString("Type", type) + .WithString("Service", service) + .WithString("Api", api) + .WithString("ClientId", clientId.substr(0, CLIENT_ID_LENGTH_LIMIT)) + .WithInt64("Timestamp", timestamp.Millis()) + .WithInteger("Version", version) + .WithString("UserAgent", userAgent.substr(0, USER_AGENT_LENGTH_LIMIT)); + } + + static inline void FillRequiredApiCallFieldsToJson(Json::JsonValue& json, + int attemptCount, + int64_t apiCallLatency, + bool maxRetriesExceeded) + { + json.WithInteger("AttemptCount", attemptCount) + .WithInt64("Latency", apiCallLatency) + .WithInteger("MaxRetriesExceeded", maxRetriesExceeded ? 1 : 0); + } + + static inline void FillRequiredApiAttemptFieldsToJson(Json::JsonValue& json, + const Aws::String& domainName, + int64_t attemptLatency) + { + json.WithString("Fqdn", domainName) + .WithInt64("AttemptLatency", attemptLatency); + } + + static inline void ExportResponseHeaderToJson(Json::JsonValue& json, const Aws::Http::HeaderValueCollection& headers, + const Aws::String& headerName, const Aws::String& targetName) + { + auto iter = headers.find(headerName); + if (iter != headers.end()) + { + json.WithString(targetName, iter->second); + } + } + + static inline void ExportHttpMetricsToJson(Json::JsonValue& json, const Aws::Monitoring::HttpClientMetricsCollection& httpMetrics, Aws::Monitoring::HttpClientMetricsType type) + { + auto iter = httpMetrics.find(GetHttpClientMetricNameByType(type)); + if (iter != httpMetrics.end()) + { + json.WithInt64(GetHttpClientMetricNameByType(type), iter->second); + } + } + + static inline void FillOptionalApiCallFieldsToJson(Json::JsonValue& json, + const Aws::Http::HttpRequest* request, + const Aws::Client::HttpResponseOutcome& outcome) + { + if (!request->GetSigningRegion().empty()) + { + json.WithString("Region", request->GetSigningRegion()); + } + if (!outcome.IsSuccess()) + { + if (outcome.GetError().GetExceptionName().empty()) // Not Aws Exception + { + json.WithString("FinalSdkExceptionMessage", outcome.GetError().GetMessage().substr(0, ERROR_MESSAGE_LENGTH_LIMIT)); + } + else // Aws Exception + { + json.WithString("FinalAwsException", outcome.GetError().GetExceptionName()) + .WithString("FinalAwsExceptionMessage", outcome.GetError().GetMessage().substr(0, ERROR_MESSAGE_LENGTH_LIMIT)); + } + json.WithInteger("FinalHttpStatusCode", static_cast<int>(outcome.GetError().GetResponseCode())); + } + else + { + json.WithInteger("FinalHttpStatusCode", static_cast<int>(outcome.GetResult()->GetResponseCode())); + } + } + + static inline void FillOptionalApiAttemptFieldsToJson(Json::JsonValue& json, + const Aws::Http::HttpRequest* request, + const Aws::Client::HttpResponseOutcome& outcome, + const CoreMetricsCollection& metricsFromCore) + { + /** + *No matter request succeeded or not, these fields should be included as long as their requirements + *are met. We should be able to access response (so as to access original request) if the response has error. + */ + if (request->HasAwsSessionToken() && !request->GetAwsSessionToken().empty()) + { + json.WithString("SessionToken", request->GetAwsSessionToken()); + } + if (!request->GetSigningRegion().empty()) + { + json.WithString("Region", request->GetSigningRegion()); + } + if (!request->GetSigningAccessKey().empty()) + { + json.WithString("AccessKey", request->GetSigningAccessKey()); + } + + const auto& headers = outcome.IsSuccess() ? outcome.GetResult()->GetHeaders() : outcome.GetError().GetResponseHeaders(); + + ExportResponseHeaderToJson(json, headers, StringUtils::ToLower("x-amzn-RequestId"), "XAmznRequestId"); + ExportResponseHeaderToJson(json, headers, StringUtils::ToLower("x-amz-request-id"), "XAmzRequestId"); + ExportResponseHeaderToJson(json, headers, StringUtils::ToLower("x-amz-id-2"), "XAmzId2"); + + if (!outcome.IsSuccess()) + { + if (outcome.GetError().GetExceptionName().empty()) // Not Aws Exception + { + json.WithString("SdkExceptionMessage", outcome.GetError().GetMessage().substr(0, ERROR_MESSAGE_LENGTH_LIMIT)); + } + else // Aws Exception + { + json.WithString("AwsException", outcome.GetError().GetExceptionName()) + .WithString("AwsExceptionMessage", outcome.GetError().GetMessage().substr(0, ERROR_MESSAGE_LENGTH_LIMIT)); + } + json.WithInteger("HttpStatusCode", static_cast<int>(outcome.GetError().GetResponseCode())); + } + else + { + json.WithInteger("HttpStatusCode", static_cast<int>(outcome.GetResult()->GetResponseCode())); + } + + // Optional MetricsCollectedFromCore + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::AcquireConnectionLatency); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::ConnectionReused); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::ConnectLatency); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::DestinationIp); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::DnsLatency); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::RequestLatency); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::SslLatency); + ExportHttpMetricsToJson(json, metricsFromCore.httpClientMetrics, HttpClientMetricsType::TcpLatency); + } + + DefaultMonitoring::DefaultMonitoring(const Aws::String& clientId, const Aws::String& host, unsigned short port): + m_udp(host.c_str(), port), m_clientId(clientId) + { + } + + void* DefaultMonitoring::OnRequestStarted(const Aws::String& serviceName, const Aws::String& requestName, const std::shared_ptr<const Aws::Http::HttpRequest>& request) const + { + AWS_UNREFERENCED_PARAM(request); + + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "OnRequestStart Service: " << serviceName << "Request: " << requestName); + auto context = Aws::New<DefaultContext>(DEFAULT_MONITORING_ALLOC_TAG); + context->apiCallStartTime = Aws::Utils::DateTime::Now(); + context->attemptStartTime = context->apiCallStartTime; + context->retryCount = 0; + return context; + } + + + void DefaultMonitoring::OnRequestSucceeded(const Aws::String& serviceName, const Aws::String& requestName, const std::shared_ptr<const Aws::Http::HttpRequest>& request, + const Aws::Client::HttpResponseOutcome& outcome, const CoreMetricsCollection& metricsFromCore, void* context) const + { + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "OnRequestSucceeded Service: " << serviceName << "Request: " << requestName); + CollectAndSendAttemptData(serviceName, requestName, request, outcome, metricsFromCore, context); + } + + void DefaultMonitoring::OnRequestFailed(const Aws::String& serviceName, const Aws::String& requestName, const std::shared_ptr<const Aws::Http::HttpRequest>& request, + const Aws::Client::HttpResponseOutcome& outcome, const CoreMetricsCollection& metricsFromCore, void* context) const + { + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "OnRequestFailed Service: " << serviceName << "Request: " << requestName); + CollectAndSendAttemptData(serviceName, requestName, request, outcome, metricsFromCore, context); + } + + void DefaultMonitoring::OnRequestRetry(const Aws::String& serviceName, const Aws::String& requestName, + const std::shared_ptr<const Aws::Http::HttpRequest>& request, void* context) const + { + AWS_UNREFERENCED_PARAM(request); + + DefaultContext* defaultContext = static_cast<DefaultContext*>(context); + defaultContext->retryCount++; + defaultContext->attemptStartTime = Aws::Utils::DateTime::Now(); + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "OnRequestRetry Service: " << serviceName << "Request: " << requestName << " RetryCnt:" << defaultContext->retryCount); + } + + void DefaultMonitoring::OnFinish(const Aws::String& serviceName, const Aws::String& requestName, + const std::shared_ptr<const Aws::Http::HttpRequest>& request, void* context) const + { + AWS_UNREFERENCED_PARAM(request); + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "OnRequestFinish Service: " << serviceName << "Request: " << requestName); + + DefaultContext* defaultContext = static_cast<DefaultContext*>(context); + Aws::Utils::Json::JsonValue json; + FillRequiredFieldsToJson(json, "ApiCall", serviceName, requestName, m_clientId, defaultContext->apiCallStartTime, DEFAULT_MONITORING_VERSION, request->GetUserAgent()); + FillRequiredApiCallFieldsToJson(json, defaultContext->retryCount + 1, (DateTime::Now() - defaultContext->apiCallStartTime).count(), (!defaultContext->lastAttemptSucceeded && defaultContext->lastErrorRetryable)); + FillOptionalApiCallFieldsToJson(json, request.get(), *(defaultContext->outcome)); + Aws::String compactData = json.View().WriteCompact(); + m_udp.SendData(reinterpret_cast<const uint8_t*>(compactData.c_str()), static_cast<int>(compactData.size())); + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Send API Metrics: \n" << json.View().WriteReadable()); + Aws::Delete(defaultContext); + } + + void DefaultMonitoring::CollectAndSendAttemptData(const Aws::String& serviceName, const Aws::String& requestName, + const std::shared_ptr<const Aws::Http::HttpRequest>& request, const Aws::Client::HttpResponseOutcome& outcome, + const CoreMetricsCollection& metricsFromCore, void* context) const + { + DefaultContext* defaultContext = static_cast<DefaultContext*>(context); + defaultContext->outcome = &outcome; + defaultContext->lastAttemptSucceeded = outcome.IsSuccess() ? true : false; + defaultContext->lastErrorRetryable = (!outcome.IsSuccess() && outcome.GetError().ShouldRetry()) ? true : false; + Aws::Utils::Json::JsonValue json; + FillRequiredFieldsToJson(json, "ApiCallAttempt", serviceName, requestName, m_clientId, defaultContext->attemptStartTime, DEFAULT_MONITORING_VERSION, request->GetUserAgent()); + FillRequiredApiAttemptFieldsToJson(json, request->GetUri().GetAuthority(), (DateTime::Now() - defaultContext->attemptStartTime).count()); + FillOptionalApiAttemptFieldsToJson(json, request.get(), outcome, metricsFromCore); + Aws::String compactData = json.View().WriteCompact(); + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Send Attempt Metrics: \n" << json.View().WriteReadable()); + m_udp.SendData(reinterpret_cast<const uint8_t*>(compactData.c_str()), static_cast<int>(compactData.size())); + } + + Aws::UniquePtr<MonitoringInterface> DefaultMonitoringFactory::CreateMonitoringInstance() const + { + Aws::String clientId(DEFAULT_MONITORING_CLIENT_ID); // default to empty + Aws::String host(DEFAULT_MONITORING_HOST); // default to 127.0.0.1 + unsigned short port = DEFAULT_MONITORING_PORT; // default to 31000 + bool enable = DEFAULT_MONITORING_ENABLE; //default to false; + + //check profile_config + Aws::String tmpEnable = Aws::Config::GetCachedConfigValue(DefaultMonitoring::DEFAULT_CSM_CONFIG_ENABLED); + Aws::String tmpClientId = Aws::Config::GetCachedConfigValue(DefaultMonitoring::DEFAULT_CSM_CONFIG_CLIENT_ID); + Aws::String tmpHost = Aws::Config::GetCachedConfigValue(DefaultMonitoring::DEFAULT_CSM_CONFIG_HOST); + Aws::String tmpPort = Aws::Config::GetCachedConfigValue(DefaultMonitoring::DEFAULT_CSM_CONFIG_PORT); + + if (!tmpEnable.empty()) + { + enable = StringUtils::CaselessCompare(tmpEnable.c_str(), "true") ? true : false; + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved csm_enabled from profile_config to be " << enable); + } + if (!tmpClientId.empty()) + { + clientId = tmpClientId; + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved csm_client_id from profile_config to be " << clientId); + } + + if (!tmpHost.empty()) + { + host = tmpHost; + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved csm_host from profile_config to be " << host); + } + + if (!tmpPort.empty()) + { + port = static_cast<short>(StringUtils::ConvertToInt32(tmpPort.c_str())); + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved csm_port from profile_config to be " << port); + } + + // check environment variables + tmpEnable = Aws::Environment::GetEnv(DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_ENABLED); + tmpClientId = Aws::Environment::GetEnv(DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_CLIENT_ID); + tmpHost = Aws::Environment::GetEnv(DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_HOST); + tmpPort = Aws::Environment::GetEnv(DefaultMonitoring::DEFAULT_CSM_ENVIRONMENT_VAR_PORT); + if (!tmpEnable.empty()) + { + enable = StringUtils::CaselessCompare(tmpEnable.c_str(), "true") ? true : false; + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved AWS_CSM_ENABLED from Environment variable to be " << enable); + } + if (!tmpClientId.empty()) + { + clientId = tmpClientId; + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved AWS_CSM_CLIENT_ID from Environment variable to be " << clientId); + + } + if (!tmpHost.empty()) + { + host = tmpHost; + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved AWS_CSM_HOST from Environment variable to be " << host); + } + if (!tmpPort.empty()) + { + port = static_cast<unsigned short>(StringUtils::ConvertToInt32(tmpPort.c_str())); + AWS_LOGSTREAM_DEBUG(DEFAULT_MONITORING_ALLOC_TAG, "Resolved AWS_CSM_PORT from Environment variable to be " << port); + } + + if (!enable) + { + return nullptr; + } + return Aws::MakeUnique<DefaultMonitoring>(DEFAULT_MONITORING_ALLOC_TAG, clientId, host, port); + } + + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/HttpClientMetrics.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/HttpClientMetrics.cpp new file mode 100644 index 0000000000..f3ef582867 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/HttpClientMetrics.cpp @@ -0,0 +1,71 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/monitoring/HttpClientMetrics.h> + +namespace Aws +{ + namespace Monitoring + { + static const char HTTP_CLIENT_METRICS_DESTINATION_IP[] = "DestinationIp"; + static const char HTTP_CLIENT_METRICS_ACQUIRE_CONNECTION_LATENCY[] = "AcquireConnectionLatency"; + static const char HTTP_CLIENT_METRICS_CONNECTION_REUSED[] = "ConnectionReused"; + static const char HTTP_CLIENT_METRICS_CONNECTION_LATENCY[] = "ConnectLatency"; + static const char HTTP_CLIENT_METRICS_REQUEST_LATENCY[] = "RequestLatency"; + static const char HTTP_CLIENT_METRICS_DNS_LATENCY[] = "DnsLatency"; + static const char HTTP_CLIENT_METRICS_TCP_LATENCY[] = "TcpLatency"; + static const char HTTP_CLIENT_METRICS_SSL_LATENCY[] = "SslLatency"; + static const char HTTP_CLIENT_METRICS_UNKNOWN[] = "Unknown"; + + using namespace Aws::Utils; + HttpClientMetricsType GetHttpClientMetricTypeByName(const Aws::String& name) + { + static std::map<int, HttpClientMetricsType> metricsNameHashToType = + { + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_DESTINATION_IP), HttpClientMetricsType::DestinationIp), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_ACQUIRE_CONNECTION_LATENCY), HttpClientMetricsType::AcquireConnectionLatency), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_CONNECTION_REUSED), HttpClientMetricsType::ConnectionReused), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_CONNECTION_LATENCY), HttpClientMetricsType::ConnectLatency), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_REQUEST_LATENCY), HttpClientMetricsType::RequestLatency), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_DNS_LATENCY), HttpClientMetricsType::DnsLatency), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_TCP_LATENCY), HttpClientMetricsType::TcpLatency), + std::pair<int, HttpClientMetricsType>(HashingUtils::HashString(HTTP_CLIENT_METRICS_SSL_LATENCY), HttpClientMetricsType::SslLatency) + }; + + int nameHash = HashingUtils::HashString(name.c_str()); + auto it = metricsNameHashToType.find(nameHash); + if (it == metricsNameHashToType.end()) + { + return HttpClientMetricsType::Unknown; + } + return it->second; + } + + Aws::String GetHttpClientMetricNameByType(HttpClientMetricsType type) + { + static std::map<int, std::string> metricsTypeToName = + { + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::DestinationIp), HTTP_CLIENT_METRICS_DESTINATION_IP), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::AcquireConnectionLatency), HTTP_CLIENT_METRICS_ACQUIRE_CONNECTION_LATENCY), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::ConnectionReused), HTTP_CLIENT_METRICS_CONNECTION_REUSED), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::ConnectLatency), HTTP_CLIENT_METRICS_CONNECTION_LATENCY), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::RequestLatency), HTTP_CLIENT_METRICS_REQUEST_LATENCY), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::DnsLatency), HTTP_CLIENT_METRICS_DNS_LATENCY), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::TcpLatency), HTTP_CLIENT_METRICS_TCP_LATENCY), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::SslLatency), HTTP_CLIENT_METRICS_SSL_LATENCY), + std::pair<int, std::string>(static_cast<int>(HttpClientMetricsType::Unknown), HTTP_CLIENT_METRICS_UNKNOWN) + }; + + auto it = metricsTypeToName.find(static_cast<int>(type)); + if (it == metricsTypeToName.end()) + { + return HTTP_CLIENT_METRICS_UNKNOWN; + } + return Aws::String(it->second.c_str()); + } + + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/MonitoringManager.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/MonitoringManager.cpp new file mode 100644 index 0000000000..7a8d3adb41 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/monitoring/MonitoringManager.cpp @@ -0,0 +1,129 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/memory/AWSMemory.h> +#include <aws/core/monitoring/MonitoringInterface.h> +#include <aws/core/monitoring/MonitoringFactory.h> +#include <aws/core/monitoring/MonitoringManager.h> +#include <aws/core/monitoring/DefaultMonitoring.h> +#include <aws/core/Core_EXPORTS.h> + +#ifdef _MSC_VER +#pragma warning(disable : 4592) +#endif + +namespace Aws +{ + namespace Monitoring + { + typedef Aws::Vector<Aws::UniquePtr<MonitoringInterface>> Monitors; + + const char MonitoringTag[] = "MonitoringAllocTag"; + + /** + * Global factory to create global metrics instance. + */ + static Aws::UniquePtr<Monitors> s_monitors; + + Aws::Vector<void*> OnRequestStarted(const Aws::String& serviceName, const Aws::String& requestName, const std::shared_ptr<const Aws::Http::HttpRequest>& request) + { + assert(s_monitors); + Aws::Vector<void*> contexts; + contexts.reserve(s_monitors->size()); + for (const auto& interface: *s_monitors) + { + contexts.emplace_back(interface->OnRequestStarted(serviceName, requestName, request)); + } + return contexts; + } + + void OnRequestSucceeded(const Aws::String& serviceName, const Aws::String& requestName, const std::shared_ptr<const Aws::Http::HttpRequest>& request, + const Aws::Client::HttpResponseOutcome& outcome, const CoreMetricsCollection& metricsFromCore, const Aws::Vector<void*>& contexts) + { + assert(s_monitors); + assert(contexts.size() == s_monitors->size()); + size_t index = 0; + for (const auto& interface: *s_monitors) + { + interface->OnRequestSucceeded(serviceName, requestName, request, outcome, metricsFromCore, contexts[index++]); + } + } + + void OnRequestFailed(const Aws::String& serviceName, const Aws::String& requestName, const std::shared_ptr<const Aws::Http::HttpRequest>& request, + const Aws::Client::HttpResponseOutcome& outcome, const CoreMetricsCollection& metricsFromCore, const Aws::Vector<void*>& contexts) + { + assert(s_monitors); + assert(contexts.size() == s_monitors->size()); + size_t index = 0; + for (const auto& interface: *s_monitors) + { + interface->OnRequestFailed(serviceName, requestName, request, outcome, metricsFromCore, contexts[index++]); + } + } + + void OnRequestRetry(const Aws::String& serviceName, const Aws::String& requestName, + const std::shared_ptr<const Aws::Http::HttpRequest>& request, const Aws::Vector<void*>& contexts) + { + assert(s_monitors); + assert(contexts.size() == s_monitors->size()); + size_t index = 0; + for (const auto& interface: *s_monitors) + { + interface->OnRequestRetry(serviceName, requestName, request, contexts[index++]); + } + } + + void OnFinish(const Aws::String& serviceName, const Aws::String& requestName, + const std::shared_ptr<const Aws::Http::HttpRequest>& request, const Aws::Vector<void*>& contexts) + { + assert(s_monitors); + assert(contexts.size() == s_monitors->size()); + size_t index = 0; + for (const auto& interface: *s_monitors) + { + interface->OnFinish(serviceName, requestName, request, contexts[index++]); + } + } + + void InitMonitoring(const std::vector<MonitoringFactoryCreateFunction>& monitoringFactoryCreateFunctions) + { + if (s_monitors) + { + return; + } + s_monitors = Aws::MakeUnique<Monitors>(MonitoringTag); + for (const auto& function: monitoringFactoryCreateFunctions) + { + auto factory = function(); + if (factory) + { + auto instance = factory->CreateMonitoringInstance(); + if (instance) + { + s_monitors->emplace_back(std::move(instance)); + } + } + } + + auto defaultMonitoringFactory = Aws::MakeShared<DefaultMonitoringFactory>(MonitoringTag); + auto instance = defaultMonitoringFactory->CreateMonitoringInstance(); + if (instance) + { + s_monitors->emplace_back(std::move(instance)); + } + } + + void CleanupMonitoring() + { + if (!s_monitors) + { + return; + } + + s_monitors = nullptr; + } + } // namespace Monitoring + +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/net/linux-shared/Net.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/net/linux-shared/Net.cpp new file mode 100644 index 0000000000..244df21945 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/net/linux-shared/Net.cpp @@ -0,0 +1,28 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/net/Net.h> + +namespace Aws +{ + namespace Net + { + // For Posix system, currently we don't need to do anything for network stack initialization. + // But we need to do initialization for WinSock on Windows and call them in Aws.cpp. So these functions + // also exist for Posix systems. + bool IsNetworkInitiated() + { + return true; + } + + void InitNetwork() + { + } + + void CleanupNetwork() + { + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/net/linux-shared/SimpleUDP.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/net/linux-shared/SimpleUDP.cpp new file mode 100644 index 0000000000..d9e0c385fd --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/net/linux-shared/SimpleUDP.cpp @@ -0,0 +1,285 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <sys/types.h> +#include <sys/socket.h> +#include <netdb.h> +#include <arpa/inet.h> +#include <netinet/in.h> +#include <unistd.h> +#include <fcntl.h> +#include <cassert> +#include <string.h> +#include <aws/core/net/SimpleUDP.h> +#include <aws/core/utils/logging/LogMacros.h> + +namespace Aws +{ + namespace Net + { + static const char ALLOC_TAG[] = "SimpleUDP"; + static const char IPV4_LOOP_BACK_ADDRESS[] = "127.0.0.1"; + static const char IPV6_LOOP_BACK_ADDRESS[] = "::1"; + + static inline bool IsValidIPAddress(const char* ip, int addressFamily/*AF_INET or AF_INET6*/) + { + char buffer[128]; + return inet_pton(addressFamily, ip, (void*)buffer) == 1 ?true :false; + } + + static bool GetASockAddrFromHostName(const char* hostName, void* sockAddrBuffer, size_t& addrLength, int& addressFamily) + { + struct addrinfo hints, *res; + + memset(&hints, 0, sizeof(hints)); + + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_DGRAM; + if (getaddrinfo(hostName, nullptr, &hints, &res)) + { + return false; + } + + memcpy(sockAddrBuffer, res->ai_addr, res->ai_addrlen); + addrLength = res->ai_addrlen; + addressFamily = res->ai_family; + freeaddrinfo(res); + return true; + } + + static sockaddr_in BuildAddrInfoIPV4(const char* hostIP, short port) + { +#if (__GNUC__ == 4) && !defined(__clang__) + AWS_SUPPRESS_WARNING("-Wmissing-field-initializers", + sockaddr_in addrinfo {}; + ); +#else + sockaddr_in addrinfo {}; +#endif + addrinfo.sin_family = AF_INET; + addrinfo.sin_port = htons(port); + inet_pton(AF_INET, hostIP, &addrinfo.sin_addr); + return addrinfo; + } + + static sockaddr_in6 BuildAddrInfoIPV6(const char* hostIP, short port) + { +#if (__GNUC__ == 4) && !defined(__clang__) + AWS_SUPPRESS_WARNING("-Wmissing-field-initializers", + sockaddr_in6 addrinfo {}; + ); +#else + sockaddr_in6 addrinfo {}; +#endif + addrinfo.sin6_family = AF_INET6; + addrinfo.sin6_port = htons(port); + inet_pton(AF_INET6, hostIP, &addrinfo.sin6_addr); + return addrinfo; + } + + SimpleUDP::SimpleUDP(int addressFamily, size_t sendBufSize, size_t receiveBufSize, bool nonBlocking): + m_addressFamily(addressFamily), m_connected(false), m_socket(-1), m_port(0) + { + CreateSocket(addressFamily, sendBufSize, receiveBufSize, nonBlocking); + } + + SimpleUDP::SimpleUDP(bool IPV4, size_t sendBufSize, size_t receiveBufSize, bool nonBlocking) : + m_addressFamily(IPV4 ? AF_INET : AF_INET6), m_connected(false), m_socket(-1), m_port(0) + { + CreateSocket(m_addressFamily, sendBufSize, receiveBufSize, nonBlocking); + } + + SimpleUDP::SimpleUDP(const char* host, unsigned short port, size_t sendBufSize, size_t receiveBufSize, bool nonBlocking) : + m_addressFamily(AF_INET), m_connected(false), m_socket(-1), m_port(port) + { + if (IsValidIPAddress(host, AF_INET)) + { + m_addressFamily = AF_INET; + m_hostIP = Aws::String(host); + } + else if (IsValidIPAddress(host, AF_INET6)) + { + m_addressFamily = AF_INET6; + m_hostIP = Aws::String(host); + } + else + { + char sockAddrBuffer[100]; + char hostBuffer[100]; + size_t addrLength = 0; + if (GetASockAddrFromHostName(host, (void*)sockAddrBuffer, addrLength, m_addressFamily)) + { + if (m_addressFamily == AF_INET) + { + struct sockaddr_in* sockaddr = reinterpret_cast<struct sockaddr_in*>(sockAddrBuffer); + inet_ntop(m_addressFamily, &(sockaddr->sin_addr), hostBuffer, sizeof(hostBuffer)); + } + else + { + struct sockaddr_in6* sockaddr = reinterpret_cast<struct sockaddr_in6*>(sockAddrBuffer); + inet_ntop(m_addressFamily, &(sockaddr->sin6_addr), hostBuffer, sizeof(hostBuffer)); + } + m_hostIP = Aws::String(hostBuffer); + } + else + { + AWS_LOGSTREAM_ERROR(ALLOC_TAG, "Can't retrieve a valid ip address based on provided host: " << host); + } + } + CreateSocket(m_addressFamily, sendBufSize, receiveBufSize, nonBlocking); + } + + SimpleUDP::~SimpleUDP() + { + close(GetUnderlyingSocket()); + } + + void SimpleUDP::CreateSocket(int addressFamily, size_t sendBufSize, size_t receiveBufSize, bool nonBlocking) + { + int sock = socket(addressFamily, SOCK_DGRAM, IPPROTO_UDP); + assert(sock != -1); + + // Try to set sock to nonblocking mode. + if (nonBlocking) + { + int flags = fcntl(sock, F_GETFL, 0); + if (flags != -1) + { + flags |= O_NONBLOCK; + fcntl(sock, F_SETFL, flags); + } + } + + // if sendBufSize is not zero, try to set send buffer size + if (sendBufSize) + { + int ret = setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &sendBufSize, sizeof(sendBufSize)); + if (ret) + { + AWS_LOGSTREAM_WARN(ALLOC_TAG, "Failed to set UDP send buffer size to " << sendBufSize << " for socket " << sock << " error message: " << strerror(errno)); + } + assert(ret == 0); + } + + // if receiveBufSize is not zero, try to set receive buffer size + if (receiveBufSize) + { + int ret = setsockopt(sock, SOL_SOCKET, SO_RCVBUF, &receiveBufSize, sizeof(receiveBufSize)); + if (ret) + { + AWS_LOGSTREAM_WARN(ALLOC_TAG, "Failed to set UDP receive buffer size to " << receiveBufSize << " for socket " << sock << " error message: " << strerror(errno)); + } + assert(ret == 0); + } + + SetUnderlyingSocket(sock); + } + + int SimpleUDP::Connect(const sockaddr* address, size_t addressLength) + { + int ret = connect(GetUnderlyingSocket(), address, static_cast<socklen_t>(addressLength)); + m_connected = ret ? false : true; + return ret; + } + + int SimpleUDP::ConnectToHost(const char* hostIP, unsigned short port) const + { + int ret; + if (m_addressFamily == AF_INET6) + { + sockaddr_in6 addrinfo = BuildAddrInfoIPV6(hostIP, port); + ret = connect(GetUnderlyingSocket(), reinterpret_cast<sockaddr*>(&addrinfo), sizeof(sockaddr_in6)); + } + else + { + sockaddr_in addrinfo = BuildAddrInfoIPV4(hostIP, port); + ret = connect(GetUnderlyingSocket(), reinterpret_cast<sockaddr*>(&addrinfo), sizeof(sockaddr_in)); + } + m_connected = ret ? false : true; + return ret; + } + + int SimpleUDP::ConnectToLocalHost(unsigned short port) const + { + if (m_addressFamily == AF_INET6) + { + return ConnectToHost(IPV6_LOOP_BACK_ADDRESS, port); + } + else + { + return ConnectToHost(IPV4_LOOP_BACK_ADDRESS, port); + } + } + + int SimpleUDP::Bind(const sockaddr* address, size_t addressLength) const + { + return bind(GetUnderlyingSocket(), address, static_cast<socklen_t>(addressLength)); + } + + int SimpleUDP::BindToLocalHost(unsigned short port) const + { + if (m_addressFamily == AF_INET6) + { + sockaddr_in6 addrinfo = BuildAddrInfoIPV6(IPV6_LOOP_BACK_ADDRESS, port); + return bind(GetUnderlyingSocket(), reinterpret_cast<sockaddr*>(&addrinfo), sizeof(sockaddr_in6)); + } + else + { + sockaddr_in addrinfo = BuildAddrInfoIPV4(IPV4_LOOP_BACK_ADDRESS, port); + return bind(GetUnderlyingSocket(), reinterpret_cast<sockaddr*>(&addrinfo), sizeof(sockaddr_in)); + } + } + + int SimpleUDP::SendData(const uint8_t* data, size_t dataLen) const + { + if (!m_connected) + { + ConnectToHost(m_hostIP.c_str(), m_port); + } + return send(GetUnderlyingSocket(), data, dataLen, 0); + } + + int SimpleUDP::SendDataTo(const sockaddr* address, size_t addressLength, const uint8_t* data, size_t dataLen) const + { + if (m_connected) + { + return send(GetUnderlyingSocket(), data, dataLen, 0); + } + else + { + return sendto(GetUnderlyingSocket(), data, dataLen, 0, address, static_cast<socklen_t>(addressLength)); + } + } + + int SimpleUDP::SendDataToLocalHost(const uint8_t* data, size_t dataLen, unsigned short port) const + { + if (m_connected) + { + return send(GetUnderlyingSocket(), data, dataLen, 0); + } + else if (m_addressFamily == AF_INET6) + { + sockaddr_in6 addrinfo = BuildAddrInfoIPV6(IPV6_LOOP_BACK_ADDRESS, port); + return sendto(GetUnderlyingSocket(), data, dataLen, 0, reinterpret_cast<sockaddr*>(&addrinfo), sizeof(sockaddr_in6)); + } + else + { + sockaddr_in addrinfo = BuildAddrInfoIPV4(IPV4_LOOP_BACK_ADDRESS, port); + return sendto(GetUnderlyingSocket(), data, dataLen, 0, reinterpret_cast<sockaddr*>(&addrinfo), sizeof(sockaddr_in)); + } + } + + int SimpleUDP::ReceiveData(uint8_t* buffer, size_t bufferLen) const + { + return recv(GetUnderlyingSocket(), buffer, static_cast<int>(bufferLen), 0); + } + + + int SimpleUDP::ReceiveDataFrom(sockaddr* address, size_t* addressLength, uint8_t* buffer, size_t bufferLen) const + { + return recvfrom(GetUnderlyingSocket(), buffer, static_cast<int>(bufferLen), 0, address, reinterpret_cast<socklen_t*>(addressLength)); + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Environment.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Environment.cpp new file mode 100644 index 0000000000..ee627340bb --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Environment.cpp @@ -0,0 +1,23 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/platform/Environment.h> + +//#include <aws/core/utils/memory/stl/AWSStringStream.h> +//#include <sys/utsname.h> + +namespace Aws +{ +namespace Environment +{ + +Aws::String GetEnv(const char* variableName) +{ + auto variableValue = std::getenv(variableName); + return Aws::String( variableValue ? variableValue : "" ); +} + +} // namespace Environment +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/FileSystem.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/FileSystem.cpp new file mode 100644 index 0000000000..c1ad818911 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/FileSystem.cpp @@ -0,0 +1,292 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/platform/FileSystem.h> + +#include <aws/core/platform/Environment.h> +#include <aws/core/platform/Platform.h> +#include <aws/core/utils/DateTime.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/UUID.h> + +#include <unistd.h> +#include <pwd.h> +#include <sys/stat.h> +#include <dirent.h> +#include <errno.h> +#include <climits> + +#include <cassert> +#ifdef __APPLE__ +#include <mach-o/dyld.h> +#endif +namespace Aws +{ +namespace FileSystem +{ + +static const char* FILE_SYSTEM_UTILS_LOG_TAG = "FileSystemUtils"; + + class PosixDirectory : public Directory + { + public: + PosixDirectory(const Aws::String& path, const Aws::String& relativePath) : Directory(path, relativePath), m_dir(nullptr) + { + m_dir = opendir(m_directoryEntry.path.c_str()); + AWS_LOGSTREAM_TRACE(FILE_SYSTEM_UTILS_LOG_TAG, "Entering directory " << m_directoryEntry.path); + + if(m_dir) + { + AWS_LOGSTREAM_TRACE(FILE_SYSTEM_UTILS_LOG_TAG, "Successfully opened directory " << m_directoryEntry.path); + m_directoryEntry.fileType = FileType::Directory; + } + else + { + AWS_LOGSTREAM_ERROR(FILE_SYSTEM_UTILS_LOG_TAG, "Could not load directory " << m_directoryEntry.path << " with error code " << errno); + } + } + + ~PosixDirectory() + { + if (m_dir) + { + closedir(m_dir); + } + } + + operator bool() const override { return m_directoryEntry.operator bool() && m_dir != nullptr; } + + DirectoryEntry Next() override + { + assert(m_dir); + DirectoryEntry entry; + + dirent* dirEntry; + bool invalidEntry(true); + + while(invalidEntry) + { + if ((dirEntry = readdir(m_dir))) + { + Aws::String entryName = dirEntry->d_name; + if(entryName != ".." && entryName != ".") + { + entry = ParseFileInfo(dirEntry, true); + invalidEntry = false; + } + else + { + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "skipping . or .."); + } + } + else + { + break; + } + } + + return entry; + } + + private: + DirectoryEntry ParseFileInfo(dirent* dirEnt, bool computePath) + { + DirectoryEntry entry; + + if(computePath) + { + Aws::StringStream ss; + ss << m_directoryEntry.path << PATH_DELIM << dirEnt->d_name; + entry.path = ss.str(); + + ss.str(""); + if(m_directoryEntry.relativePath.empty()) + { + ss << dirEnt->d_name; + } + else + { + ss << m_directoryEntry.relativePath << PATH_DELIM << dirEnt->d_name; + } + entry.relativePath = ss.str(); + } + else + { + entry.path = m_directoryEntry.path; + entry.relativePath = m_directoryEntry.relativePath; + } + + AWS_LOGSTREAM_TRACE(FILE_SYSTEM_UTILS_LOG_TAG, "Calling stat on path " << entry.path); + + struct stat dirInfo; + if(!lstat(entry.path.c_str(), &dirInfo)) + { + if(S_ISDIR(dirInfo.st_mode)) + { + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "type directory detected"); + entry.fileType = FileType::Directory; + } + else if(S_ISLNK(dirInfo.st_mode)) + { + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "type symlink detected"); + entry.fileType = FileType::Symlink; + } + else if(S_ISREG(dirInfo.st_mode)) + { + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "type file detected"); + entry.fileType = FileType::File; + } + + entry.fileSize = static_cast<int64_t>(dirInfo.st_size); + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "file size detected as " << entry.fileSize); + } + else + { + AWS_LOGSTREAM_ERROR(FILE_SYSTEM_UTILS_LOG_TAG, "Failed to stat file path " << entry.path << " with error code " << errno); + } + + return entry; + } + + DIR* m_dir; + }; + +Aws::String GetHomeDirectory() +{ + static const char* HOME_DIR_ENV_VAR = "HOME"; + + AWS_LOGSTREAM_TRACE(FILE_SYSTEM_UTILS_LOG_TAG, "Checking " << HOME_DIR_ENV_VAR << " for the home directory."); + + Aws::String homeDir = Aws::Environment::GetEnv(HOME_DIR_ENV_VAR); + + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "Environment value for variable " << HOME_DIR_ENV_VAR << " is " << homeDir); + + if(homeDir.empty()) + { + AWS_LOGSTREAM_WARN(FILE_SYSTEM_UTILS_LOG_TAG, "Home dir not stored in environment, trying to fetch manually from the OS."); + + passwd pw; + passwd *p_pw = nullptr; + char pw_buffer[4096]; + getpwuid_r(getuid(), &pw, pw_buffer, sizeof(pw_buffer), &p_pw); + if(p_pw && p_pw->pw_dir) + { + homeDir = p_pw->pw_dir; + } + + AWS_LOGSTREAM_INFO(FILE_SYSTEM_UTILS_LOG_TAG, "Pulled " << homeDir << " as home directory from the OS."); + } + + Aws::String retVal = homeDir.size() > 0 ? Aws::Utils::StringUtils::Trim(homeDir.c_str()) : ""; + if(!retVal.empty()) + { + if(retVal.at(retVal.length() - 1) != PATH_DELIM) + { + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "Home directory is missing the final " << PATH_DELIM << " appending one to normalize"); + retVal += PATH_DELIM; + } + } + + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "Final Home Directory is " << retVal); + + return retVal; +} + +bool CreateDirectoryIfNotExists(const char* path, bool createParentDirs) +{ + Aws::String directoryName = path; + AWS_LOGSTREAM_INFO(FILE_SYSTEM_UTILS_LOG_TAG, "Creating directory " << directoryName); + + for (size_t i = (createParentDirs ? 0 : directoryName.size() - 1); i < directoryName.size(); i++) + { + // Create the parent directory if we find a delimiter and the delimiter is not the first char, or if this is the target directory. + if (i != 0 && (directoryName[i] == FileSystem::PATH_DELIM || i == directoryName.size() - 1)) + { + if (directoryName[i] == FileSystem::PATH_DELIM) + { + directoryName[i] = '\0'; + } + int errorCode = mkdir(directoryName.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); + if (errorCode != 0 && errno != EEXIST) + { + AWS_LOGSTREAM_ERROR(FILE_SYSTEM_UTILS_LOG_TAG, "Creation of directory " << directoryName.c_str() << " returned code: " << errno); + return false; + } + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "Creation of directory " << directoryName.c_str() << " returned code: " << errno); + directoryName[i] = FileSystem::PATH_DELIM; + } + } + return true; +} + +bool RemoveFileIfExists(const char* path) +{ + AWS_LOGSTREAM_INFO(FILE_SYSTEM_UTILS_LOG_TAG, "Deleting file: " << path); + + int errorCode = unlink(path); + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "Deletion of file: " << path << " Returned error code: " << errno); + return errorCode == 0 || errno == ENOENT; +} + +bool RemoveDirectoryIfExists(const char* path) +{ + AWS_LOGSTREAM_INFO(FILE_SYSTEM_UTILS_LOG_TAG, "Deleting directory: " << path); + int errorCode = rmdir(path); + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "Deletion of directory: " << path << " Returned error code: " << errno); + return errorCode == 0 || errno == ENOTDIR || errno == ENOENT; +} + +bool RelocateFileOrDirectory(const char* from, const char* to) +{ + AWS_LOGSTREAM_INFO(FILE_SYSTEM_UTILS_LOG_TAG, "Moving file at " << from << " to " << to); + + int errorCode = std::rename(from, to); + + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "The moving operation of file at " << from << " to " << to << " Returned error code of " << errno); + return errorCode == 0; +} + +Aws::String CreateTempFilePath() +{ + Aws::StringStream ss; + auto dt = Aws::Utils::DateTime::Now(); + + ss << dt.ToGmtString("%Y%m%dT%H%M%S") << dt.Millis() << Aws::String(Aws::Utils::UUID::RandomUUID()); + Aws::String tempFile(ss.str()); + + AWS_LOGSTREAM_DEBUG(FILE_SYSTEM_UTILS_LOG_TAG, "CreateTempFilePath generated: " << tempFile); + + return tempFile; +} + +Aws::String GetExecutableDirectory() +{ + char dest[PATH_MAX]; + memset(dest, 0, PATH_MAX); +#ifdef __APPLE__ + uint32_t destSize = PATH_MAX; + if (_NSGetExecutablePath(dest, &destSize) == 0) +#else + size_t destSize = PATH_MAX; + if (readlink("/proc/self/exe", dest, destSize)) +#endif + { + Aws::String executablePath(dest); + auto lastSlash = executablePath.find_last_of('/'); + if(lastSlash != std::string::npos) + { + return executablePath.substr(0, lastSlash); + } + } + return "./"; +} + +Aws::UniquePtr<Directory> OpenDirectory(const Aws::String& path, const Aws::String& relativePath) +{ + return Aws::MakeUnique<PosixDirectory>(FILE_SYSTEM_UTILS_LOG_TAG, path, relativePath); +} + +} // namespace FileSystem +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/OSVersionInfo.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/OSVersionInfo.cpp new file mode 100644 index 0000000000..040173a2e5 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/OSVersionInfo.cpp @@ -0,0 +1,59 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/platform/OSVersionInfo.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/StringUtils.h> +#include <sys/utsname.h> + +namespace Aws +{ +namespace OSVersionInfo +{ + +Aws::String GetSysCommandOutput(const char* command) +{ + Aws::String outputStr; + FILE* outputStream; + const int maxBufferSize = 256; + char outputBuffer[maxBufferSize]; + + outputStream = popen(command, "r"); + + if (outputStream) + { + while (!feof(outputStream)) + { + if (fgets(outputBuffer, maxBufferSize, outputStream) != nullptr) + { + outputStr.append(outputBuffer); + } + } + + pclose(outputStream); + + return Aws::Utils::StringUtils::Trim(outputStr.c_str()); + } + + return {}; +} + + +Aws::String ComputeOSVersionString() +{ + utsname name; + int32_t success = uname(&name); + if(success >= 0) + { + Aws::StringStream ss; + ss << name.sysname << "/" << name.release << " " << name.machine; + return ss.str(); + } + + return "non-windows/unknown"; +} + +} // namespace OSVersionInfo +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Security.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Security.cpp new file mode 100644 index 0000000000..286de1a948 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Security.cpp @@ -0,0 +1,26 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/platform/Security.h> + +#include <string.h> + +namespace Aws +{ +namespace Security +{ + +void SecureMemClear(unsigned char *data, size_t length) +{ +#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || defined(__bsdi__) || defined(__DragonFly__) + memset_s(data, length, 0, length); +#else + memset(data, 0, length); + asm volatile("" : "+m" (data)); +#endif +} + +} // namespace Security +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Time.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Time.cpp new file mode 100644 index 0000000000..7a0d3d1c0a --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/platform/linux-shared/Time.cpp @@ -0,0 +1,31 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/platform/Time.h> + +#include <time.h> + +namespace Aws +{ +namespace Time +{ + +time_t TimeGM(struct tm* const t) +{ + return timegm(t); +} + +void LocalTime(tm* t, std::time_t time) +{ + localtime_r(&time, t); +} + +void GMTime(tm* t, std::time_t time) +{ + gmtime_r(&time, t); +} + +} // namespace Time +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/ARN.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/ARN.cpp new file mode 100644 index 0000000000..dac358c09d --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/ARN.cpp @@ -0,0 +1,46 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/ARN.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/logging/LogMacros.h> + +namespace Aws +{ + namespace Utils + { + ARN::ARN(const Aws::String& arnString) + { + m_valid = false; + + // An ARN can be identified as any string starting with arn: with 6 defined segments each separated by a : + const auto result = StringUtils::Split(arnString, ':', StringUtils::SplitOptions::INCLUDE_EMPTY_ENTRIES); + + if (result.size() < 6) + { + return; + } + + if (result[0] != "arn") + { + return; + } + + m_arnString = arnString; + m_partition = result[1]; + m_service = result[2]; + m_region = result[3]; + m_accountId = result[4]; + m_resource = result[5]; + + for (size_t i = 6; i < result.size(); i++) + { + m_resource += ":" + result[i]; + } + + m_valid = true; + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/Array.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/Array.cpp new file mode 100644 index 0000000000..43e7863421 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/Array.cpp @@ -0,0 +1,65 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/Array.h> + +#include <aws/core/platform/Security.h> + +namespace Aws +{ + namespace Utils + { + Array<CryptoBuffer> CryptoBuffer::Slice(size_t sizeOfSlice) const + { + assert(sizeOfSlice <= GetLength()); + + size_t numberOfSlices = (GetLength() + sizeOfSlice - 1) / sizeOfSlice; + size_t currentSliceIndex = 0; + Array<CryptoBuffer> slices(numberOfSlices); + + for (size_t i = 0; i < numberOfSlices - 1; ++i) + { + CryptoBuffer newArray(sizeOfSlice); + for (size_t cpyIdx = 0; cpyIdx < newArray.GetLength(); ++cpyIdx) + { + newArray[cpyIdx] = GetItem(cpyIdx + currentSliceIndex); + } + currentSliceIndex += sizeOfSlice; + slices[i] = std::move(newArray); + } + + CryptoBuffer lastArray(GetLength() % sizeOfSlice == 0 ? sizeOfSlice : GetLength() % sizeOfSlice ); + for (size_t cpyIdx = 0; cpyIdx < lastArray.GetLength(); ++cpyIdx) + { + lastArray[cpyIdx] = GetItem(cpyIdx + currentSliceIndex); + } + slices[slices.GetLength() - 1] = std::move(lastArray); + + return slices; + } + + CryptoBuffer& CryptoBuffer::operator^(const CryptoBuffer& operand) + { + size_t smallestSize = std::min<size_t>(GetLength(), operand.GetLength()); + for (size_t i = 0; i < smallestSize; ++i) + { + (*this)[i] ^= operand[i]; + } + + return *this; + } + + /** + * Zero out the array securely + */ + void CryptoBuffer::Zero() + { + if (GetUnderlyingData()) + { + Aws::Security::SecureMemClear(GetUnderlyingData(), GetLength()); + } + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/DNS.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/DNS.cpp new file mode 100644 index 0000000000..ce588150e2 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/DNS.cpp @@ -0,0 +1,55 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/DNS.h> +#include <aws/core/utils/StringUtils.h> + +namespace Aws +{ + namespace Utils + { + bool IsValidDnsLabel(const Aws::String& label) + { + // Valid DNS hostnames are composed of valid DNS labels separated by a period. + // Valid DNS labels are characterized by the following: + // 1- Only contains alphanumeric characters and/or dashes + // 2- Cannot start or end with a dash + // 3- Have a maximum length of 63 characters (the entirety of the domain name should be less than 255 bytes) + + if (label.empty()) + return false; + + if (label.size() > 63) + return false; + + if (!StringUtils::IsAlnum(label.front())) + return false; // '-' is not acceptable as the first character + + if (!StringUtils::IsAlnum(label.back())) + return false; // '-' is not acceptable as the last character + + for (size_t i = 1, e = label.size() - 1; i < e; ++i) + { + auto c = label[i]; + if (c != '-' && !StringUtils::IsAlnum(c)) + return false; + } + + return true; + } + + bool IsValidHost(const Aws::String& host) + { + // Valid DNS hostnames are composed of valid DNS labels separated by a period. + auto labels = StringUtils::Split(host, '.'); + if (labels.empty()) + { + return false; + } + + return !std::any_of(labels.begin(), labels.end(), [](const Aws::String& label){ return !IsValidDnsLabel(label); }); + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/DateTimeCommon.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/DateTimeCommon.cpp new file mode 100644 index 0000000000..b690c90c2d --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/DateTimeCommon.cpp @@ -0,0 +1,1502 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/DateTime.h> + +#include <aws/core/platform/Time.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <time.h> +#include <cassert> +#include <iostream> +#include <cstring> + +static const char* CLASS_TAG = "DateTime"; +static const char* RFC822_DATE_FORMAT_STR_MINUS_Z = "%a, %d %b %Y %H:%M:%S"; +static const char* RFC822_DATE_FORMAT_STR_WITH_Z = "%a, %d %b %Y %H:%M:%S %Z"; +static const char* ISO_8601_LONG_DATE_FORMAT_STR = "%Y-%m-%dT%H:%M:%SZ"; +static const char* ISO_8601_LONG_BASIC_DATE_FORMAT_STR = "%Y%m%dT%H%M%SZ"; + +using namespace Aws::Utils; + + +std::tm CreateZeroedTm() +{ + std::tm timeStruct; + timeStruct.tm_hour = 0; + timeStruct.tm_isdst = -1; + timeStruct.tm_mday = 0; + timeStruct.tm_min = 0; + timeStruct.tm_mon = 0; + timeStruct.tm_sec = 0; + timeStruct.tm_wday = 0; + timeStruct.tm_yday = 0; + timeStruct.tm_year = 0; + + return timeStruct; +} + +//Get the 0-6 week day number from a string representing WeekDay. Case insensitive and will stop on abbreviation +static int GetWeekDayNumberFromStr(const char* timeString, size_t startIndex, size_t stopIndex) +{ + if(stopIndex - startIndex < 3) + { + return -1; + } + + size_t index = startIndex; + + char c = timeString[index]; + char next = 0; + + //it's ugly but this should compile down to EXACTLY 3 comparisons and no memory allocations + switch(c) + { + case 'S': + case 's': + next = timeString[++index]; + switch(next) + { + case 'A': + case 'a': + next = timeString[++index]; + switch (next) + { + case 'T': + case 't': + return 6; + default: + return -1; + } + case 'U': + case 'u': + next = timeString[++index]; + switch (next) + { + case 'N': + case 'n': + return 0; + default: + return -1; + } + default: + return -1; + } + case 'T': + case 't': + next = timeString[++index]; + switch (next) + { + case 'H': + case 'h': + next = timeString[++index]; + switch(next) + { + case 'U': + case 'u': + return 4; + default: + return -1; + } + case 'U': + case 'u': + next = timeString[++index]; + switch(next) + { + case 'E': + case 'e': + return 2; + default: + return -1; + } + default: + return -1; + } + case 'M': + case 'm': + next = timeString[++index]; + switch(next) + { + case 'O': + case 'o': + next = timeString[++index]; + switch (next) + { + case 'N': + case 'n': + return 1; + default: + return -1; + } + default: + return -1; + } + case 'W': + case 'w': + next = timeString[++index]; + switch (next) + { + case 'E': + case 'e': + next = timeString[++index]; + switch (next) + { + case 'D': + case 'd': + return 3; + default: + return -1; + } + default: + return -1; + } + case 'F': + case 'f': + next = timeString[++index]; + switch (next) + { + case 'R': + case 'r': + next = timeString[++index]; + switch (next) + { + case 'I': + case 'i': + return 5; + default: + return -1; + } + default: + return -1; + } + default: + return -1; + } +} + +//Get the 0-11 monthy number from a string representing Month. Case insensitive and will stop on abbreviation +static int GetMonthNumberFromStr(const char* timeString, size_t startIndex, size_t stopIndex) +{ + if (stopIndex - startIndex < 3) + { + return -1; + } + + size_t index = startIndex; + + char c = timeString[index]; + char next = 0; + + //it's ugly but this should compile down to EXACTLY 3 comparisons and no memory allocations + switch (c) + { + case 'M': + case 'm': + next = timeString[++index]; + switch (next) + { + case 'A': + case 'a': + next = timeString[++index]; + switch (next) + { + case 'Y': + case 'y': + return 4; + case 'R': + case 'r': + return 2; + default: + return -1; + } + default: + return -1; + } + case 'A': + case 'a': + next = timeString[++index]; + switch (next) + { + case 'P': + case 'p': + next = timeString[++index]; + switch (next) + { + case 'R': + case 'r': + return 3; + default: + return -1; + } + case 'U': + case 'u': + next = timeString[++index]; + switch (next) + { + case 'G': + case 'g': + return 7; + default: + return -1; + } + default: + return -1; + } + case 'J': + case 'j': + next = timeString[++index]; + switch (next) + { + case 'A': + case 'a': + next = timeString[++index]; + switch (next) + { + case 'N': + case 'n': + return 0; + default: + return -1; + } + case 'U': + case 'u': + next = timeString[++index]; + switch (next) + { + case 'N': + case 'n': + return 5; + case 'L': + case 'l': + return 6; + default: + return -1; + } + default: + return -1; + } + case 'F': + case 'f': + next = timeString[++index]; + switch (next) + { + case 'E': + case 'e': + next = timeString[++index]; + switch (next) + { + case 'B': + case 'b': + return 1; + default: + return -1; + } + default: + return -1; + } + case 'S': + case 's': + next = timeString[++index]; + switch (next) + { + case 'E': + case 'e': + next = timeString[++index]; + switch (next) + { + case 'P': + case 'p': + return 8; + default: + return -1; + } + default: + return -1; + } + case 'O': + case 'o': + next = timeString[++index]; + switch (next) + { + case 'C': + case 'c': + next = timeString[++index]; + switch (next) + { + case 'T': + case 't': + return 9; + default: + return -1; + } + default: + return -1; + } + case 'N': + case 'n': + next = timeString[++index]; + switch (next) + { + case 'O': + case 'o': + next = timeString[++index]; + switch (next) + { + case 'V': + case 'v': + return 10; + default: + return -1; + } + default: + return -1; + } + case 'D': + case 'd': + next = timeString[++index]; + switch (next) + { + case 'E': + case 'e': + next = timeString[++index]; + switch (next) + { + case 'C': + case 'c': + return 11; + default: + return -1; + } + default: + return -1; + } + default: + return -1; + } +} +// Ensure local classes with generic names have internal linkage +namespace { + +class DateParser +{ +public: + DateParser(const char* toParse) : m_error(false), m_toParse(toParse), m_utcAssumed(true) + { + m_parsedTimestamp = CreateZeroedTm(); + memset(m_tz, 0, 7); + } + + virtual ~DateParser() = default; + + virtual void Parse() = 0; + bool WasParseSuccessful() const { return !m_error; } + std::tm& GetParsedTimestamp() { return m_parsedTimestamp; } + bool ShouldIAssumeThisIsUTC() const { return m_utcAssumed; } + const char* GetParsedTimezone() const { return m_tz; } + +protected: + bool m_error; + const char* m_toParse; + std::tm m_parsedTimestamp; + bool m_utcAssumed; + // The size should be at least one byte greater than the maximum possible size so that we could use the last char to indicate the end of the string. + char m_tz[7]; +}; + +static const int MAX_LEN = 100; + +//Before you send me hate mail because I'm doing this manually, I encourage you to try using std::get_time on all platforms and getting +//uniform results. Timezone information doesn't parse on Windows and it hardly even works on GCC 4.9.x. This is the only way to make sure +//the standard is parsed correctly. strptime isn't available one Windows. This code gets hit pretty hard during http serialization/deserialization +//as a result I'm going for no dynamic allocations and linear complexity +class RFC822DateParser : public DateParser +{ +public: + RFC822DateParser(const char* toParse) : DateParser(toParse), m_state(0) + { + } + + /** + * Really simple state machine for the format %a, %d %b %Y %H:%M:%S %Z + */ + void Parse() override + { + size_t len = strlen(m_toParse); + + //DOS check + if (len > MAX_LEN) + { + AWS_LOGSTREAM_WARN(CLASS_TAG, "Incoming String to parse too long with length: " << len) + m_error = true; + return; + } + + size_t index = 0; + size_t stateStartIndex = 0; + int finalState = 8; + + while(m_state <= finalState && !m_error && index < len) + { + char c = m_toParse[index]; + + switch (m_state) + { + case 0: + if(c == ',') + { + int weekNumber = GetWeekDayNumberFromStr(m_toParse, stateStartIndex, index + 1); + + if (weekNumber > -1) + { + m_state = 1; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_wday = weekNumber; + } + else + { + m_error = true; + } + } + else if(!isalpha(c)) + { + m_error = true; + } + break; + case 1: + if (isspace(c)) + { + m_state = 2; + stateStartIndex = index + 1; + } + else + { + m_error = true; + } + break; + case 2: + if (isdigit(c)) + { + m_parsedTimestamp.tm_mday = m_parsedTimestamp.tm_mday * 10 + (c - '0'); + } + else if(isspace(c)) + { + m_state = 3; + stateStartIndex = index + 1; + } + else + { + m_error = true; + } + break; + case 3: + if (isspace(c)) + { + int monthNumber = GetMonthNumberFromStr(m_toParse, stateStartIndex, index + 1); + + if (monthNumber > -1) + { + m_state = 4; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_mon = monthNumber; + } + else + { + m_error = true; + } + } + else if (!isalpha(c)) + { + m_error = true; + } + break; + case 4: + if (isspace(c) && index - stateStartIndex == 4) + { + m_state = 5; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_year -= 1900; + } + else if (isspace(c) && index - stateStartIndex == 2) + { + m_state = 5; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_year += 2000 - 1900; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_year = m_parsedTimestamp.tm_year * 10 + (c - '0'); + } + else + { + m_error = true; + } + break; + case 5: + if(c == ':' && index - stateStartIndex == 2) + { + m_state = 6; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_hour = m_parsedTimestamp.tm_hour * 10 + (c - '0'); + } + else + { + m_error = true; + } + break; + case 6: + if (c == ':' && index - stateStartIndex == 2) + { + m_state = 7; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_min = m_parsedTimestamp.tm_min * 10 + (c - '0'); + } + else + { + m_error = true; + } + break; + case 7: + if (isspace(c) && index - stateStartIndex == 2) + { + m_state = 8; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_sec = m_parsedTimestamp.tm_sec * 10 + (c - '0'); + } + else + { + m_error = true; + } + break; + case 8: + if ((isalnum(c) || c == '+' || c == '-') && (index - stateStartIndex < 5)) + { + m_tz[index - stateStartIndex] = c; + } + else + { + m_error = true; + } + break; + default: + m_error = true; + break; + } + + index++; + } + + if (m_tz[0] != 0) + { + m_utcAssumed = IsUTCTimeZoneDesignator(m_tz); + } + + m_error = (m_error || m_state != finalState); + } + + int GetState() const { return m_state; } + +private: + //Detects whether or not the passed in timezone string is a UTC zone. + static bool IsUTCTimeZoneDesignator(const char* str) + { + size_t len = strlen(str); + if (len < 3) + { + return false; + } + + int index = 0; + char c = str[index]; + switch (c) + { + case 'U': + case 'u': + c = str[++index]; + switch(c) + { + case 'T': + case 't': + c = str[++index]; + switch(c) + { + case 'C': + case 'c': + return true; + default: + return false; + } + + case 'C': + case 'c': + c = str[++index]; + switch (c) + { + case 'T': + case 't': + return true; + default: + return false; + } + default: + return false; + } + case 'G': + case 'g': + c = str[++index]; + switch (c) + { + case 'M': + case 'm': + c = str[++index]; + switch (c) + { + case 'T': + case 't': + return true; + default: + return false; + } + default: + return false; + } + case '+': + case '-': + c = str[++index]; + switch (c) + { + case '0': + c = str[++index]; + switch (c) + { + case '0': + c = str[++index]; + switch (c) + { + case '0': + return true; + default: + return false; + } + default: + return false; + } + default: + return false; + } + case 'Z': + return true; + default: + return false; + } + + } + + int m_state; +}; + +//Before you send me hate mail because I'm doing this manually, I encourage you to try using std::get_time on all platforms and getting +//uniform results. Timezone information doesn't parse on Windows and it hardly even works on GCC 4.9.x. This is the only way to make sure +//the standard is parsed correctly. strptime isn't available one Windows. This code gets hit pretty hard during http serialization/deserialization +//as a result I'm going for no dynamic allocations and linear complexity +class ISO_8601DateParser : public DateParser +{ +public: + ISO_8601DateParser(const char* stringToParse) : DateParser(stringToParse), m_state(0) + { + } + + //parses "%Y-%m-%dT%H:%M:%SZ or "%Y-%m-%dT%H:%M:%S.000Z" + void Parse() override + { + size_t len = strlen(m_toParse); + + //DOS check + if (len > MAX_LEN) + { + AWS_LOGSTREAM_WARN(CLASS_TAG, "Incoming String to parse too long with length: " << len) + m_error = true; + return; + } + + size_t index = 0; + size_t stateStartIndex = 0; + const int finalState = 7; + + while (m_state <= finalState && !m_error && index < len) + { + char c = m_toParse[index]; + switch (m_state) + { + case 0: + if (c == '-' && index - stateStartIndex == 4) + { + m_state = 1; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_year -= 1900; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_year = m_parsedTimestamp.tm_year * 10 + (c - '0'); + } + else + { + m_error = true; + } + break; + case 1: + if (c == '-' && index - stateStartIndex == 2) + { + m_state = 2; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_mon -= 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_mon = m_parsedTimestamp.tm_mon * 10 + (c - '0'); + } + else + { + m_error = true; + } + + break; + case 2: + if (c == 'T' && index - stateStartIndex == 2) + { + m_state = 3; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_mday = m_parsedTimestamp.tm_mday * 10 + (c - '0'); + } + else + { + m_error = true; + } + + break; + case 3: + if (c == ':' && index - stateStartIndex == 2) + { + m_state = 4; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_hour = m_parsedTimestamp.tm_hour * 10 + (c - '0'); + } + else + { + m_error = true; + } + + break; + case 4: + if (c == ':' && index - stateStartIndex == 2) + { + m_state = 5; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_min = m_parsedTimestamp.tm_min * 10 + (c - '0'); + } + else + { + m_error = true; + } + + break; + case 5: + if ((c == 'Z' || c == '+' || c == '-' ) && (index - stateStartIndex == 2)) + { + m_tz[0] = c; + m_state = 7; + stateStartIndex = index + 1; + } + else if (c == '.' && index - stateStartIndex == 2) + { + m_state = 6; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_sec = m_parsedTimestamp.tm_sec * 10 + (c - '0'); + } + else + { + m_error = true; + } + + break; + case 6: + if ((c == 'Z' || c == '+' || c == '-' ) && (index - stateStartIndex == 3)) + { + m_tz[0] = c; + m_state = 7; + stateStartIndex = index + 1; + } + else if(!isdigit(c)) + { + m_error = true; + } + break; + case 7: + if ((isdigit(c) || c == ':') && (index - stateStartIndex < 5)) + { + m_tz[1 + index - stateStartIndex] = c; + } + else + { + m_error = true; + } + break; + default: + m_error = true; + break; + } + index++; + } + + if (m_tz[0] != 0) + { + m_utcAssumed = IsUTCTimeZoneDesignator(m_tz); + } + + m_error = (m_error || m_state != finalState); + } + +private: + //Detects whether or not the passed in timezone string is a UTC zone. + static bool IsUTCTimeZoneDesignator(const char* str) + { + size_t len = strlen(str); + + if (len > 0) + { + if (len == 1 && str[0] == 'Z') + { + return true; + } + + if (len == 6 && str[0] == '+' + && str[1] == '0' + && str[2] == '0' + && str[3] == ':' + && str[4] == '0' + && str[5] == '0') + { + return true; + } + + return false; + } + + return false; + } + + int m_state; +}; + +class ISO_8601BasicDateParser : public DateParser +{ +public: + ISO_8601BasicDateParser(const char* stringToParse) : DateParser(stringToParse), m_state(0) + { + } + + //parses "%Y%m%dT%H%M%SZ or "%Y%m%dT%H%M%S000Z" + void Parse() override + { + size_t len = strlen(m_toParse); + + //DOS check + if (len > MAX_LEN) + { + AWS_LOGSTREAM_WARN(CLASS_TAG, "Incoming String to parse too long with length: " << len) + m_error = true; + return; + } + + size_t index = 0; + size_t stateStartIndex = 0; + const int finalState = 7; + + while (m_state <= finalState && !m_error && index < len) + { + char c = m_toParse[index]; + switch (m_state) + { + // On year: %Y + case 0: + if (isdigit(c)) + { + m_parsedTimestamp.tm_year = m_parsedTimestamp.tm_year * 10 + (c - '0'); + if (index - stateStartIndex == 3) + { + m_state = 1; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_year -= 1900; + } + } + else + { + m_error = true; + } + break; + // On month: %m + case 1: + if (isdigit(c)) + { + m_parsedTimestamp.tm_mon = m_parsedTimestamp.tm_mon * 10 + (c - '0'); + if (index - stateStartIndex == 1) + { + m_state = 2; + stateStartIndex = index + 1; + m_parsedTimestamp.tm_mon -= 1; + } + } + else + { + m_error = true; + } + break; + // On month day: %d + case 2: + if (c == 'T' && index - stateStartIndex == 2) + { + m_state = 3; + stateStartIndex = index + 1; + } + else if (isdigit(c)) + { + m_parsedTimestamp.tm_mday = m_parsedTimestamp.tm_mday * 10 + (c - '0'); + } + else + { + m_error = true; + } + break; + // On hour: %H + case 3: + if (isdigit(c)) + { + m_parsedTimestamp.tm_hour = m_parsedTimestamp.tm_hour * 10 + (c - '0'); + if (index - stateStartIndex == 1) + { + m_state = 4; + stateStartIndex = index + 1; + } + } + else + { + m_error = true; + } + break; + // On minute: %M + case 4: + if (isdigit(c)) + { + m_parsedTimestamp.tm_min = m_parsedTimestamp.tm_min * 10 + (c - '0'); + if (index - stateStartIndex == 1) + { + m_state = 5; + stateStartIndex = index + 1; + } + } + else + { + m_error = true; + } + break; + // On second: %S + case 5: + if (isdigit(c)) + { + m_parsedTimestamp.tm_sec = m_parsedTimestamp.tm_sec * 10 + (c - '0'); + if (index - stateStartIndex == 1) + { + m_state = 6; + stateStartIndex = index + 1; + } + } + else + { + m_error = true; + } + break; + // On TZ: Z or 000Z + case 6: + if ((c == 'Z' || c == '+' || c == '-' ) && (index - stateStartIndex == 0 || index - stateStartIndex == 3)) + { + m_tz[0] = c; + m_state = 7; + stateStartIndex = index + 1; + } + else if (!isdigit(c) || index - stateStartIndex > 3) + { + m_error = true; + } + break; + case 7: + if ((isdigit(c) || c == ':') && (index - stateStartIndex < 5)) + { + m_tz[1 + index - stateStartIndex] = c; + } + else + { + m_error = true; + } + break; + default: + m_error = true; + break; + } + index++; + } + + if (m_tz[0] != 0) + { + m_utcAssumed = IsUTCTimeZoneDesignator(m_tz); + } + + m_error = (m_error || m_state != finalState); + } + +private: + //Detects whether or not the passed in timezone string is a UTC zone. + static bool IsUTCTimeZoneDesignator(const char* str) + { + size_t len = strlen(str); + + if (len > 0) + { + if (len == 1 && str[0] == 'Z') + { + return true; + } + + if (len == 5 && str[0] == '+' + && str[1] == '0' + && str[2] == '0' + && str[3] == '0' + && str[4] == '0') + { + return true; + } + + return false; + } + + return false; + } + + int m_state; +}; + +} // namespace + +DateTime::DateTime(const std::chrono::system_clock::time_point& timepointToAssign) : m_time(timepointToAssign), m_valid(true) +{ +} + +DateTime::DateTime(int64_t millisSinceEpoch) : m_valid(true) +{ + std::chrono::duration<int64_t, std::chrono::milliseconds::period> timestamp(millisSinceEpoch); + m_time = std::chrono::system_clock::time_point(timestamp); +} + +DateTime::DateTime(double epoch_millis) : m_valid(true) +{ + std::chrono::duration<double, std::chrono::seconds::period> timestamp(epoch_millis); + m_time = std::chrono::system_clock::time_point(std::chrono::duration_cast<std::chrono::milliseconds>(timestamp)); +} + +DateTime::DateTime(const Aws::String& timestamp, DateFormat format) : m_valid(true) +{ + ConvertTimestampStringToTimePoint(timestamp.c_str(), format); +} + +DateTime::DateTime(const char* timestamp, DateFormat format) : m_valid(true) +{ + ConvertTimestampStringToTimePoint(timestamp, format); +} + +DateTime::DateTime() : m_valid(true) +{ + //init time_point to default by doing nothing. +} + +DateTime& DateTime::operator=(const Aws::String& timestamp) +{ + *this = DateTime(timestamp, DateFormat::AutoDetect); + return *this; +} + +DateTime& DateTime::operator=(double secondsMillis) +{ + *this = DateTime(secondsMillis); + return *this; +} + +DateTime& DateTime::operator=(int64_t millisSinceEpoch) +{ + *this = DateTime(millisSinceEpoch); + return *this; +} + +DateTime& DateTime::operator=(const std::chrono::system_clock::time_point& timepointToAssign) +{ + *this = DateTime(timepointToAssign); + return *this; +} + +bool DateTime::operator == (const DateTime& other) const +{ + return m_time == other.m_time; +} + +bool DateTime::operator < (const DateTime& other) const +{ + return m_time < other.m_time; +} + +bool DateTime::operator > (const DateTime& other) const +{ + return m_time > other.m_time; +} + +bool DateTime::operator != (const DateTime& other) const +{ + return m_time != other.m_time; +} + +bool DateTime::operator <= (const DateTime& other) const +{ + return m_time <= other.m_time; +} + +bool DateTime::operator >= (const DateTime& other) const +{ + return m_time >= other.m_time; +} + +DateTime DateTime::operator +(const std::chrono::milliseconds& a) const +{ + auto timepointCpy = m_time; + timepointCpy += a; + return DateTime(timepointCpy); +} + +DateTime DateTime::operator -(const std::chrono::milliseconds& a) const +{ + auto timepointCpy = m_time; + timepointCpy -= a; + return DateTime(timepointCpy); +} + +Aws::String DateTime::ToLocalTimeString(DateFormat format) const +{ + switch (format) + { + case DateFormat::ISO_8601: + return ToLocalTimeString(ISO_8601_LONG_DATE_FORMAT_STR); + case DateFormat::ISO_8601_BASIC: + return ToLocalTimeString(ISO_8601_LONG_BASIC_DATE_FORMAT_STR); + case DateFormat::RFC822: + return ToLocalTimeString(RFC822_DATE_FORMAT_STR_WITH_Z); + default: + assert(0); + return ""; + } +} + +Aws::String DateTime::ToLocalTimeString(const char* formatStr) const +{ + struct tm localTimeStamp = ConvertTimestampToLocalTimeStruct(); + + char formattedString[100]; + std::strftime(formattedString, sizeof(formattedString), formatStr, &localTimeStamp); + return formattedString; +} + +Aws::String DateTime::ToGmtString(DateFormat format) const +{ + switch (format) + { + case DateFormat::ISO_8601: + return ToGmtString(ISO_8601_LONG_DATE_FORMAT_STR); + case DateFormat::ISO_8601_BASIC: + return ToGmtString(ISO_8601_LONG_BASIC_DATE_FORMAT_STR); + case DateFormat::RFC822: + { + //Windows erroneously drops the local timezone in for %Z + Aws::String rfc822GmtString = ToGmtString(RFC822_DATE_FORMAT_STR_MINUS_Z); + rfc822GmtString += " GMT"; + return rfc822GmtString; + } + default: + assert(0); + return ""; + } +} + +Aws::String DateTime::ToGmtString(const char* formatStr) const +{ + struct tm gmtTimeStamp = ConvertTimestampToGmtStruct(); + + char formattedString[100]; + std::strftime(formattedString, sizeof(formattedString), formatStr, &gmtTimeStamp); + return formattedString; +} + +double DateTime::SecondsWithMSPrecision() const +{ + std::chrono::duration<double, std::chrono::seconds::period> timestamp(m_time.time_since_epoch()); + return timestamp.count(); +} + +int64_t DateTime::Millis() const +{ + auto timestamp = std::chrono::duration_cast<std::chrono::milliseconds>(m_time.time_since_epoch()); + return timestamp.count(); +} + +std::chrono::system_clock::time_point DateTime::UnderlyingTimestamp() const +{ + return m_time; +} + +int DateTime::GetYear(bool localTime) const +{ + return GetTimeStruct(localTime).tm_year + 1900; +} + +Month DateTime::GetMonth(bool localTime) const +{ + return static_cast<Aws::Utils::Month>(GetTimeStruct(localTime).tm_mon); +} + +int DateTime::GetDay(bool localTime) const +{ + return GetTimeStruct(localTime).tm_mday; +} + +DayOfWeek DateTime::GetDayOfWeek(bool localTime) const +{ + return static_cast<Aws::Utils::DayOfWeek>(GetTimeStruct(localTime).tm_wday); +} + +int DateTime::GetHour(bool localTime) const +{ + return GetTimeStruct(localTime).tm_hour; +} + +int DateTime::GetMinute(bool localTime) const +{ + return GetTimeStruct(localTime).tm_min; +} + +int DateTime::GetSecond(bool localTime) const +{ + return GetTimeStruct(localTime).tm_sec; +} + +bool DateTime::IsDST(bool localTime) const +{ + return GetTimeStruct(localTime).tm_isdst == 0 ? false : true; +} + +DateTime DateTime::Now() +{ + DateTime dateTime; + dateTime.m_time = std::chrono::system_clock::now(); + return dateTime; +} + +int64_t DateTime::CurrentTimeMillis() +{ + return Now().Millis(); +} + +Aws::String DateTime::CalculateLocalTimestampAsString(const char* formatStr) +{ + DateTime now = Now(); + return now.ToLocalTimeString(formatStr); +} + +Aws::String DateTime::CalculateGmtTimestampAsString(const char* formatStr) +{ + DateTime now = Now(); + return now.ToGmtString(formatStr); +} + +Aws::String DateTime::CalculateGmtTimeWithMsPrecision() +{ + auto now = DateTime::Now(); + struct tm gmtTimeStamp = now.ConvertTimestampToGmtStruct(); + + char formattedString[100]; + auto len = std::strftime(formattedString, sizeof(formattedString), "%Y-%m-%d %H:%M:%S", &gmtTimeStamp); + if (len) + { + auto ms = now.Millis(); + ms = ms - ms / 1000 * 1000; // calculate the milliseconds as fraction. + formattedString[len++] = '.'; + int divisor = 100; + while(divisor) + { + auto digit = ms / divisor; + formattedString[len++] = char('0' + digit); + ms = ms - divisor * digit; + divisor /= 10; + } + formattedString[len] = '\0'; + } + return formattedString; +} + +int DateTime::CalculateCurrentHour() +{ + return Now().GetHour(true); +} + +double DateTime::ComputeCurrentTimestampInAmazonFormat() +{ + return Now().SecondsWithMSPrecision(); +} + +std::chrono::milliseconds DateTime::Diff(const DateTime& a, const DateTime& b) +{ + auto diff = a.m_time - b.m_time; + return std::chrono::duration_cast<std::chrono::milliseconds>(diff); +} + +std::chrono::milliseconds DateTime::operator-(const DateTime& other) const +{ + auto diff = this->m_time - other.m_time; + return std::chrono::duration_cast<std::chrono::milliseconds>(diff); +} + +void DateTime::ConvertTimestampStringToTimePoint(const char* timestamp, DateFormat format) +{ + std::tm timeStruct; + bool isUtc = true; + + switch (format) + { + case DateFormat::RFC822: + { + RFC822DateParser parser(timestamp); + parser.Parse(); + m_valid = parser.WasParseSuccessful(); + isUtc = parser.ShouldIAssumeThisIsUTC(); + timeStruct = parser.GetParsedTimestamp(); + break; + } + case DateFormat::ISO_8601: + { + ISO_8601DateParser parser(timestamp); + parser.Parse(); + m_valid = parser.WasParseSuccessful(); + isUtc = parser.ShouldIAssumeThisIsUTC(); + timeStruct = parser.GetParsedTimestamp(); + break; + } + case DateFormat::ISO_8601_BASIC: + { + ISO_8601BasicDateParser parser(timestamp); + parser.Parse(); + m_valid = parser.WasParseSuccessful(); + isUtc = parser.ShouldIAssumeThisIsUTC(); + timeStruct = parser.GetParsedTimestamp(); + break; + } + case DateFormat::AutoDetect: + { + RFC822DateParser rfcParser(timestamp); + rfcParser.Parse(); + if(rfcParser.WasParseSuccessful()) + { + m_valid = true; + isUtc = rfcParser.ShouldIAssumeThisIsUTC(); + timeStruct = rfcParser.GetParsedTimestamp(); + break; + } + ISO_8601DateParser isoParser(timestamp); + isoParser.Parse(); + if (isoParser.WasParseSuccessful()) + { + m_valid = true; + isUtc = isoParser.ShouldIAssumeThisIsUTC(); + timeStruct = isoParser.GetParsedTimestamp(); + break; + } + ISO_8601BasicDateParser isoBasicParser(timestamp); + isoBasicParser.Parse(); + if (isoBasicParser.WasParseSuccessful()) + { + m_valid = true; + isUtc = isoBasicParser.ShouldIAssumeThisIsUTC(); + timeStruct = isoBasicParser.GetParsedTimestamp(); + break; + } + m_valid = false; + break; + } + default: + assert(0); + } + + if (m_valid) + { + std::time_t tt; + if(isUtc) + { + tt = Aws::Time::TimeGM(&timeStruct); + } + else + { + assert(0); + AWS_LOGSTREAM_WARN(CLASS_TAG, "Non-UTC timestamp detected. This is always a bug. Make the world a better place and fix whatever sent you this timestamp: " << timestamp) + tt = std::mktime(&timeStruct); + } + m_time = std::chrono::system_clock::from_time_t(tt); + } +} + +tm DateTime::GetTimeStruct(bool localTime) const +{ + return localTime ? ConvertTimestampToLocalTimeStruct() : ConvertTimestampToGmtStruct(); +} + +tm DateTime::ConvertTimestampToLocalTimeStruct() const +{ + std::time_t time = std::chrono::system_clock::to_time_t(m_time); + struct tm localTimeStamp; + + Aws::Time::LocalTime(&localTimeStamp, time); + + return localTimeStamp; +} + +tm DateTime::ConvertTimestampToGmtStruct() const +{ + std::time_t time = std::chrono::system_clock::to_time_t(m_time); + struct tm gmtTimeStamp; + Aws::Time::GMTime(&gmtTimeStamp, time); + + return gmtTimeStamp; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/Directory.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/Directory.cpp new file mode 100644 index 0000000000..49ca56b280 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/Directory.cpp @@ -0,0 +1,323 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/platform/FileSystem.h> +#include <aws/core/utils/memory/stl/AWSQueue.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/memory/stl/AWSStreamFwd.h> +#include <aws/core/utils/StringUtils.h> +#include <fstream> +#include <cassert> + +namespace Aws +{ + namespace FileSystem + { + Aws::String Join(const Aws::String& leftSegment, const Aws::String& rightSegment) + { + return Join(PATH_DELIM, leftSegment, rightSegment); + } + + Aws::String Join(char delimiter, const Aws::String& leftSegment, const Aws::String& rightSegment) + { + Aws::StringStream ss; + + if (!leftSegment.empty()) + { + if (leftSegment.back() == delimiter) + { + ss << leftSegment.substr(0, leftSegment.length() - 1); + } + else + { + ss << leftSegment; + } + } + + ss << delimiter; + + if (!rightSegment.empty()) + { + if (rightSegment.front() == delimiter) + { + ss << rightSegment.substr(1); + } + else + { + ss << rightSegment; + } + } + + return ss.str(); + } + + bool DeepCopyDirectory(const char* from, const char* to) + { + if (!from || !to) return false; + + DirectoryTree fromDir(from); + + if (!fromDir) return false; + + CreateDirectoryIfNotExists(to); + DirectoryTree toDir(to); + + if (!toDir) return false; + + bool success(true); + + auto visitor = [to,&success](const DirectoryTree*, const DirectoryEntry& entry) + { + auto newPath = Aws::FileSystem::Join(to, entry.relativePath); + + if (entry.fileType == Aws::FileSystem::FileType::File) + { + Aws::OFStream copyOutStream(newPath.c_str()); + Aws::IFStream originalStream(entry.path.c_str()); + + if(!copyOutStream.good() || !originalStream.good()) + { + success = false; + return false; + } + + std::copy(std::istreambuf_iterator<char>(originalStream), + std::istreambuf_iterator<char>(), std::ostreambuf_iterator<char>(copyOutStream)); + } + else if (entry.fileType == Aws::FileSystem::FileType::Directory) + { + success = CreateDirectoryIfNotExists(newPath.c_str()); + return success; + } + + return success; + }; + + fromDir.TraverseDepthFirst(visitor); + return success; + } + + bool DeepDeleteDirectory(const char* toDelete) + { + bool success(true); + + //scope this to a new stack frame, because we won't be able to delete the root directory + //unless the directory handle has closed. + { + DirectoryTree delDir(toDelete); + + if (!delDir) return false; + + auto visitor = [&success](const DirectoryTree*, const DirectoryEntry& entry) + { + if (entry.fileType == FileType::File) + { + success = RemoveFileIfExists(entry.path.c_str()); + } + else + { + success = RemoveDirectoryIfExists(entry.path.c_str()); + } + + return success; + }; + + delDir.TraverseDepthFirst(visitor, true); + } + + if (success) + { + success = RemoveDirectoryIfExists(toDelete); + } + + return success; + } + + Directory::Directory(const Aws::String& path, const Aws::String& relativePath) + { + auto trimmedPath = Utils::StringUtils::Trim(path.c_str()); + auto trimmedRelativePath = Utils::StringUtils::Trim(relativePath.c_str()); + + if (!trimmedPath.empty() && trimmedPath[trimmedPath.length() - 1] == PATH_DELIM) + { + m_directoryEntry.path = trimmedPath.substr(0, trimmedPath.length() - 1); + } + else + { + m_directoryEntry.path = trimmedPath; + } + + if (!trimmedRelativePath.empty() && trimmedRelativePath[trimmedRelativePath.length() - 1] == PATH_DELIM) + { + m_directoryEntry.relativePath = trimmedRelativePath.substr(0, trimmedRelativePath.length() - 1); + } + else + { + m_directoryEntry.relativePath = trimmedRelativePath; + } + } + + Aws::UniquePtr<Directory> Directory::Descend(const DirectoryEntry& directoryEntry) + { + assert(directoryEntry.fileType != FileType::File); + return OpenDirectory(directoryEntry.path, directoryEntry.relativePath); + } + + Aws::Vector<Aws::String> Directory::GetAllFilePathsInDirectory(const Aws::String& path) + { + Aws::FileSystem::DirectoryTree tree(path); + Aws::Vector<Aws::String> filesVector; + auto visitor = [&](const Aws::FileSystem::DirectoryTree*, const Aws::FileSystem::DirectoryEntry& entry) + { + if (entry.fileType == Aws::FileSystem::FileType::File) + { + filesVector.push_back(entry.path); + } + return true; + }; + tree.TraverseBreadthFirst(visitor); + return filesVector; + } + + DirectoryTree::DirectoryTree(const Aws::String& path) + { + m_dir = OpenDirectory(path); + } + + DirectoryTree::operator bool() const + { + return m_dir->operator bool(); + } + + bool DirectoryTree::operator==(DirectoryTree& other) + { + return Diff(other).size() == 0; + } + + bool DirectoryTree::operator==(const Aws::String& path) + { + return *this == DirectoryTree(path); + } + + Aws::Map<Aws::String, DirectoryEntry> DirectoryTree::Diff(DirectoryTree& other) + { + Aws::Map<Aws::String, DirectoryEntry> thisEntries; + auto thisTraversal = [&thisEntries](const DirectoryTree*, const DirectoryEntry& entry) + { + thisEntries[entry.relativePath] = entry; + return true; + }; + + Aws::Map<Aws::String, DirectoryEntry> otherEntries; + auto otherTraversal = [&thisEntries, &otherEntries](const DirectoryTree*, const DirectoryEntry& entry) + { + auto thisEntry = thisEntries.find(entry.relativePath); + if (thisEntry != thisEntries.end()) + { + thisEntries.erase(entry.relativePath); + } + else + { + otherEntries[entry.relativePath] = entry; + } + + return true; + }; + + TraverseDepthFirst(thisTraversal); + other.TraverseDepthFirst(otherTraversal); + + thisEntries.insert(otherEntries.begin(), otherEntries.end()); + return thisEntries; + } + + void DirectoryTree::TraverseDepthFirst(const DirectoryEntryVisitor& visitor, bool postOrderTraversal) + { + TraverseDepthFirst(*m_dir, visitor, postOrderTraversal); + m_dir = OpenDirectory(m_dir->GetPath()); + } + + void DirectoryTree::TraverseBreadthFirst(const DirectoryEntryVisitor& visitor) + { + TraverseBreadthFirst(*m_dir, visitor); + m_dir = OpenDirectory(m_dir->GetPath()); + } + + void DirectoryTree::TraverseBreadthFirst(Directory& dir, const DirectoryEntryVisitor& visitor) + { + if (!dir) + { + return; + } + + Aws::Queue<DirectoryEntry> queue; + while (DirectoryEntry&& entry = dir.Next()) + { + queue.push(std::move(entry)); + } + + while (queue.size() > 0) + { + auto entry = queue.front(); + queue.pop(); + if(visitor(this, entry)) + { + if(entry.fileType == FileType::Directory) + { + auto currentDir = dir.Descend(entry); + + while (DirectoryEntry&& dirEntry = currentDir->Next()) + { + queue.push(std::move(dirEntry)); + } + } + } + else + { + return; + } + } + } + + bool DirectoryTree::TraverseDepthFirst(Directory& dir, const DirectoryEntryVisitor& visitor, bool postOrder) + { + if (!dir) + { + return true; + } + + bool exitTraversal(false); + DirectoryEntry entry; + + while ((entry = dir.Next()) && !exitTraversal) + { + if(!postOrder) + { + if(!visitor(this, entry)) + { + return false; + } + } + + if (entry.fileType == FileType::Directory) + { + auto subDir = dir.Descend(entry); + exitTraversal = !TraverseDepthFirst(*subDir, visitor, postOrder); + } + + if (postOrder) + { + if (!visitor(this, entry)) + { + return false; + } + } + } + + return !exitTraversal; + } + + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp new file mode 100644 index 0000000000..eaeba1d910 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/EnumParseOverflowContainer.cpp @@ -0,0 +1,33 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/EnumParseOverflowContainer.h> +#include <aws/core/utils/logging/LogMacros.h> + +using namespace Aws::Utils; +using namespace Aws::Utils::Threading; + +static const char LOG_TAG[] = "EnumParseOverflowContainer"; + +const Aws::String& EnumParseOverflowContainer::RetrieveOverflow(int hashCode) const +{ + ReaderLockGuard guard(m_overflowLock); + auto foundIter = m_overflowMap.find(hashCode); + if (foundIter != m_overflowMap.end()) + { + AWS_LOGSTREAM_DEBUG(LOG_TAG, "Found value " << foundIter->second << " for hash " << hashCode << " from enum overflow container."); + return foundIter->second; + } + + AWS_LOGSTREAM_ERROR(LOG_TAG, "Could not find a previously stored overflow value for hash " << hashCode << ". This will likely break some requests."); + return m_emptyString; +} + +void EnumParseOverflowContainer::StoreOverflow(int hashCode, const Aws::String& value) +{ + WriterLockGuard guard(m_overflowLock); + AWS_LOGSTREAM_WARN(LOG_TAG, "Encountered enum member " << value << " which is not modeled in your clients. You should update your clients when you get a chance."); + m_overflowMap[hashCode] = value; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/FileSystemUtils.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/FileSystemUtils.cpp new file mode 100644 index 0000000000..c47f750960 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/FileSystemUtils.cpp @@ -0,0 +1,51 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/FileSystemUtils.h> + +using namespace Aws::Utils; + +Aws::String PathUtils::GetFileNameFromPathWithoutExt(const Aws::String& path) +{ + Aws::String fileName = Aws::Utils::PathUtils::GetFileNameFromPathWithExt(path); + size_t endPos = fileName.find_last_of('.'); + if (endPos == std::string::npos) + { + return fileName; + } + if (endPos == 0) // fileName is "." + { + return {}; + } + + return fileName.substr(0, endPos); +} + +Aws::String PathUtils::GetFileNameFromPathWithExt(const Aws::String& path) +{ + if (path.size() == 0) + { + return path; + } + + size_t startPos = path.find_last_of(Aws::FileSystem::PATH_DELIM); + if (startPos == path.size() - 1) + { + return {}; + } + + if (startPos == std::string::npos) + { + startPos = 0; + } + else + { + startPos += 1; + } + + size_t endPos = path.size() - 1; + + return path.substr(startPos, endPos - startPos + 1); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/GetTheLights.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/GetTheLights.cpp new file mode 100644 index 0000000000..6e78b546ab --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/GetTheLights.cpp @@ -0,0 +1,36 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/GetTheLights.h> +#include <cassert> + +namespace Aws +{ + namespace Utils + { + GetTheLights::GetTheLights() : m_value(0) + { + } + + void GetTheLights::EnterRoom(std::function<void()> &&callable) + { + int cpy = ++m_value; + assert(cpy > 0); + if(cpy == 1) + { + callable(); + } + } + + void GetTheLights::LeaveRoom(std::function<void()> &&callable) + { + int cpy = --m_value; + assert(cpy >= 0); + if(cpy == 0) + { + callable(); + } + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/HashingUtils.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/HashingUtils.cpp new file mode 100644 index 0000000000..147bddf33e --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/HashingUtils.cpp @@ -0,0 +1,236 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/base64/Base64.h> +#include <aws/core/utils/crypto/Sha256.h> +#include <aws/core/utils/crypto/Sha256HMAC.h> +#include <aws/core/utils/crypto/MD5.h> +#include <aws/core/utils/Outcome.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/memory/stl/AWSList.h> + +#include <iomanip> + +using namespace Aws::Utils; +using namespace Aws::Utils::Base64; +using namespace Aws::Utils::Crypto; + +// internal buffers are fixed-size arrays, so this is harmless memory-management wise +static Aws::Utils::Base64::Base64 s_base64; + +// Aws Glacier Tree Hash calculates hash value for each 1MB data +const static size_t TREE_HASH_ONE_MB = 1024 * 1024; + +Aws::String HashingUtils::Base64Encode(const ByteBuffer& message) +{ + return s_base64.Encode(message); +} + +ByteBuffer HashingUtils::Base64Decode(const Aws::String& encodedMessage) +{ + return s_base64.Decode(encodedMessage); +} + +ByteBuffer HashingUtils::CalculateSHA256HMAC(const ByteBuffer& toSign, const ByteBuffer& secret) +{ + Sha256HMAC hash; + return hash.Calculate(toSign, secret).GetResult(); +} + +ByteBuffer HashingUtils::CalculateSHA256(const Aws::String& str) +{ + Sha256 hash; + return hash.Calculate(str).GetResult(); +} + +ByteBuffer HashingUtils::CalculateSHA256(Aws::IOStream& stream) +{ + Sha256 hash; + return hash.Calculate(stream).GetResult(); +} + +/** + * This function is only used by HashingUtils::CalculateSHA256TreeHash() in this cpp file + * It's a helper function be used to compute the TreeHash defined at: + * http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html + */ +static ByteBuffer TreeHashFinalCompute(Aws::List<ByteBuffer>& input) +{ + Sha256 hash; + assert(input.size() != 0); + + // O(n) time complexity of merging (n + n/2 + n/4 + n/8 +...+ 1) + while (input.size() > 1) + { + auto iter = input.begin(); + // if only one element left, just left it there + while (std::next(iter) != input.end()) + { + // if >= two elements + Aws::String str(reinterpret_cast<char*>(iter->GetUnderlyingData()), iter->GetLength()); + // list erase returns iterator of next element next to the erased element or end() if erased the last one + // list insert inserts element before pos, here we erase two elements, and insert a new element + iter = input.erase(iter); + str.append(reinterpret_cast<char*>(iter->GetUnderlyingData()), iter->GetLength()); + iter = input.erase(iter); + input.insert(iter, hash.Calculate(str).GetResult()); + + if (iter == input.end()) break; + } // while process to the last element + } // while the list has only one element left + + return *(input.begin()); +} + +ByteBuffer HashingUtils::CalculateSHA256TreeHash(const Aws::String& str) +{ + Sha256 hash; + if (str.size() == 0) + { + return hash.Calculate(str).GetResult(); + } + + Aws::List<ByteBuffer> input; + size_t pos = 0; + while (pos < str.size()) + { + input.push_back(hash.Calculate(Aws::String(str, pos, TREE_HASH_ONE_MB)).GetResult()); + pos += TREE_HASH_ONE_MB; + } + + return TreeHashFinalCompute(input); +} + +ByteBuffer HashingUtils::CalculateSHA256TreeHash(Aws::IOStream& stream) +{ + Sha256 hash; + Aws::List<ByteBuffer> input; + auto currentPos = stream.tellg(); + if (currentPos == std::ios::pos_type(-1)) + { + currentPos = 0; + stream.clear(); + } + stream.seekg(0, stream.beg); + Array<char> streamBuffer(TREE_HASH_ONE_MB); + while (stream.good()) + { + stream.read(streamBuffer.GetUnderlyingData(), TREE_HASH_ONE_MB); + auto bytesRead = stream.gcount(); + if (bytesRead > 0) + { + input.push_back(hash.Calculate(Aws::String(reinterpret_cast<char*>(streamBuffer.GetUnderlyingData()), static_cast<size_t>(bytesRead))).GetResult()); + } + } + stream.clear(); + stream.seekg(currentPos, stream.beg); + + if (input.size() == 0) + { + return hash.Calculate("").GetResult(); + } + return TreeHashFinalCompute(input); +} + +Aws::String HashingUtils::HexEncode(const ByteBuffer& message) +{ + Aws::String encoded; + encoded.reserve(2 * message.GetLength()); + + for (unsigned i = 0; i < message.GetLength(); ++i) + { + encoded.push_back("0123456789abcdef"[message[i] >> 4]); + encoded.push_back("0123456789abcdef"[message[i] & 0x0f]); + } + + return encoded; +} + +ByteBuffer HashingUtils::HexDecode(const Aws::String& str) +{ + //number of characters should be even + assert(str.length() % 2 == 0); + assert(str.length() >= 2); + + if(str.length() < 2 || str.length() % 2 != 0) + { + return ByteBuffer(); + } + + size_t strLength = str.length(); + size_t readIndex = 0; + + if(str[0] == '0' && (str[1] == 'x' || str[1] == 'X')) + { + strLength -= 2; + readIndex = 2; + } + + ByteBuffer hexBuffer(strLength / 2); + size_t bufferIndex = 0; + + for (size_t i = readIndex; i < str.length(); i += 2) + { + if(!StringUtils::IsAlnum(str[i]) || !StringUtils::IsAlnum(str[i + 1])) + { + //contains non-hex characters + assert(0); + } + + char firstChar = str[i]; + uint8_t distance = firstChar - '0'; + + if(isalpha(firstChar)) + { + firstChar = static_cast<char>(toupper(firstChar)); + distance = firstChar - 'A' + 10; + } + + unsigned char val = distance * 16; + + char secondChar = str[i + 1]; + distance = secondChar - '0'; + + if(isalpha(secondChar)) + { + secondChar = static_cast<char>(toupper(secondChar)); + distance = secondChar - 'A' + 10; + } + + val += distance; + hexBuffer[bufferIndex++] = val; + } + + return hexBuffer; +} + +ByteBuffer HashingUtils::CalculateMD5(const Aws::String& str) +{ + MD5 hash; + return hash.Calculate(str).GetResult(); +} + +ByteBuffer HashingUtils::CalculateMD5(Aws::IOStream& stream) +{ + MD5 hash; + return hash.Calculate(stream).GetResult(); +} + +int HashingUtils::HashString(const char* strToHash) +{ + if (!strToHash) + return 0; + + unsigned hash = 0; + while (char charValue = *strToHash++) + { + hash = charValue + 31 * hash; + } + + return hash; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/StringUtils.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/StringUtils.cpp new file mode 100644 index 0000000000..e1deb3f046 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/StringUtils.cpp @@ -0,0 +1,421 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <algorithm> +#include <iomanip> +#include <cstdlib> +#include <cstdio> +#include <cstring> +#include <functional> + +#ifdef _WIN32 +#include <Windows.h> +#endif + +using namespace Aws::Utils; + +void StringUtils::Replace(Aws::String& s, const char* search, const char* replace) +{ + if(!search || !replace) + { + return; + } + + size_t replaceLength = strlen(replace); + size_t searchLength = strlen(search); + + for (std::size_t pos = 0;; pos += replaceLength) + { + pos = s.find(search, pos); + if (pos == Aws::String::npos) + break; + + s.erase(pos, searchLength); + s.insert(pos, replace); + } +} + + +Aws::String StringUtils::ToLower(const char* source) +{ + Aws::String copy; + size_t sourceLength = strlen(source); + copy.resize(sourceLength); + //appease the latest whims of the VC++ 2017 gods + std::transform(source, source + sourceLength, copy.begin(), [](unsigned char c) { return (char)::tolower(c); }); + + return copy; +} + + +Aws::String StringUtils::ToUpper(const char* source) +{ + Aws::String copy; + size_t sourceLength = strlen(source); + copy.resize(sourceLength); + //appease the latest whims of the VC++ 2017 gods + std::transform(source, source + sourceLength, copy.begin(), [](unsigned char c) { return (char)::toupper(c); }); + + return copy; +} + + +bool StringUtils::CaselessCompare(const char* value1, const char* value2) +{ + Aws::String value1Lower = ToLower(value1); + Aws::String value2Lower = ToLower(value2); + + return value1Lower == value2Lower; +} + +Aws::Vector<Aws::String> StringUtils::Split(const Aws::String& toSplit, char splitOn) +{ + return Split(toSplit, splitOn, SIZE_MAX, SplitOptions::NOT_SET); +} + +Aws::Vector<Aws::String> StringUtils::Split(const Aws::String& toSplit, char splitOn, SplitOptions option) +{ + return Split(toSplit, splitOn, SIZE_MAX, option); +} + +Aws::Vector<Aws::String> StringUtils::Split(const Aws::String& toSplit, char splitOn, size_t numOfTargetParts) +{ + return Split(toSplit, splitOn, numOfTargetParts, SplitOptions::NOT_SET); +} + +Aws::Vector<Aws::String> StringUtils::Split(const Aws::String& toSplit, char splitOn, size_t numOfTargetParts, SplitOptions option) +{ + Aws::Vector<Aws::String> returnValues; + Aws::StringStream input(toSplit); + Aws::String item; + + while(returnValues.size() < numOfTargetParts - 1 && std::getline(input, item, splitOn)) + { + if (!item.empty() || option == SplitOptions::INCLUDE_EMPTY_ENTRIES) + { + returnValues.emplace_back(std::move(item)); + } + } + + if (std::getline(input, item, static_cast<char>(EOF))) + { + if (option != SplitOptions::INCLUDE_EMPTY_ENTRIES) + { + // Trim all leading delimiters. + item.erase(item.begin(), std::find_if(item.begin(), item.end(), [splitOn](int ch) { return ch != splitOn; })); + if (!item.empty()) + { + returnValues.emplace_back(std::move(item)); + } + } + else + { + returnValues.emplace_back(std::move(item)); + } + + } + // To handle the case when there are trailing delimiters. + else if (!toSplit.empty() && toSplit.back() == splitOn && option == SplitOptions::INCLUDE_EMPTY_ENTRIES) + { + returnValues.emplace_back(); + } + + return returnValues; +} + +Aws::Vector<Aws::String> StringUtils::SplitOnLine(const Aws::String& toSplit) +{ + Aws::StringStream input(toSplit); + Aws::Vector<Aws::String> returnValues; + Aws::String item; + + while (std::getline(input, item)) + { + if (item.size() > 0) + { + returnValues.push_back(item); + } + } + + return returnValues; +} + + +Aws::String StringUtils::URLEncode(const char* unsafe) +{ + Aws::StringStream escaped; + escaped.fill('0'); + escaped << std::hex << std::uppercase; + + size_t unsafeLength = strlen(unsafe); + for (auto i = unsafe, n = unsafe + unsafeLength; i != n; ++i) + { + char c = *i; + if (IsAlnum(c) || c == '-' || c == '_' || c == '.' || c == '~') + { + escaped << (char)c; + } + else + { + //this unsigned char cast allows us to handle unicode characters. + escaped << '%' << std::setw(2) << int((unsigned char)c) << std::setw(0); + } + } + + return escaped.str(); +} + +Aws::String StringUtils::UTF8Escape(const char* unicodeString, const char* delimiter) +{ + Aws::StringStream escaped; + escaped.fill('0'); + escaped << std::hex << std::uppercase; + + size_t unsafeLength = strlen(unicodeString); + for (auto i = unicodeString, n = unicodeString + unsafeLength; i != n; ++i) + { + int c = *i; + if (c >= ' ' && c < 127 ) + { + escaped << (char)c; + } + else + { + //this unsigned char cast allows us to handle unicode characters. + escaped << delimiter << std::setw(2) << int((unsigned char)c) << std::setw(0); + } + } + + return escaped.str(); +} + +Aws::String StringUtils::URLEncode(double unsafe) +{ + char buffer[32]; +#if defined(_MSC_VER) && _MSC_VER < 1900 + _snprintf_s(buffer, sizeof(buffer), _TRUNCATE, "%g", unsafe); +#else + snprintf(buffer, sizeof(buffer), "%g", unsafe); +#endif + + return StringUtils::URLEncode(buffer); +} + + +Aws::String StringUtils::URLDecode(const char* safe) +{ + Aws::String unescaped; + + for (; *safe; safe++) + { + switch(*safe) + { + case '%': + { + int hex = 0; + auto ch = *++safe; + if (ch >= '0' && ch <= '9') + { + hex = (ch - '0') * 16; + } + else if (ch >= 'A' && ch <= 'F') + { + hex = (ch - 'A' + 10) * 16; + } + else if (ch >= 'a' && ch <= 'f') + { + hex = (ch - 'a' + 10) * 16; + } + else + { + unescaped.push_back('%'); + if (ch == 0) + { + return unescaped; + } + unescaped.push_back(ch); + break; + } + + ch = *++safe; + if (ch >= '0' && ch <= '9') + { + hex += (ch - '0'); + } + else if (ch >= 'A' && ch <= 'F') + { + hex += (ch - 'A' + 10); + } + else if (ch >= 'a' && ch <= 'f') + { + hex += (ch - 'a' + 10); + } + else + { + unescaped.push_back('%'); + unescaped.push_back(*(safe - 1)); + if (ch == 0) + { + return unescaped; + } + unescaped.push_back(ch); + break; + } + + unescaped.push_back(char(hex)); + break; + } + case '+': + unescaped.push_back(' '); + break; + default: + unescaped.push_back(*safe); + break; + } + } + + return unescaped; +} + +static bool IsSpace(int ch) +{ + if (ch < -1 || ch > 255) + { + return false; + } + + return ::isspace(ch) != 0; +} + +Aws::String StringUtils::LTrim(const char* source) +{ + Aws::String copy(source); + copy.erase(copy.begin(), std::find_if(copy.begin(), copy.end(), [](int ch) { return !IsSpace(ch); })); + return copy; +} + +// trim from end +Aws::String StringUtils::RTrim(const char* source) +{ + Aws::String copy(source); + copy.erase(std::find_if(copy.rbegin(), copy.rend(), [](int ch) { return !IsSpace(ch); }).base(), copy.end()); + return copy; +} + +// trim from both ends +Aws::String StringUtils::Trim(const char* source) +{ + return LTrim(RTrim(source).c_str()); +} + +long long StringUtils::ConvertToInt64(const char* source) +{ + if(!source) + { + return 0; + } + +#ifdef __ANDROID__ + return atoll(source); +#else + return std::atoll(source); +#endif // __ANDROID__ +} + + +long StringUtils::ConvertToInt32(const char* source) +{ + if (!source) + { + return 0; + } + + return std::atol(source); +} + + +bool StringUtils::ConvertToBool(const char* source) +{ + if(!source) + { + return false; + } + + Aws::String strValue = ToLower(source); + if(strValue == "true" || strValue == "1") + { + return true; + } + + return false; +} + + +double StringUtils::ConvertToDouble(const char* source) +{ + if(!source) + { + return 0.0; + } + + return std::strtod(source, NULL); +} + +#ifdef _WIN32 + +Aws::WString StringUtils::ToWString(const char* source) +{ + const auto len = static_cast<int>(std::strlen(source)); + Aws::WString outString; + outString.resize(len); // there is no way UTF-16 would require _more_ code-points than UTF-8 for the _same_ string + const auto result = MultiByteToWideChar(CP_UTF8 /*CodePage*/, + 0 /*dwFlags*/, + source /*lpMultiByteStr*/, + len /*cbMultiByte*/, + &outString[0] /*lpWideCharStr*/, + static_cast<int>(outString.length())/*cchWideChar*/); + if (!result) + { + return L""; + } + outString.resize(result); + return outString; +} + +Aws::String StringUtils::FromWString(const wchar_t* source) +{ + const auto len = static_cast<int>(std::wcslen(source)); + Aws::String output; + if (int requiredSizeInBytes = WideCharToMultiByte(CP_UTF8 /*CodePage*/, + 0 /*dwFlags*/, + source /*lpWideCharStr*/, + len /*cchWideChar*/, + nullptr /*lpMultiByteStr*/, + 0 /*cbMultiByte*/, + nullptr /*lpDefaultChar*/, + nullptr /*lpUsedDefaultChar*/)) + { + output.resize(requiredSizeInBytes); + } + const auto result = WideCharToMultiByte(CP_UTF8 /*CodePage*/, + 0 /*dwFlags*/, + source /*lpWideCharStr*/, + len /*cchWideChar*/, + &output[0] /*lpMultiByteStr*/, + static_cast<int>(output.length()) /*cbMultiByte*/, + nullptr /*lpDefaultChar*/, + nullptr /*lpUsedDefaultChar*/); + if (!result) + { + return ""; + } + output.resize(result); + return output; +} + +#endif diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/TempFile.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/TempFile.cpp new file mode 100644 index 0000000000..7bc07266c9 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/TempFile.cpp @@ -0,0 +1,54 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/FileSystemUtils.h> + +#include <aws/core/platform/FileSystem.h> + +namespace Aws +{ + namespace Utils + { + static Aws::String ComputeTempFileName(const char* prefix, const char* suffix) + { + Aws::String prefixStr; + + if (prefix) + { + prefixStr = prefix; + } + + Aws::String suffixStr; + + if (suffix) + { + suffixStr = suffix; + } + + return prefixStr + Aws::FileSystem::CreateTempFilePath() + suffixStr; + } + + TempFile::TempFile(const char* prefix, const char* suffix, std::ios_base::openmode openFlags) : + FStreamWithFileName(ComputeTempFileName(prefix, suffix).c_str(), openFlags) + { + } + + TempFile::TempFile(const char* prefix, std::ios_base::openmode openFlags) : + FStreamWithFileName(ComputeTempFileName(prefix, nullptr).c_str(), openFlags) + { + } + + TempFile::TempFile(std::ios_base::openmode openFlags) : + FStreamWithFileName(ComputeTempFileName(nullptr, nullptr).c_str(), openFlags) + { + } + + + TempFile::~TempFile() + { + Aws::FileSystem::RemoveFileIfExists(m_fileName.c_str()); + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/UUID.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/UUID.cpp new file mode 100644 index 0000000000..862f3eacdd --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/UUID.cpp @@ -0,0 +1,90 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/UUID.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/StringUtils.h> +#include <aws/core/utils/crypto/Factories.h> +#include <aws/core/utils/crypto/SecureRandom.h> +#include <iomanip> + +namespace Aws +{ + namespace Utils + { + static const size_t UUID_STR_SIZE = 0x24u; // 36 characters + static const size_t VERSION_LOCATION = 0x06u; + static const size_t VARIANT_LOCATION = 0x08u; + static const unsigned char VERSION = 0x40u; + static const unsigned char VERSION_MASK = 0x0Fu; + static const unsigned char VARIANT = 0x80u; + static const unsigned char VARIANT_MASK = 0x3Fu; + + static void hexify(Aws::String& ss, const unsigned char* toWrite, size_t min, size_t max) + { + for (size_t i = min; i < max; ++i) + { + ss.push_back("0123456789ABCDEF"[toWrite[i] >> 4]); + ss.push_back("0123456789ABCDEF"[toWrite[i] & 0x0F]); + } + } + + UUID::UUID(const Aws::String& uuidToConvert) + { + //GUID has 2 characters per byte + 4 dashes = 36 bytes + assert(uuidToConvert.length() == UUID_STR_SIZE); + memset(m_uuid, 0, sizeof(m_uuid)); + Aws::String escapedHexStr(uuidToConvert); + StringUtils::Replace(escapedHexStr, "-", ""); + assert(escapedHexStr.length() == UUID_BINARY_SIZE * 2); + ByteBuffer&& rawUuid = HashingUtils::HexDecode(escapedHexStr); + memcpy(m_uuid, rawUuid.GetUnderlyingData(), rawUuid.GetLength()); + } + + UUID::UUID(const unsigned char toCopy[UUID_BINARY_SIZE]) + { + memcpy(m_uuid, toCopy, sizeof(m_uuid)); + } + + UUID::operator Aws::String() const + { + Aws::String ss; + ss.reserve(UUID_STR_SIZE); + hexify(ss, m_uuid, 0, 4); + ss.push_back('-'); + + hexify(ss, m_uuid, 4, 6); + ss.push_back('-'); + + hexify(ss, m_uuid, 6, 8); + ss.push_back('-'); + + hexify(ss, m_uuid, 8, 10); + ss.push_back('-'); + + hexify(ss, m_uuid, 10, 16); + + return ss; + } + + UUID UUID::RandomUUID() + { + auto secureRandom = Crypto::CreateSecureRandomBytesImplementation(); + assert(secureRandom); + + unsigned char randomBytes[UUID_BINARY_SIZE]; + memset(randomBytes, 0, UUID_BINARY_SIZE); + secureRandom->GetBytes(randomBytes, UUID_BINARY_SIZE); + //Set version bits to 0100 + //https://tools.ietf.org/html/rfc4122#section-4.1.3 + randomBytes[VERSION_LOCATION] = (randomBytes[VERSION_LOCATION] & VERSION_MASK) | VERSION; + //set variant bits to 10 + //https://tools.ietf.org/html/rfc4122#section-4.1.1 + randomBytes[VARIANT_LOCATION] = (randomBytes[VARIANT_LOCATION] & VARIANT_MASK) | VARIANT; + + return UUID(randomBytes); + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/base64/Base64.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/base64/Base64.cpp new file mode 100644 index 0000000000..2103d6d5a6 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/base64/Base64.cpp @@ -0,0 +1,148 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/base64/Base64.h> +#include <cstring> + +using namespace Aws::Utils::Base64; + +static const uint8_t SENTINEL_VALUE = 255; +static const char BASE64_ENCODING_TABLE_MIME[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +namespace Aws +{ +namespace Utils +{ +namespace Base64 +{ + +Base64::Base64(const char *encodingTable) +{ + if(encodingTable == nullptr) + { + encodingTable = BASE64_ENCODING_TABLE_MIME; + } + + size_t encodingTableLength = strlen(encodingTable); + if(encodingTableLength != 64) + { + encodingTable = BASE64_ENCODING_TABLE_MIME; + encodingTableLength = 64; + } + + memcpy(m_mimeBase64EncodingTable, encodingTable, encodingTableLength); + + memset((void *)m_mimeBase64DecodingTable, 0, 256); + + for(uint32_t i = 0; i < encodingTableLength; ++i) + { + uint32_t index = static_cast<uint32_t>(m_mimeBase64EncodingTable[i]); + m_mimeBase64DecodingTable[index] = static_cast<uint8_t>(i); + } + + m_mimeBase64DecodingTable[(uint32_t)'='] = SENTINEL_VALUE; +} + +Aws::String Base64::Encode(const Aws::Utils::ByteBuffer& buffer) const +{ + size_t bufferLength = buffer.GetLength(); + size_t blockCount = (bufferLength + 2) / 3; + size_t remainderCount = (bufferLength % 3); + + Aws::String outputString; + outputString.reserve(CalculateBase64EncodedLength(buffer)); + + for(size_t i = 0; i < bufferLength; i += 3 ) + { + uint32_t block = buffer[ i ]; + + block <<= 8; + if (i + 1 < bufferLength) + { + block = block | buffer[ i + 1 ]; + } + + block <<= 8; + if (i + 2 < bufferLength) + { + block = block | buffer[ i + 2 ]; + } + + outputString.push_back(m_mimeBase64EncodingTable[(block >> 18) & 0x3F]); + outputString.push_back(m_mimeBase64EncodingTable[(block >> 12) & 0x3F]); + outputString.push_back(m_mimeBase64EncodingTable[(block >> 6) & 0x3F]); + outputString.push_back(m_mimeBase64EncodingTable[block & 0x3F]); + } + + if(remainderCount > 0) + { + outputString[blockCount * 4 - 1] = '='; + if(remainderCount == 1) + { + outputString[blockCount * 4 - 2] = '='; + } + } + + return outputString; +} + +Aws::Utils::ByteBuffer Base64::Decode(const Aws::String& str) const +{ + size_t decodedLength = CalculateBase64DecodedLength(str); + + Aws::Utils::ByteBuffer buffer(decodedLength); + + const char* rawString = str.c_str(); + size_t blockCount = str.length() / 4; + for(size_t i = 0; i < blockCount; ++i) + { + size_t stringIndex = i * 4; + + uint32_t value1 = m_mimeBase64DecodingTable[uint32_t(rawString[stringIndex])]; + uint32_t value2 = m_mimeBase64DecodingTable[uint32_t(rawString[++stringIndex])]; + uint32_t value3 = m_mimeBase64DecodingTable[uint32_t(rawString[++stringIndex])]; + uint32_t value4 = m_mimeBase64DecodingTable[uint32_t(rawString[++stringIndex])]; + + size_t bufferIndex = i * 3; + buffer[bufferIndex] = static_cast<uint8_t>((value1 << 2) | ((value2 >> 4) & 0x03)); + if(value3 != SENTINEL_VALUE) + { + buffer[++bufferIndex] = static_cast<uint8_t>(((value2 << 4) & 0xF0) | ((value3 >> 2) & 0x0F)); + if(value4 != SENTINEL_VALUE) + { + buffer[++bufferIndex] = static_cast<uint8_t>((value3 & 0x03) << 6 | value4); + } + } + } + + return buffer; +} + +size_t Base64::CalculateBase64DecodedLength(const Aws::String& b64input) +{ + const size_t len = b64input.length(); + if(len < 2) + { + return 0; + } + + size_t padding = 0; + + if (b64input[len - 1] == '=' && b64input[len - 2] == '=') //last two chars are = + padding = 2; + else if (b64input[len - 1] == '=') //last char is = + padding = 1; + + return (len * 3 / 4 - padding); +} + +size_t Base64::CalculateBase64EncodedLength(const Aws::Utils::ByteBuffer& buffer) +{ + return 4 * ((buffer.GetLength() + 2) / 3); +} + +} // namespace Base64 +} // namespace Utils +} // namespace Aws
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Cipher.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Cipher.cpp new file mode 100644 index 0000000000..1c844273f4 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Cipher.cpp @@ -0,0 +1,123 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/crypto/Cipher.h> +#include <aws/core/utils/crypto/Factories.h> +#include <aws/core/utils/crypto/SecureRandom.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <cstdlib> +#include <climits> + +//if you are reading this, you are witnessing pure brilliance. +#define IS_BIG_ENDIAN (*(uint16_t*)"\0\xff" < 0x100) + +using namespace Aws::Utils::Crypto; +using namespace Aws::Utils; + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + static const char* LOG_TAG = "Cipher"; + + //swap byte ordering + template<class T> + typename std::enable_if<std::is_unsigned<T>::value, T>::type + bswap(T i, T j = 0u, std::size_t n = 0u) + { + return n == sizeof(T) ? j : + bswap<T>(i >> CHAR_BIT, (j << CHAR_BIT) | (i & (T)(unsigned char)(-1)), n + 1); + } + + CryptoBuffer IncrementCTRCounter(const CryptoBuffer& counter, uint32_t numberOfBlocks) + { + // minium counter size is 12 bytes. This isn't a variable because some compilers + // are stupid and thing that variable is unused. + assert(counter.GetLength() >= 12); + + CryptoBuffer incrementedCounter(counter); + + //get the last 4 bytes and manipulate them as an integer. + uint32_t* ctrPtr = (uint32_t*)(incrementedCounter.GetUnderlyingData() + incrementedCounter.GetLength() - sizeof(int32_t)); + if(IS_BIG_ENDIAN) + { + //you likely are not Big Endian, but + //if it's big endian, just go ahead and increment it... done + *ctrPtr += numberOfBlocks; + } + else + { + //otherwise, swap the byte ordering of the integer we loaded from the buffer (because it is backwards). However, the number of blocks is already properly + //aligned. Once we compute the new value, swap it back so that the mirroring operation goes back to the actual buffer. + *ctrPtr = bswap<uint32_t>(bswap<uint32_t>(*ctrPtr) + numberOfBlocks); + } + + return incrementedCounter; + } + + CryptoBuffer GenerateXRandomBytes(size_t lengthBytes, bool ctrMode) + { + std::shared_ptr<SecureRandomBytes> rng = CreateSecureRandomBytesImplementation(); + + CryptoBuffer bytes(lengthBytes); + size_t lengthToGenerate = ctrMode ? (3 * bytes.GetLength()) / 4 : bytes.GetLength(); + + rng->GetBytes(bytes.GetUnderlyingData(), lengthToGenerate); + + if(!*rng) + { + AWS_LOGSTREAM_FATAL(LOG_TAG, "Random Number generation failed. Abort all crypto operations."); + assert(false); + abort(); + } + + return bytes; + } + + /** + * Generate random number per 4 bytes and use each byte for the byte in the iv + */ + CryptoBuffer SymmetricCipher::GenerateIV(size_t ivLengthBytes, bool ctrMode) + { + CryptoBuffer iv(GenerateXRandomBytes(ivLengthBytes, ctrMode)); + + if(iv.GetLength() == 0) + { + AWS_LOGSTREAM_ERROR(LOG_TAG, "Unable to generate iv of length " << ivLengthBytes); + return iv; + } + + if(ctrMode) + { + //init the counter + size_t length = iv.GetLength(); + //[ nonce 1/4] [ iv 1/2 ] [ ctr 1/4 ] + size_t ctrStart = (length / 2) + (length / 4); + for(; ctrStart < iv.GetLength() - 1; ++ ctrStart) + { + iv[ctrStart] = 0; + } + iv[length - 1] = 1; + } + + return iv; + } + + CryptoBuffer SymmetricCipher::GenerateKey(size_t keyLengthBytes) + { + CryptoBuffer const& key = GenerateXRandomBytes(keyLengthBytes, false); + + if(key.GetLength() == 0) + { + AWS_LOGSTREAM_ERROR(LOG_TAG, "Unable to generate key of length " << keyLengthBytes); + } + + return key; + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/ContentCryptoMaterial.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/ContentCryptoMaterial.cpp new file mode 100644 index 0000000000..3036bd70eb --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/ContentCryptoMaterial.cpp @@ -0,0 +1,34 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/crypto/ContentCryptoMaterial.h> +#include <aws/core/utils/crypto/Cipher.h> + +using namespace Aws::Utils::Crypto; + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + ContentCryptoMaterial::ContentCryptoMaterial() : + m_cryptoTagLength(0), m_keyWrapAlgorithm(KeyWrapAlgorithm::NONE), m_contentCryptoScheme(ContentCryptoScheme::NONE) + { + } + + ContentCryptoMaterial::ContentCryptoMaterial(ContentCryptoScheme contentCryptoScheme) : + m_contentEncryptionKey(SymmetricCipher::GenerateKey()), m_cryptoTagLength(0), m_keyWrapAlgorithm(KeyWrapAlgorithm::NONE), m_contentCryptoScheme(contentCryptoScheme) + { + + } + + ContentCryptoMaterial::ContentCryptoMaterial(const Aws::Utils::CryptoBuffer & cek, ContentCryptoScheme contentCryptoScheme) : + m_contentEncryptionKey(cek), m_cryptoTagLength(0), m_keyWrapAlgorithm(KeyWrapAlgorithm::NONE), m_contentCryptoScheme(contentCryptoScheme) + { + + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/ContentCryptoScheme.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/ContentCryptoScheme.cpp new file mode 100644 index 0000000000..f39a75df2c --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/ContentCryptoScheme.cpp @@ -0,0 +1,61 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/crypto/ContentCryptoScheme.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/EnumParseOverflowContainer.h> +#include <aws/core/Globals.h> + +using namespace Aws::Utils; + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + namespace ContentCryptoSchemeMapper + { + static const int cryptoScheme_CBC_HASH = HashingUtils::HashString("AES/CBC/PKCS5Padding"); + static const int cryptoScheme_CTR_HASH = HashingUtils::HashString("AES/CTR/NoPadding"); + static const int cryptoScheme_GCM_HASH = HashingUtils::HashString("AES/GCM/NoPadding"); + + ContentCryptoScheme GetContentCryptoSchemeForName(const Aws::String& name) + { + int hashcode = HashingUtils::HashString(name.c_str()); + if (hashcode == cryptoScheme_CBC_HASH) + { + return ContentCryptoScheme::CBC; + } + else if (hashcode == cryptoScheme_CTR_HASH) + { + return ContentCryptoScheme::CTR; + } + else if (hashcode == cryptoScheme_GCM_HASH) + { + return ContentCryptoScheme::GCM; + } + assert(0); + return ContentCryptoScheme::NONE; + } + + Aws::String GetNameForContentCryptoScheme(ContentCryptoScheme enumValue) + { + switch (enumValue) + { + case ContentCryptoScheme::CBC: + return "AES/CBC/PKCS5Padding"; + case ContentCryptoScheme::CTR: + return "AES/CTR/NoPadding"; + case ContentCryptoScheme::GCM: + return "AES/GCM/NoPadding"; + default: + assert(0); + return ""; + } + } + }//namespace ContentCryptoSchemeMapper + } //namespace Crypto + }//namespace Utils +}//namespace Aws
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/CryptoBuf.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/CryptoBuf.cpp new file mode 100644 index 0000000000..2b47097679 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/CryptoBuf.cpp @@ -0,0 +1,348 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/crypto/CryptoBuf.h> + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + SymmetricCryptoBufSrc::SymmetricCryptoBufSrc(Aws::IStream& stream, SymmetricCipher& cipher, CipherMode cipherMode, size_t bufferSize) + : + m_isBuf(PUT_BACK_SIZE), m_cipher(cipher), m_stream(stream), m_cipherMode(cipherMode), m_isFinalized(false), + m_bufferSize(bufferSize), m_putBack(PUT_BACK_SIZE) + { + char* end = reinterpret_cast<char*>(m_isBuf.GetUnderlyingData() + m_isBuf.GetLength()); + setg(end, end, end); + } + + SymmetricCryptoBufSrc::pos_type SymmetricCryptoBufSrc::seekoff(off_type off, std::ios_base::seekdir dir, std::ios_base::openmode which) + { + if(which == std::ios_base::in) + { + auto curPos = m_stream.tellg(); + //error on seek we may have read past the end already. Try resetting and seeking to the end first + if (curPos == pos_type(-1)) + { + m_stream.clear(); + m_stream.seekg(0, std::ios_base::end); + curPos = m_stream.tellg(); + } + + auto absPosition = ComputeAbsSeekPosition(off, dir, curPos); + size_t seekTo = static_cast<size_t>(absPosition); + size_t index = static_cast<size_t>(curPos); + + if(index == seekTo) + { + return curPos; + } + else if (seekTo < index) + { + m_cipher.Reset(); + m_stream.clear(); + m_stream.seekg(0); + m_isFinalized = false; + index = 0; + } + + CryptoBuffer cryptoBuffer; + while (m_cipher && index < seekTo && !m_isFinalized) + { + size_t max_read = std::min<size_t>(static_cast<size_t>(seekTo - index), m_bufferSize); + + Aws::Utils::Array<char> buf(max_read); + size_t readSize(0); + if(m_stream) + { + m_stream.read(buf.GetUnderlyingData(), max_read); + readSize = static_cast<size_t>(m_stream.gcount()); + } + + if (readSize > 0) + { + if (m_cipherMode == CipherMode::Encrypt) + { + cryptoBuffer = m_cipher.EncryptBuffer(CryptoBuffer(reinterpret_cast<unsigned char*>(buf.GetUnderlyingData()), readSize)); + } + else + { + cryptoBuffer = m_cipher.DecryptBuffer(CryptoBuffer(reinterpret_cast<unsigned char*>(buf.GetUnderlyingData()), readSize)); + } + } + else + { + if (m_cipherMode == CipherMode::Encrypt) + { + cryptoBuffer = m_cipher.FinalizeEncryption(); + } + else + { + cryptoBuffer = m_cipher.FinalizeDecryption(); + } + + m_isFinalized = true; + } + + index += cryptoBuffer.GetLength(); + } + + if (cryptoBuffer.GetLength() && m_cipher) + { + CryptoBuffer putBackArea(m_putBack); + + m_isBuf = CryptoBuffer({&putBackArea, &cryptoBuffer}); + //in the very unlikely case that the cipher had less output than the source stream. + assert(seekTo <= index); + size_t newBufferPos = index > seekTo ? cryptoBuffer.GetLength() - (index - seekTo) : cryptoBuffer.GetLength(); + char* baseBufPtr = reinterpret_cast<char*>(m_isBuf.GetUnderlyingData()); + setg(baseBufPtr, baseBufPtr + m_putBack + newBufferPos, baseBufPtr + m_isBuf.GetLength()); + + return pos_type(seekTo); + } + else if (seekTo == 0) + { + m_isBuf = CryptoBuffer(m_putBack); + char* end = reinterpret_cast<char*>(m_isBuf.GetUnderlyingData() + m_isBuf.GetLength()); + setg(end, end, end); + return pos_type(seekTo); + } + } + + return pos_type(off_type(-1)); + } + + SymmetricCryptoBufSrc::pos_type SymmetricCryptoBufSrc::seekpos(pos_type pos, std::ios_base::openmode which) + { + return seekoff(pos, std::ios_base::beg, which); + } + + SymmetricCryptoBufSrc::int_type SymmetricCryptoBufSrc::underflow() + { + if (!m_cipher || (m_isFinalized && gptr() >= egptr())) + { + return traits_type::eof(); + } + + if (gptr() < egptr()) + { + return traits_type::to_int_type(*gptr()); + } + + char* baseBufPtr = reinterpret_cast<char*>(m_isBuf.GetUnderlyingData()); + CryptoBuffer putBackArea(m_putBack); + + //eback is properly set after the first fill. So this guarantees we are on the second or later fill. + if (eback() == baseBufPtr) + { + //just fill in the last bit of the previous buffer into the put back area so that it has some data in it + memcpy(putBackArea.GetUnderlyingData(), egptr() - m_putBack, m_putBack); + } + + CryptoBuffer newDataBuf; + + while(!newDataBuf.GetLength() && !m_isFinalized) + { + Aws::Utils::Array<char> buf(m_bufferSize); + m_stream.read(buf.GetUnderlyingData(), m_bufferSize); + size_t readSize = static_cast<size_t>(m_stream.gcount()); + + if (readSize > 0) + { + if (m_cipherMode == CipherMode::Encrypt) + { + newDataBuf = m_cipher.EncryptBuffer(CryptoBuffer(reinterpret_cast<unsigned char*>(buf.GetUnderlyingData()), readSize)); + } + else + { + newDataBuf = m_cipher.DecryptBuffer(CryptoBuffer(reinterpret_cast<unsigned char*>(buf.GetUnderlyingData()), readSize)); + } + } + else + { + if (m_cipherMode == CipherMode::Encrypt) + { + newDataBuf = m_cipher.FinalizeEncryption(); + } + else + { + newDataBuf = m_cipher.FinalizeDecryption(); + } + + m_isFinalized = true; + } + } + + + if(newDataBuf.GetLength() > 0) + { + m_isBuf = CryptoBuffer({&putBackArea, &newDataBuf}); + + baseBufPtr = reinterpret_cast<char*>(m_isBuf.GetUnderlyingData()); + setg(baseBufPtr, baseBufPtr + m_putBack, baseBufPtr + m_isBuf.GetLength()); + + return traits_type::to_int_type(*gptr()); + } + + return traits_type::eof(); + } + + SymmetricCryptoBufSrc::off_type SymmetricCryptoBufSrc::ComputeAbsSeekPosition(off_type pos, std::ios_base::seekdir dir, std::fpos<FPOS_TYPE> curPos) + { + switch(dir) + { + case std::ios_base::beg: + return pos; + case std::ios_base::cur: + return m_stream.tellg() + pos; + case std::ios_base::end: + { + off_type absPos = m_stream.seekg(0, std::ios_base::end).tellg() - pos; + m_stream.seekg(curPos); + return absPos; + } + default: + assert(0); + return off_type(-1); + } + } + + void SymmetricCryptoBufSrc::FinalizeCipher() + { + if(m_cipher && !m_isFinalized) + { + if(m_cipherMode == CipherMode::Encrypt) + { + m_cipher.FinalizeEncryption(); + } + else + { + m_cipher.FinalizeDecryption(); + } + } + } + + SymmetricCryptoBufSink::SymmetricCryptoBufSink(Aws::OStream& stream, SymmetricCipher& cipher, CipherMode cipherMode, size_t bufferSize, int16_t blockOffset) + : + m_osBuf(bufferSize), m_cipher(cipher), m_stream(stream), m_cipherMode(cipherMode), m_isFinalized(false), m_blockOffset(blockOffset) + { + assert(m_blockOffset < 16 && m_blockOffset >= 0); + char* outputBase = reinterpret_cast<char*>(m_osBuf.GetUnderlyingData()); + setp(outputBase, outputBase + bufferSize - 1); + } + + SymmetricCryptoBufSink::~SymmetricCryptoBufSink() + { + FinalizeCiphersAndFlushSink(); + } + + void SymmetricCryptoBufSink::FinalizeCiphersAndFlushSink() + { + if(m_cipher && !m_isFinalized) + { + writeOutput(true); + } + } + + bool SymmetricCryptoBufSink::writeOutput(bool finalize) + { + if(!m_isFinalized) + { + CryptoBuffer cryptoBuf; + if (pptr() > pbase()) + { + if (m_cipherMode == CipherMode::Encrypt) + { + cryptoBuf = m_cipher.EncryptBuffer(CryptoBuffer(reinterpret_cast<unsigned char*>(pbase()), pptr() - pbase())); + } + else + { + cryptoBuf = m_cipher.DecryptBuffer(CryptoBuffer(reinterpret_cast<unsigned char*>(pbase()), pptr() - pbase())); + } + + pbump(-(static_cast<int>(pptr() - pbase()))); + } + if(finalize) + { + CryptoBuffer finalBuffer; + if (m_cipherMode == CipherMode::Encrypt) + { + finalBuffer = m_cipher.FinalizeEncryption(); + } + else + { + finalBuffer = m_cipher.FinalizeDecryption(); + } + if(cryptoBuf.GetLength()) + { + cryptoBuf = CryptoBuffer({&cryptoBuf, &finalBuffer}); + } + else + { + cryptoBuf = std::move(finalBuffer); + } + + m_isFinalized = true; + } + + if (m_cipher) + { + if(cryptoBuf.GetLength()) + { + //allow mid block decryption. We have to decrypt it, but we don't have to write it to the stream. + //the assumption here is that tellp() will always be 0 or >= 16 bytes. The block offset should only + //be the offset of the first block read. + size_t len = cryptoBuf.GetLength(); + size_t blockOffset = m_stream.tellp() > m_blockOffset ? 0 : m_blockOffset; + if (len > blockOffset) + { + m_stream.write(reinterpret_cast<char*>(cryptoBuf.GetUnderlyingData() + blockOffset), len - blockOffset); + m_blockOffset = 0; + } + else + { + m_blockOffset -= static_cast<int16_t>(len); + } + } + return true; + } + } + + return false; + } + + SymmetricCryptoBufSink::int_type SymmetricCryptoBufSink::overflow(int_type ch) + { + if(m_cipher && m_stream) + { + if(ch != traits_type::eof()) + { + *pptr() = (char)ch; + pbump(1); + } + + if(writeOutput(ch == traits_type::eof())) + { + return ch; + } + } + + return traits_type::eof(); + } + + int SymmetricCryptoBufSink::sync() + { + if(m_cipher && m_stream) + { + return writeOutput(false) ? 0 : -1; + } + + return -1; + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/CryptoStream.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/CryptoStream.cpp new file mode 100644 index 0000000000..2d645f7427 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/CryptoStream.cpp @@ -0,0 +1,52 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/crypto/CryptoStream.h> + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + static const char* CLASS_TAG = "Aws::Utils::Crypto::SymmetricCryptoStream"; + + SymmetricCryptoStream::SymmetricCryptoStream(Aws::IStream& src, CipherMode mode, SymmetricCipher& cipher, size_t bufSize) : + Aws::IOStream(m_cryptoBuf = Aws::New<SymmetricCryptoBufSrc>(CLASS_TAG, src, cipher, mode, bufSize)), m_hasOwnership(true) + { + } + + SymmetricCryptoStream::SymmetricCryptoStream(Aws::OStream& sink, CipherMode mode, SymmetricCipher& cipher, size_t bufSize, int16_t blockOffset) : + Aws::IOStream(m_cryptoBuf = Aws::New<SymmetricCryptoBufSink>(CLASS_TAG, sink, cipher, mode, bufSize, blockOffset)), m_hasOwnership(true) + { + } + + SymmetricCryptoStream::SymmetricCryptoStream(Aws::Utils::Crypto::SymmetricCryptoBufSrc& bufSrc) : + Aws::IOStream(&bufSrc), m_cryptoBuf(&bufSrc), m_hasOwnership(false) + { + } + + SymmetricCryptoStream::SymmetricCryptoStream(Aws::Utils::Crypto::SymmetricCryptoBufSink& bufSink) : + Aws::IOStream(&bufSink), m_cryptoBuf(&bufSink), m_hasOwnership(false) + { + } + + SymmetricCryptoStream::~SymmetricCryptoStream() + { + Finalize(); + + if(m_hasOwnership && m_cryptoBuf) + { + Aws::Delete(m_cryptoBuf); + } + } + + void SymmetricCryptoStream::Finalize() + { + assert(m_cryptoBuf); + m_cryptoBuf->Finalize(); + } + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/EncryptionMaterials.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/EncryptionMaterials.cpp new file mode 100644 index 0000000000..d000c86baa --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/EncryptionMaterials.cpp @@ -0,0 +1,19 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/crypto/EncryptionMaterials.h> + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + //this is here to force the linker to behave correctly since this is an interface that will need to cross the dll + //boundary. + EncryptionMaterials::~EncryptionMaterials() + {} + } + } +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/KeyWrapAlgorithm.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/KeyWrapAlgorithm.cpp new file mode 100644 index 0000000000..b9e098775c --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/KeyWrapAlgorithm.cpp @@ -0,0 +1,68 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/crypto/KeyWrapAlgorithm.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/EnumParseOverflowContainer.h> +#include <aws/core/Globals.h> + +using namespace Aws::Utils; + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + namespace KeyWrapAlgorithmMapper + { + static const int keyWrapAlgorithm_KMS_HASH = HashingUtils::HashString("kms"); + static const int keyWrapAlgorithm_KMS_CONTEXT_HASH = HashingUtils::HashString("kms+context"); + static const int keyWrapAlgorithm_KeyWrap_HASH = HashingUtils::HashString("AESWrap"); + static const int keyWrapAlgorithm_AES_GCM_HASH = HashingUtils::HashString("AES/GCM"); + + KeyWrapAlgorithm GetKeyWrapAlgorithmForName(const Aws::String& name) + { + int hashcode = HashingUtils::HashString(name.c_str()); + if (hashcode == keyWrapAlgorithm_KMS_HASH) + { + return KeyWrapAlgorithm::KMS; + } + else if (hashcode == keyWrapAlgorithm_KMS_CONTEXT_HASH) + { + return KeyWrapAlgorithm::KMS_CONTEXT; + } + else if (hashcode == keyWrapAlgorithm_KeyWrap_HASH) + { + return KeyWrapAlgorithm::AES_KEY_WRAP; + } + else if (hashcode == keyWrapAlgorithm_AES_GCM_HASH) + { + return KeyWrapAlgorithm::AES_GCM; + } + assert(0); + return KeyWrapAlgorithm::NONE; + } + + Aws::String GetNameForKeyWrapAlgorithm(KeyWrapAlgorithm enumValue) + { + switch (enumValue) + { + case KeyWrapAlgorithm::KMS: + return "kms"; + case KeyWrapAlgorithm::KMS_CONTEXT: + return "kms+context"; + case KeyWrapAlgorithm::AES_KEY_WRAP: + return "AESWrap"; + case KeyWrapAlgorithm::AES_GCM: + return "AES/GCM"; + default: + assert(0); + } + return ""; + } + }//namespace KeyWrapAlgorithmMapper + }//namespace Crypto + }//namespace Utils +}//namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/MD5.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/MD5.cpp new file mode 100644 index 0000000000..bf14ace1ad --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/MD5.cpp @@ -0,0 +1,31 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/crypto/MD5.h> +#include <aws/core/utils/Outcome.h> +#include <aws/core/utils/crypto/Factories.h> + +using namespace Aws::Utils::Crypto; + + +MD5::MD5() : + m_hashImpl(CreateMD5Implementation()) +{ +} + +MD5::~MD5() +{ +} + +HashResult MD5::Calculate(const Aws::String& str) +{ + return m_hashImpl->Calculate(str); +} + +HashResult MD5::Calculate(Aws::IStream& stream) +{ + return m_hashImpl->Calculate(stream); +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Sha256.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Sha256.cpp new file mode 100644 index 0000000000..178df00d37 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Sha256.cpp @@ -0,0 +1,30 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/crypto/Sha256.h> +#include <aws/core/utils/Outcome.h> +#include <aws/core/utils/crypto/Factories.h> + +using namespace Aws::Utils::Crypto; + +Sha256::Sha256() : + m_hashImpl(CreateSha256Implementation()) +{ +} + +Sha256::~Sha256() +{ +} + +HashResult Sha256::Calculate(const Aws::String& str) +{ + return m_hashImpl->Calculate(str); +} + +HashResult Sha256::Calculate(Aws::IStream& stream) +{ + return m_hashImpl->Calculate(stream); +}
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Sha256HMAC.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Sha256HMAC.cpp new file mode 100644 index 0000000000..ecc1f06529 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/Sha256HMAC.cpp @@ -0,0 +1,34 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/crypto/Sha256HMAC.h> +#include <aws/core/utils/crypto/Factories.h> +#include <aws/core/utils/Outcome.h> + +namespace Aws +{ +namespace Utils +{ +namespace Crypto +{ + +Sha256HMAC::Sha256HMAC() : + m_hmacImpl(CreateSha256HMACImplementation()) +{ +} + +Sha256HMAC::~Sha256HMAC() +{ +} + +HashResult Sha256HMAC::Calculate(const Aws::Utils::ByteBuffer& toSign, const Aws::Utils::ByteBuffer& secret) +{ + return m_hmacImpl->Calculate(toSign, secret); +} + +} // namespace Crypto +} // namespace Utils +} // namespace Aws
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/factory/Factories.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/factory/Factories.cpp new file mode 100644 index 0000000000..bff0382241 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/factory/Factories.cpp @@ -0,0 +1,895 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/crypto/Factories.h> +#include <aws/core/utils/crypto/Hash.h> +#include <aws/core/utils/crypto/HMAC.h> + +#if ENABLE_BCRYPT_ENCRYPTION + #error #include <aws/core/utils/crypto/bcrypt/CryptoImpl.h> +#elif ENABLE_OPENSSL_ENCRYPTION + #include <aws/core/utils/crypto/openssl/CryptoImpl.h> +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + #error #include <aws/core/utils/crypto/commoncrypto/CryptoImpl.h> + #include <aws/core/utils/logging/LogMacros.h> +#else + // if you don't have any encryption you still need to pull in the interface definitions + #include <aws/core/utils/crypto/Hash.h> + #include <aws/core/utils/crypto/HMAC.h> + #include <aws/core/utils/crypto/Cipher.h> + #include <aws/core/utils/crypto/SecureRandom.h> + #define NO_ENCRYPTION +#endif + +using namespace Aws::Utils; +using namespace Aws::Utils::Crypto; + +static const char *s_allocationTag = "CryptoFactory"; + +static std::shared_ptr<HashFactory>& GetMD5Factory() +{ + static std::shared_ptr<HashFactory> s_MD5Factory(nullptr); + return s_MD5Factory; +} + +static std::shared_ptr<HashFactory>& GetSha256Factory() +{ + static std::shared_ptr<HashFactory> s_Sha256Factory(nullptr); + return s_Sha256Factory; +} + +static std::shared_ptr<HMACFactory>& GetSha256HMACFactory() +{ + static std::shared_ptr<HMACFactory> s_Sha256HMACFactory(nullptr); + return s_Sha256HMACFactory; +} + +static std::shared_ptr<SymmetricCipherFactory>& GetAES_CBCFactory() +{ + static std::shared_ptr<SymmetricCipherFactory> s_AES_CBCFactory(nullptr); + return s_AES_CBCFactory; +} + +static std::shared_ptr<SymmetricCipherFactory>& GetAES_CTRFactory() +{ + static std::shared_ptr<SymmetricCipherFactory> s_AES_CTRFactory(nullptr); + return s_AES_CTRFactory; +} + +static std::shared_ptr<SymmetricCipherFactory>& GetAES_GCMFactory() +{ + static std::shared_ptr<SymmetricCipherFactory> s_AES_GCMFactory(nullptr); + return s_AES_GCMFactory; +} + +static std::shared_ptr<SymmetricCipherFactory>& GetAES_KeyWrapFactory() +{ + static std::shared_ptr<SymmetricCipherFactory> s_AES_KeyWrapFactory(nullptr); + return s_AES_KeyWrapFactory; +} + +static std::shared_ptr<SecureRandomFactory>& GetSecureRandomFactory() +{ + static std::shared_ptr<SecureRandomFactory> s_SecureRandomFactory(nullptr); + return s_SecureRandomFactory; +} + +static std::shared_ptr<SecureRandomBytes>& GetSecureRandom() +{ + static std::shared_ptr<SecureRandomBytes> s_SecureRandom(nullptr); + return s_SecureRandom; +} + +static bool s_InitCleanupOpenSSLFlag(false); + +class DefaultMD5Factory : public HashFactory +{ +public: + std::shared_ptr<Hash> CreateImplementation() const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<MD5BcryptImpl>(s_allocationTag); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<MD5OpenSSLImpl>(s_allocationTag); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<MD5CommonCryptoImpl>(s_allocationTag); +#else + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +class DefaultSHA256Factory : public HashFactory +{ +public: + std::shared_ptr<Hash> CreateImplementation() const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<Sha256BcryptImpl>(s_allocationTag); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<Sha256OpenSSLImpl>(s_allocationTag); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<Sha256CommonCryptoImpl>(s_allocationTag); +#else + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +class DefaultSHA256HmacFactory : public HMACFactory +{ +public: + std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation() const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<Sha256HMACBcryptImpl>(s_allocationTag); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<Sha256HMACOpenSSLImpl>(s_allocationTag); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<Sha256HMACCommonCryptoImpl>(s_allocationTag); +#else + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + + +class DefaultAES_CBCFactory : public SymmetricCipherFactory +{ +public: + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_BCrypt>(s_allocationTag, key); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_OpenSSL>(s_allocationTag, key); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_CommonCrypto>(s_allocationTag, key); +#else + AWS_UNREFERENCED_PARAM(key); + return nullptr; +#endif + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key, const CryptoBuffer& iv, const CryptoBuffer&, const CryptoBuffer&) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_BCrypt>(s_allocationTag, key, iv); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_OpenSSL>(s_allocationTag, key, iv); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_CommonCrypto>(s_allocationTag, key, iv); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + + return nullptr; +#endif + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(CryptoBuffer&& key, CryptoBuffer&& iv, CryptoBuffer&&, CryptoBuffer&&) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_BCrypt>(s_allocationTag, key, iv); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_OpenSSL>(s_allocationTag, key, iv); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_CBC_Cipher_CommonCrypto>(s_allocationTag, key, iv); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +class DefaultAES_CTRFactory : public SymmetricCipherFactory +{ +public: + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_BCrypt>(s_allocationTag, key); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_OpenSSL>(s_allocationTag, key); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_CommonCrypto>(s_allocationTag, key); +#else + AWS_UNREFERENCED_PARAM(key); + return nullptr; +#endif + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key, const CryptoBuffer& iv, const CryptoBuffer&, const CryptoBuffer&) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_BCrypt>(s_allocationTag, key, iv); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_OpenSSL>(s_allocationTag, key, iv); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_CommonCrypto>(s_allocationTag, key, iv); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + + return nullptr; +#endif + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(CryptoBuffer&& key, CryptoBuffer&& iv, CryptoBuffer&&, CryptoBuffer&&) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_BCrypt>(s_allocationTag, key, iv); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_OpenSSL>(s_allocationTag, key, iv); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_CTR_Cipher_CommonCrypto>(s_allocationTag, key, iv); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +class DefaultAES_GCMFactory : public SymmetricCipherFactory +{ +public: + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_BCrypt>(s_allocationTag, key); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_OpenSSL>(s_allocationTag, key); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_CommonCrypto>(s_allocationTag, key); +#else + AWS_UNREFERENCED_PARAM(key); + + return nullptr; +#endif + } + + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key, const CryptoBuffer* aad) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_BCrypt>(s_allocationTag, key, aad); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_OpenSSL>(s_allocationTag, key, aad); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_CommonCrypto>(s_allocationTag, key, aad); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(aad); + return nullptr; +#endif + } + + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key, const CryptoBuffer& iv, const CryptoBuffer& tag, const CryptoBuffer& aad) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_BCrypt>(s_allocationTag, key, iv, tag, aad); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_OpenSSL>(s_allocationTag, key, iv, tag, aad); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_CommonCrypto>(s_allocationTag, key, iv, tag, aad); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + AWS_UNREFERENCED_PARAM(tag); + AWS_UNREFERENCED_PARAM(aad); + return nullptr; +#endif + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(CryptoBuffer&& key, CryptoBuffer&& iv, CryptoBuffer&& tag, CryptoBuffer&& aad) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_BCrypt>(s_allocationTag, std::move(key), std::move(iv), std::move(tag), std::move(aad)); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_OpenSSL>(s_allocationTag, std::move(key), std::move(iv), std::move(tag), std::move(aad)); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_GCM_Cipher_CommonCrypto>(s_allocationTag, std::move(key), std::move(iv), std::move(tag), std::move(aad)); +#else + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + AWS_UNREFERENCED_PARAM(tag); + AWS_UNREFERENCED_PARAM(aad); + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +class DefaultAES_KeyWrapFactory : public SymmetricCipherFactory +{ +public: + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key) const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<AES_KeyWrap_Cipher_BCrypt>(s_allocationTag, key); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<AES_KeyWrap_Cipher_OpenSSL>(s_allocationTag, key); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<AES_KeyWrap_Cipher_CommonCrypto>(s_allocationTag, key); +#else + AWS_UNREFERENCED_PARAM(key); + return nullptr; +#endif + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(const CryptoBuffer& key, const CryptoBuffer& iv, const CryptoBuffer& tag, const CryptoBuffer&) const override + { + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + AWS_UNREFERENCED_PARAM(tag); + return nullptr; + } + /** + * Factory method. Returns cipher implementation. See the SymmetricCipher class for more details. + */ + std::shared_ptr<SymmetricCipher> CreateImplementation(CryptoBuffer&& key, CryptoBuffer&& iv, CryptoBuffer&& tag, CryptoBuffer&&) const override + { + AWS_UNREFERENCED_PARAM(key); + AWS_UNREFERENCED_PARAM(iv); + AWS_UNREFERENCED_PARAM(tag); + return nullptr; + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if (s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if (s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +class DefaultSecureRandFactory : public SecureRandomFactory +{ + /** + * Factory method. Returns SecureRandom implementation. + */ + std::shared_ptr<SecureRandomBytes> CreateImplementation() const override + { +#if ENABLE_BCRYPT_ENCRYPTION + return Aws::MakeShared<SecureRandomBytes_BCrypt>(s_allocationTag); +#elif ENABLE_OPENSSL_ENCRYPTION + return Aws::MakeShared<SecureRandomBytes_OpenSSLImpl>(s_allocationTag); +#elif ENABLE_COMMONCRYPTO_ENCRYPTION + return Aws::MakeShared<SecureRandomBytes_CommonCrypto>(s_allocationTag); +#else + return nullptr; +#endif + } + + /** + * Opportunity to make any static initialization calls you need to make. + * Will only be called once. + */ + void InitStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.EnterRoom(&OpenSSL::init_static_state); + } +#endif + } + + /** + * Opportunity to make any static cleanup calls you need to make. + * will only be called at the end of the application. + */ + void CleanupStaticState() override + { +#if ENABLE_OPENSSL_ENCRYPTION + if(s_InitCleanupOpenSSLFlag) + { + OpenSSL::getTheLights.LeaveRoom(&OpenSSL::cleanup_static_state); + } +#endif + } +}; + +void Aws::Utils::Crypto::SetInitCleanupOpenSSLFlag(bool initCleanupFlag) +{ + s_InitCleanupOpenSSLFlag = initCleanupFlag; +} + +void Aws::Utils::Crypto::InitCrypto() +{ + if(GetMD5Factory()) + { + GetMD5Factory()->InitStaticState(); + } + else + { + GetMD5Factory() = Aws::MakeShared<DefaultMD5Factory>(s_allocationTag); + GetMD5Factory()->InitStaticState(); + } + + if(GetSha256Factory()) + { + GetSha256Factory()->InitStaticState(); + } + else + { + GetSha256Factory() = Aws::MakeShared<DefaultSHA256Factory>(s_allocationTag); + GetSha256Factory()->InitStaticState(); + } + + if(GetSha256HMACFactory()) + { + GetSha256HMACFactory()->InitStaticState(); + } + else + { + GetSha256HMACFactory() = Aws::MakeShared<DefaultSHA256HmacFactory>(s_allocationTag); + GetSha256HMACFactory()->InitStaticState(); + } + + if(GetAES_CBCFactory()) + { + GetAES_CBCFactory()->InitStaticState(); + } + else + { + GetAES_CBCFactory() = Aws::MakeShared<DefaultAES_CBCFactory>(s_allocationTag); + GetAES_CBCFactory()->InitStaticState(); + } + + if(GetAES_CTRFactory()) + { + GetAES_CTRFactory()->InitStaticState(); + } + else + { + GetAES_CTRFactory() = Aws::MakeShared<DefaultAES_CTRFactory>(s_allocationTag); + GetAES_CTRFactory()->InitStaticState(); + } + + if(GetAES_GCMFactory()) + { + GetAES_GCMFactory()->InitStaticState(); + } + else + { + GetAES_GCMFactory() = Aws::MakeShared<DefaultAES_GCMFactory>(s_allocationTag); + GetAES_GCMFactory()->InitStaticState(); + } + + if (!GetAES_KeyWrapFactory()) + { + GetAES_KeyWrapFactory() = Aws::MakeShared<DefaultAES_KeyWrapFactory>(s_allocationTag); + } + GetAES_KeyWrapFactory()->InitStaticState(); + + if(GetSecureRandomFactory()) + { + GetSecureRandomFactory()->InitStaticState(); + } + else + { + GetSecureRandomFactory() = Aws::MakeShared<DefaultSecureRandFactory>(s_allocationTag); + GetSecureRandomFactory()->InitStaticState(); + } + + GetSecureRandom() = GetSecureRandomFactory()->CreateImplementation(); +} + +void Aws::Utils::Crypto::CleanupCrypto() +{ + if(GetMD5Factory()) + { + GetMD5Factory()->CleanupStaticState(); + GetMD5Factory() = nullptr; + } + + if(GetSha256Factory()) + { + GetSha256Factory()->CleanupStaticState(); + GetSha256Factory() = nullptr; + } + + if(GetSha256HMACFactory()) + { + GetSha256HMACFactory()->CleanupStaticState(); + GetSha256HMACFactory() = nullptr; + } + + if(GetAES_CBCFactory()) + { + GetAES_CBCFactory()->CleanupStaticState(); + GetAES_CBCFactory() = nullptr; + } + + if(GetAES_CTRFactory()) + { + GetAES_CTRFactory()->CleanupStaticState(); + GetAES_CTRFactory() = nullptr; + } + + if(GetAES_GCMFactory()) + { + GetAES_GCMFactory()->CleanupStaticState(); + GetAES_GCMFactory() = nullptr; + } + + if(GetAES_KeyWrapFactory()) + { + GetAES_KeyWrapFactory()->CleanupStaticState(); + GetAES_KeyWrapFactory() = nullptr; + } + + if(GetSecureRandomFactory()) + { + GetSecureRandom() = nullptr; + GetSecureRandomFactory()->CleanupStaticState(); + GetSecureRandomFactory() = nullptr; + } +} + +void Aws::Utils::Crypto::SetMD5Factory(const std::shared_ptr<HashFactory>& factory) +{ + GetMD5Factory() = factory; +} + +void Aws::Utils::Crypto::SetSha256Factory(const std::shared_ptr<HashFactory>& factory) +{ + GetSha256Factory() = factory; +} + +void Aws::Utils::Crypto::SetSha256HMACFactory(const std::shared_ptr<HMACFactory>& factory) +{ + GetSha256HMACFactory() = factory; +} + +void Aws::Utils::Crypto::SetAES_CBCFactory(const std::shared_ptr<SymmetricCipherFactory>& factory) +{ + GetAES_CBCFactory() = factory; +} + +void Aws::Utils::Crypto::SetAES_CTRFactory(const std::shared_ptr<SymmetricCipherFactory>& factory) +{ + GetAES_CTRFactory() = factory; +} + +void Aws::Utils::Crypto::SetAES_GCMFactory(const std::shared_ptr<SymmetricCipherFactory>& factory) +{ + GetAES_GCMFactory() = factory; +} + +void Aws::Utils::Crypto::SetAES_KeyWrapFactory(const std::shared_ptr<SymmetricCipherFactory>& factory) +{ + GetAES_KeyWrapFactory() = factory; +} + +void Aws::Utils::Crypto::SetSecureRandomFactory(const std::shared_ptr<SecureRandomFactory>& factory) +{ + GetSecureRandomFactory() = factory; +} + +std::shared_ptr<Hash> Aws::Utils::Crypto::CreateMD5Implementation() +{ + return GetMD5Factory()->CreateImplementation(); +} + +std::shared_ptr<Hash> Aws::Utils::Crypto::CreateSha256Implementation() +{ + return GetSha256Factory()->CreateImplementation(); +} + +std::shared_ptr<Aws::Utils::Crypto::HMAC> Aws::Utils::Crypto::CreateSha256HMACImplementation() +{ + return GetSha256HMACFactory()->CreateImplementation(); +} + +#ifdef _WIN32 +#pragma warning( push ) +#pragma warning( disable : 4702 ) +#endif + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_CBCImplementation(const CryptoBuffer& key) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_CBCFactory()->CreateImplementation(key); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_CBCImplementation(const CryptoBuffer& key, const CryptoBuffer& iv) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_CBCFactory()->CreateImplementation(key, iv); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_CBCImplementation(CryptoBuffer&& key, CryptoBuffer&& iv) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_CBCFactory()->CreateImplementation(std::move(key), std::move(iv)); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_CTRImplementation(const CryptoBuffer& key) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_CTRFactory()->CreateImplementation(key); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_CTRImplementation(const CryptoBuffer& key, const CryptoBuffer& iv) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_CTRFactory()->CreateImplementation(key, iv); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_CTRImplementation(CryptoBuffer&& key, CryptoBuffer&& iv) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_CTRFactory()->CreateImplementation(std::move(key), std::move(iv)); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_GCMImplementation(const CryptoBuffer& key) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_GCMFactory()->CreateImplementation(key); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_GCMImplementation(const CryptoBuffer& key, const CryptoBuffer* aad) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_GCMFactory()->CreateImplementation(key, aad); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_GCMImplementation(const CryptoBuffer& key, const CryptoBuffer& iv, const CryptoBuffer& tag, const CryptoBuffer& aad) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_GCMFactory()->CreateImplementation(key, iv, tag, aad); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_GCMImplementation(CryptoBuffer&& key, CryptoBuffer&& iv, CryptoBuffer&& tag, CryptoBuffer&& aad) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_GCMFactory()->CreateImplementation(std::move(key), std::move(iv), std::move(tag), std::move(aad)); +} + +std::shared_ptr<SymmetricCipher> Aws::Utils::Crypto::CreateAES_KeyWrapImplementation(const CryptoBuffer& key) +{ +#ifdef NO_SYMMETRIC_ENCRYPTION + return nullptr; +#endif + return GetAES_KeyWrapFactory()->CreateImplementation(key); +} + +#ifdef _WIN32 +#pragma warning(pop) +#endif + +std::shared_ptr<SecureRandomBytes> Aws::Utils::Crypto::CreateSecureRandomBytesImplementation() +{ + return GetSecureRandom(); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/openssl/CryptoImpl.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/openssl/CryptoImpl.cpp new file mode 100644 index 0000000000..911838864b --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/crypto/openssl/CryptoImpl.cpp @@ -0,0 +1,987 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <cstring> + +#include <aws/core/utils/memory/AWSMemory.h> +#include <aws/core/utils/crypto/openssl/CryptoImpl.h> +#include <aws/core/utils/Outcome.h> +#include <openssl/md5.h> + +#ifdef OPENSSL_IS_BORINGSSL +#ifdef _MSC_VER +AWS_SUPPRESS_WARNING_PUSH(4201) +#else +AWS_SUPPRESS_WARNING_PUSH("-Wpedantic") +#endif +#endif + +#include <openssl/sha.h> + +#ifdef OPENSSL_IS_BORINGSSL +AWS_SUPPRESS_WARNING_POP +#endif + +#include <openssl/err.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <thread> + +using namespace Aws::Utils; +using namespace Aws::Utils::Crypto; + +namespace Aws +{ + namespace Utils + { + namespace Crypto + { + namespace OpenSSL + { +/** + * openssl with OPENSSL_VERSION_NUMBER < 0x10100003L made data type details unavailable + * libressl use openssl with data type details available, but mandatorily set + * OPENSSL_VERSION_NUMBER = 0x20000000L, insane! + * https://github.com/aws/aws-sdk-cpp/pull/507/commits/2c99f1fe0c4b4683280caeb161538d4724d6a179 + */ +#if defined(LIBRESSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER == 0x20000000L) +#undef OPENSSL_VERSION_NUMBER +#define OPENSSL_VERSION_NUMBER 0x1000107fL +#endif +#define OPENSSL_VERSION_LESS_1_1 (OPENSSL_VERSION_NUMBER < 0x10100003L) + +#if OPENSSL_VERSION_LESS_1_1 + static const char* OPENSSL_INTERNALS_TAG = "OpenSSLCallbackState"; + static std::mutex* locks(nullptr); +#endif + + GetTheLights getTheLights; + + void init_static_state() + { +#if OPENSSL_VERSION_LESS_1_1 || defined(OPENSSL_IS_BORINGSSL) + ERR_load_crypto_strings(); +#else + OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CRYPTO_STRINGS /*options*/ ,NULL /* OpenSSL init settings*/ ); +#endif +#if !defined(OPENSSL_IS_BORINGSSL) + OPENSSL_add_all_algorithms_noconf(); +#endif +#if OPENSSL_VERSION_LESS_1_1 + if (!CRYPTO_get_locking_callback()) + { + locks = Aws::NewArray<std::mutex>(static_cast<size_t>(CRYPTO_num_locks()), + OPENSSL_INTERNALS_TAG); + CRYPTO_set_locking_callback(&locking_fn); + } + + if (!CRYPTO_get_id_callback()) + { + CRYPTO_set_id_callback(&id_fn); + } +#endif + RAND_poll(); + } + + void cleanup_static_state() + { +#if OPENSSL_VERSION_LESS_1_1 + if (CRYPTO_get_locking_callback() == &locking_fn) + { + CRYPTO_set_locking_callback(nullptr); + assert(locks); + Aws::DeleteArray(locks); + locks = nullptr; + } + + if (CRYPTO_get_id_callback() == &id_fn) + { + CRYPTO_set_id_callback(nullptr); + } +#endif + } + +#if OPENSSL_VERSION_LESS_1_1 + void locking_fn(int mode, int n, const char*, int) + { + if (mode & CRYPTO_LOCK) + { + locks[n].lock(); + } + else + { + locks[n].unlock(); + } + } + + unsigned long id_fn() + { + return static_cast<unsigned long>(std::hash<std::thread::id>()(std::this_thread::get_id())); + } +#endif + } + + static const char* OPENSSL_LOG_TAG = "OpenSSLCipher"; + + void SecureRandomBytes_OpenSSLImpl::GetBytes(unsigned char* buffer, size_t bufferSize) + { + if (!bufferSize) + { + return; + } + + if (!buffer) + { + AWS_LOGSTREAM_FATAL(OPENSSL_LOG_TAG, "Secure Random Bytes generator can't generate: " << bufferSize << " bytes with nullptr buffer."); + assert(buffer); + return; + } + + int success = RAND_bytes(buffer, static_cast<int>(bufferSize)); + if (success != 1) + { + m_failure = true; + } + } + + class OpensslCtxRAIIGuard + { + public: + OpensslCtxRAIIGuard() + { + m_ctx = EVP_MD_CTX_create(); + assert(m_ctx != nullptr); + } + + ~OpensslCtxRAIIGuard() + { + EVP_MD_CTX_destroy(m_ctx); + m_ctx = nullptr; + } + + EVP_MD_CTX* getResource() + { + return m_ctx; + } + private: + EVP_MD_CTX *m_ctx; + }; + + HashResult MD5OpenSSLImpl::Calculate(const Aws::String& str) + { + OpensslCtxRAIIGuard guard; + auto ctx = guard.getResource(); +#if !defined(OPENSSL_IS_BORINGSSL) + EVP_MD_CTX_set_flags(ctx, EVP_MD_CTX_FLAG_NON_FIPS_ALLOW); +#endif + EVP_DigestInit_ex(ctx, EVP_md5(), nullptr); + EVP_DigestUpdate(ctx, str.c_str(), str.size()); + + ByteBuffer hash(EVP_MD_size(EVP_md5())); + EVP_DigestFinal(ctx, hash.GetUnderlyingData(), nullptr); + + return HashResult(std::move(hash)); + } + + HashResult MD5OpenSSLImpl::Calculate(Aws::IStream& stream) + { + OpensslCtxRAIIGuard guard; + auto ctx = guard.getResource(); +#if !defined(OPENSSL_IS_BORINGSSL) + EVP_MD_CTX_set_flags(ctx, EVP_MD_CTX_FLAG_NON_FIPS_ALLOW); +#endif + EVP_DigestInit_ex(ctx, EVP_md5(), nullptr); + + auto currentPos = stream.tellg(); + if (currentPos == -1) + { + currentPos = 0; + stream.clear(); + } + stream.seekg(0, stream.beg); + + char streamBuffer[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE]; + while (stream.good()) + { + stream.read(streamBuffer, Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE); + auto bytesRead = stream.gcount(); + + if (bytesRead > 0) + { + EVP_DigestUpdate(ctx, streamBuffer, static_cast<size_t>(bytesRead)); + } + } + + stream.clear(); + stream.seekg(currentPos, stream.beg); + + ByteBuffer hash(EVP_MD_size(EVP_md5())); + EVP_DigestFinal(ctx, hash.GetUnderlyingData(), nullptr); + + return HashResult(std::move(hash)); + } + + HashResult Sha256OpenSSLImpl::Calculate(const Aws::String& str) + { + OpensslCtxRAIIGuard guard; + auto ctx = guard.getResource(); + EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr); + EVP_DigestUpdate(ctx, str.c_str(), str.size()); + + ByteBuffer hash(EVP_MD_size(EVP_sha256())); + EVP_DigestFinal(ctx, hash.GetUnderlyingData(), nullptr); + + return HashResult(std::move(hash)); + } + + HashResult Sha256OpenSSLImpl::Calculate(Aws::IStream& stream) + { + OpensslCtxRAIIGuard guard; + auto ctx = guard.getResource(); + + EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr); + + auto currentPos = stream.tellg(); + if (currentPos == -1) + { + currentPos = 0; + stream.clear(); + } + + stream.seekg(0, stream.beg); + + char streamBuffer[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE]; + while (stream.good()) + { + stream.read(streamBuffer, Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE); + auto bytesRead = stream.gcount(); + + if (bytesRead > 0) + { + EVP_DigestUpdate(ctx, streamBuffer, static_cast<size_t>(bytesRead)); + } + } + + stream.clear(); + stream.seekg(currentPos, stream.beg); + + ByteBuffer hash(EVP_MD_size(EVP_sha256())); + EVP_DigestFinal(ctx, hash.GetUnderlyingData(), nullptr); + + return HashResult(std::move(hash)); + } + + class HMACRAIIGuard { + public: + HMACRAIIGuard() { +#if OPENSSL_VERSION_LESS_1_1 + m_ctx = Aws::New<HMAC_CTX>("AllocSha256HAMCOpenSSLContext"); +#else + m_ctx = HMAC_CTX_new(); +#endif + assert(m_ctx != nullptr); + } + + ~HMACRAIIGuard() { +#if OPENSSL_VERSION_LESS_1_1 + Aws::Delete<HMAC_CTX>(m_ctx); +#else + HMAC_CTX_free(m_ctx); +#endif + m_ctx = nullptr; + } + + HMAC_CTX* getResource() { + return m_ctx; + } + private: + HMAC_CTX *m_ctx; + }; + + HashResult Sha256HMACOpenSSLImpl::Calculate(const ByteBuffer& toSign, const ByteBuffer& secret) + { + unsigned int length = SHA256_DIGEST_LENGTH; + ByteBuffer digest(length); + memset(digest.GetUnderlyingData(), 0, length); + + HMACRAIIGuard guard; + HMAC_CTX* m_ctx = guard.getResource(); + +#if OPENSSL_VERSION_LESS_1_1 + HMAC_CTX_init(m_ctx); +#endif + + HMAC_Init_ex(m_ctx, secret.GetUnderlyingData(), static_cast<int>(secret.GetLength()), EVP_sha256(), + NULL); + HMAC_Update(m_ctx, toSign.GetUnderlyingData(), toSign.GetLength()); + HMAC_Final(m_ctx, digest.GetUnderlyingData(), &length); + +#if OPENSSL_VERSION_LESS_1_1 + HMAC_CTX_cleanup(m_ctx); +#else + HMAC_CTX_reset(m_ctx); +#endif + return HashResult(std::move(digest)); + } + + void LogErrors(const char* logTag = OPENSSL_LOG_TAG) + { + unsigned long errorCode = ERR_get_error(); + char errStr[256]; + ERR_error_string_n(errorCode, errStr, 256); + + AWS_LOGSTREAM_ERROR(logTag, errStr); + } + + OpenSSLCipher::OpenSSLCipher(const CryptoBuffer& key, size_t blockSizeBytes, bool ctrMode) : + SymmetricCipher(key, blockSizeBytes, ctrMode), m_encryptor_ctx(nullptr), m_decryptor_ctx(nullptr) + { + Init(); + } + + OpenSSLCipher::OpenSSLCipher(OpenSSLCipher&& toMove) : SymmetricCipher(std::move(toMove)), + m_encryptor_ctx(nullptr), m_decryptor_ctx(nullptr) + { + Init(); + EVP_CIPHER_CTX_copy(m_encryptor_ctx, toMove.m_encryptor_ctx); + EVP_CIPHER_CTX_copy(m_decryptor_ctx, toMove.m_decryptor_ctx); + EVP_CIPHER_CTX_cleanup(toMove.m_encryptor_ctx); + EVP_CIPHER_CTX_cleanup(toMove.m_decryptor_ctx); + } + + OpenSSLCipher::OpenSSLCipher(CryptoBuffer&& key, CryptoBuffer&& initializationVector, CryptoBuffer&& tag) : + SymmetricCipher(std::move(key), std::move(initializationVector), std::move(tag)), + m_encryptor_ctx(nullptr), m_decryptor_ctx(nullptr) + { + Init(); + } + + OpenSSLCipher::OpenSSLCipher(const CryptoBuffer& key, const CryptoBuffer& initializationVector, + const CryptoBuffer& tag) : + SymmetricCipher(key, initializationVector, tag), m_encryptor_ctx(nullptr), m_decryptor_ctx(nullptr) + { + Init(); + } + + OpenSSLCipher::~OpenSSLCipher() + { + Cleanup(); + if (m_encryptor_ctx) + { + EVP_CIPHER_CTX_free(m_encryptor_ctx); + m_encryptor_ctx = nullptr; + } + if (m_decryptor_ctx) + { + EVP_CIPHER_CTX_free(m_decryptor_ctx); + m_decryptor_ctx = nullptr; + } + } + + void OpenSSLCipher::Init() + { + if (m_failure) + { + return; + } + + if (!m_encryptor_ctx) + { + // EVP_CIPHER_CTX_init() will be called inside EVP_CIPHER_CTX_new(). + m_encryptor_ctx = EVP_CIPHER_CTX_new(); + assert(m_encryptor_ctx != nullptr); + } + else + { // _init is the same as _reset after openssl 1.1 + EVP_CIPHER_CTX_init(m_encryptor_ctx); + } + if (!m_decryptor_ctx) + { + // EVP_CIPHER_CTX_init() will be called inside EVP_CIPHER_CTX_new(). + m_decryptor_ctx = EVP_CIPHER_CTX_new(); + assert(m_decryptor_ctx != nullptr); + } + else + { // _init is the same as _reset after openssl 1.1 + EVP_CIPHER_CTX_init(m_decryptor_ctx); + } + m_emptyPlaintext = false; + } + + CryptoBuffer OpenSSLCipher::EncryptBuffer(const CryptoBuffer& unEncryptedData) + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(OPENSSL_LOG_TAG, "Cipher not properly initialized for encryption. Aborting"); + return CryptoBuffer(); + } + + int lengthWritten = static_cast<int>(unEncryptedData.GetLength() + (GetBlockSizeBytes() - 1)); + CryptoBuffer encryptedText(static_cast<size_t>( lengthWritten + (GetBlockSizeBytes() - 1))); + + if (!EVP_EncryptUpdate(m_encryptor_ctx, encryptedText.GetUnderlyingData(), &lengthWritten, + unEncryptedData.GetUnderlyingData(), + static_cast<int>(unEncryptedData.GetLength()))) + { + m_failure = true; + LogErrors(); + return CryptoBuffer(); + } + + if (static_cast<size_t>(lengthWritten) < encryptedText.GetLength()) + { + return CryptoBuffer(encryptedText.GetUnderlyingData(), static_cast<size_t>(lengthWritten)); + } + return encryptedText; + } + + CryptoBuffer OpenSSLCipher::FinalizeEncryption() + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(OPENSSL_LOG_TAG, "Cipher not properly initialized for encryption finalization. Aborting"); + return CryptoBuffer(); + } + + CryptoBuffer finalBlock(GetBlockSizeBytes()); + int writtenSize = 0; + if (!EVP_EncryptFinal_ex(m_encryptor_ctx, finalBlock.GetUnderlyingData(), &writtenSize)) + { + m_failure = true; + LogErrors(); + return CryptoBuffer(); + } + return CryptoBuffer(finalBlock.GetUnderlyingData(), static_cast<size_t>(writtenSize)); + } + + CryptoBuffer OpenSSLCipher::DecryptBuffer(const CryptoBuffer& encryptedData) + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(OPENSSL_LOG_TAG, "Cipher not properly initialized for decryption. Aborting"); + return CryptoBuffer(); + } + + int lengthWritten = static_cast<int>(encryptedData.GetLength() + (GetBlockSizeBytes() - 1)); + CryptoBuffer decryptedText(static_cast<size_t>(lengthWritten)); + + if (!EVP_DecryptUpdate(m_decryptor_ctx, decryptedText.GetUnderlyingData(), &lengthWritten, + encryptedData.GetUnderlyingData(), + static_cast<int>(encryptedData.GetLength()))) + { + m_failure = true; + LogErrors(); + return CryptoBuffer(); + } + + if (lengthWritten == 0) + { + m_emptyPlaintext = true; + } + if (static_cast<size_t>(lengthWritten) < decryptedText.GetLength()) + { + return CryptoBuffer(decryptedText.GetUnderlyingData(), static_cast<size_t>(lengthWritten)); + } + return decryptedText; + } + + CryptoBuffer OpenSSLCipher::FinalizeDecryption() + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(OPENSSL_LOG_TAG, "Cipher not properly initialized for decryption finalization. Aborting"); + return CryptoBuffer(); + } + + CryptoBuffer finalBlock(GetBlockSizeBytes()); + int writtenSize = static_cast<int>(finalBlock.GetLength()); + int ret = EVP_DecryptFinal_ex(m_decryptor_ctx, finalBlock.GetUnderlyingData(), &writtenSize); +#if OPENSSL_VERSION_NUMBER > 0x1010104fL //1.1.1d + if (ret <= 0) +#else + if (ret <= 0 && !m_emptyPlaintext) // see details why making exception for empty string at: https://github.com/aws/aws-sdk-cpp/issues/1413 +#endif + { + m_failure = true; + LogErrors(); + return CryptoBuffer(); + } + return CryptoBuffer(finalBlock.GetUnderlyingData(), static_cast<size_t>(writtenSize)); + } + + void OpenSSLCipher::Reset() + { + Cleanup(); + Init(); + } + + void OpenSSLCipher::Cleanup() + { + m_failure = false; + if (m_encryptor_ctx) EVP_CIPHER_CTX_cleanup(m_encryptor_ctx); + if (m_decryptor_ctx) EVP_CIPHER_CTX_cleanup(m_decryptor_ctx); + } + + bool OpenSSLCipher::CheckKeyAndIVLength(size_t expectedKeyLength, size_t expectedIVLength) + { + if (!m_failure && ((m_key.GetLength() != expectedKeyLength) || m_initializationVector.GetLength() != expectedIVLength)) + { + AWS_LOGSTREAM_ERROR(OPENSSL_LOG_TAG, "Expected Key size is: " << expectedKeyLength << " and expected IV size is: " << expectedIVLength); + m_failure = true; + } + return !m_failure; + } + + size_t AES_CBC_Cipher_OpenSSL::BlockSizeBytes = 16; + size_t AES_CBC_Cipher_OpenSSL::KeyLengthBits = 256; + static const char* CBC_LOG_TAG = "AES_CBC_Cipher_OpenSSL"; + + AES_CBC_Cipher_OpenSSL::AES_CBC_Cipher_OpenSSL(const CryptoBuffer& key) : OpenSSLCipher(key, BlockSizeBytes) + { + InitCipher(); + } + + AES_CBC_Cipher_OpenSSL::AES_CBC_Cipher_OpenSSL(CryptoBuffer&& key, CryptoBuffer&& initializationVector) : + OpenSSLCipher(std::move(key), std::move(initializationVector)) + { + InitCipher(); + } + + AES_CBC_Cipher_OpenSSL::AES_CBC_Cipher_OpenSSL(const CryptoBuffer& key, + const CryptoBuffer& initializationVector) : + OpenSSLCipher(key, initializationVector) + { + InitCipher(); + } + + void AES_CBC_Cipher_OpenSSL::InitCipher() + { + if (m_failure || !CheckKeyAndIVLength(KeyLengthBits/8, BlockSizeBytes)) + { + return; + } + + if (!EVP_EncryptInit_ex(m_encryptor_ctx, EVP_aes_256_cbc(), nullptr, m_key.GetUnderlyingData(), + m_initializationVector.GetUnderlyingData()) || + !EVP_DecryptInit_ex(m_decryptor_ctx, EVP_aes_256_cbc(), nullptr, m_key.GetUnderlyingData(), + m_initializationVector.GetUnderlyingData())) + { + m_failure = true; + LogErrors(CBC_LOG_TAG); + } + } + + size_t AES_CBC_Cipher_OpenSSL::GetBlockSizeBytes() const + { + return BlockSizeBytes; + } + + size_t AES_CBC_Cipher_OpenSSL::GetKeyLengthBits() const + { + return KeyLengthBits; + } + + void AES_CBC_Cipher_OpenSSL::Reset() + { + OpenSSLCipher::Reset(); + InitCipher(); + } + + size_t AES_CTR_Cipher_OpenSSL::BlockSizeBytes = 16; + size_t AES_CTR_Cipher_OpenSSL::KeyLengthBits = 256; + static const char* CTR_LOG_TAG = "AES_CTR_Cipher_OpenSSL"; + + AES_CTR_Cipher_OpenSSL::AES_CTR_Cipher_OpenSSL(const CryptoBuffer& key) : OpenSSLCipher(key, BlockSizeBytes, + true) + { + InitCipher(); + } + + AES_CTR_Cipher_OpenSSL::AES_CTR_Cipher_OpenSSL(CryptoBuffer&& key, CryptoBuffer&& initializationVector) : + OpenSSLCipher(std::move(key), std::move(initializationVector)) + { + InitCipher(); + } + + AES_CTR_Cipher_OpenSSL::AES_CTR_Cipher_OpenSSL(const CryptoBuffer& key, + const CryptoBuffer& initializationVector) : + OpenSSLCipher(key, initializationVector) + { + InitCipher(); + } + + void AES_CTR_Cipher_OpenSSL::InitCipher() + { + if (m_failure || !CheckKeyAndIVLength(KeyLengthBits/8, BlockSizeBytes)) + { + return; + } + + if (!(EVP_EncryptInit_ex(m_encryptor_ctx, EVP_aes_256_ctr(), nullptr, m_key.GetUnderlyingData(), + m_initializationVector.GetUnderlyingData()) + && EVP_CIPHER_CTX_set_padding(m_encryptor_ctx, 0)) || + !(EVP_DecryptInit_ex(m_decryptor_ctx, EVP_aes_256_ctr(), nullptr, m_key.GetUnderlyingData(), + m_initializationVector.GetUnderlyingData()) + && EVP_CIPHER_CTX_set_padding(m_decryptor_ctx, 0))) + { + m_failure = true; + LogErrors(CTR_LOG_TAG); + } + } + + size_t AES_CTR_Cipher_OpenSSL::GetBlockSizeBytes() const + { + return BlockSizeBytes; + } + + size_t AES_CTR_Cipher_OpenSSL::GetKeyLengthBits() const + { + return KeyLengthBits; + } + + void AES_CTR_Cipher_OpenSSL::Reset() + { + OpenSSLCipher::Reset(); + InitCipher(); + } + + size_t AES_GCM_Cipher_OpenSSL::BlockSizeBytes = 16; + size_t AES_GCM_Cipher_OpenSSL::KeyLengthBits = 256; + size_t AES_GCM_Cipher_OpenSSL::IVLengthBytes = 12; + size_t AES_GCM_Cipher_OpenSSL::TagLengthBytes = 16; + + static const char* GCM_LOG_TAG = "AES_GCM_Cipher_OpenSSL"; + + AES_GCM_Cipher_OpenSSL::AES_GCM_Cipher_OpenSSL(const CryptoBuffer& key) + : OpenSSLCipher(key, IVLengthBytes) + { + InitCipher(); + } + + AES_GCM_Cipher_OpenSSL::AES_GCM_Cipher_OpenSSL(const CryptoBuffer& key, const CryptoBuffer* aad) + : OpenSSLCipher(key, IVLengthBytes), m_aad(*aad) + { + InitCipher(); + } + + AES_GCM_Cipher_OpenSSL::AES_GCM_Cipher_OpenSSL(CryptoBuffer&& key, CryptoBuffer&& initializationVector, + CryptoBuffer&& tag, CryptoBuffer&& aad) : + OpenSSLCipher(std::move(key), std::move(initializationVector), std::move(tag)), m_aad(std::move(aad)) + { + InitCipher(); + } + + AES_GCM_Cipher_OpenSSL::AES_GCM_Cipher_OpenSSL(const CryptoBuffer& key, const CryptoBuffer& initializationVector, + const CryptoBuffer& tag, const CryptoBuffer& aad) : + OpenSSLCipher(key, initializationVector, tag), m_aad(std::move(aad)) + { + InitCipher(); + } + + CryptoBuffer AES_GCM_Cipher_OpenSSL::FinalizeEncryption() + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(GCM_LOG_TAG, "Cipher not properly initialized for encryption finalization. Aborting"); + return CryptoBuffer(); + } + + int writtenSize = 0; + CryptoBuffer finalBlock(GetBlockSizeBytes()); + EVP_EncryptFinal_ex(m_encryptor_ctx, finalBlock.GetUnderlyingData(), &writtenSize); + + m_tag = CryptoBuffer(TagLengthBytes); + if (!EVP_CIPHER_CTX_ctrl(m_encryptor_ctx, EVP_CTRL_GCM_GET_TAG, static_cast<int>(m_tag.GetLength()), + m_tag.GetUnderlyingData())) + { + m_failure = true; + LogErrors(GCM_LOG_TAG); + } + + return CryptoBuffer(); + } + + void AES_GCM_Cipher_OpenSSL::InitCipher() + { + if (m_failure || !CheckKeyAndIVLength(KeyLengthBits/8, IVLengthBytes)) + { + return; + } + + if (!(EVP_EncryptInit_ex(m_encryptor_ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) && + EVP_EncryptInit_ex(m_encryptor_ctx, nullptr, nullptr, m_key.GetUnderlyingData(), + m_initializationVector.GetUnderlyingData()) && + EVP_CIPHER_CTX_set_padding(m_encryptor_ctx, 0)) || + !(EVP_DecryptInit_ex(m_decryptor_ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) && + EVP_DecryptInit_ex(m_decryptor_ctx, nullptr, nullptr, m_key.GetUnderlyingData(), + m_initializationVector.GetUnderlyingData()) && + EVP_CIPHER_CTX_set_padding(m_decryptor_ctx, 0))) + { + m_failure = true; + LogErrors(GCM_LOG_TAG); + return; + } + + if (m_aad.GetLength() > 0) + { + int outLen = 0; + if(!EVP_EncryptUpdate(m_encryptor_ctx, nullptr, &outLen, m_aad.GetUnderlyingData(), m_aad.GetLength()) + || !EVP_DecryptUpdate(m_decryptor_ctx, nullptr, &outLen, m_aad.GetUnderlyingData(), m_aad.GetLength())) + { + m_failure = true; + LogErrors(GCM_LOG_TAG); + return; + } + } + + //tag should always be set in GCM decrypt mode + if (m_tag.GetLength() > 0) + { + if (m_tag.GetLength() < TagLengthBytes) + { + AWS_LOGSTREAM_ERROR(GCM_LOG_TAG, "Illegal attempt to decrypt an AES GCM payload without a valid tag set: tag length=" << m_tag.GetLength()); + m_failure = true; + return; + } + + if (!EVP_CIPHER_CTX_ctrl(m_decryptor_ctx, EVP_CTRL_GCM_SET_TAG, static_cast<int>(m_tag.GetLength()), m_tag.GetUnderlyingData())) + { + m_failure = true; + LogErrors(GCM_LOG_TAG); + } + } + } + + size_t AES_GCM_Cipher_OpenSSL::GetBlockSizeBytes() const + { + return BlockSizeBytes; + } + + size_t AES_GCM_Cipher_OpenSSL::GetKeyLengthBits() const + { + return KeyLengthBits; + } + + size_t AES_GCM_Cipher_OpenSSL::GetTagLengthBytes() const + { + return TagLengthBytes; + } + + void AES_GCM_Cipher_OpenSSL::Reset() + { + OpenSSLCipher::Reset(); + InitCipher(); + } + + size_t AES_KeyWrap_Cipher_OpenSSL::KeyLengthBits = 256; + size_t AES_KeyWrap_Cipher_OpenSSL::BlockSizeBytes = 8; + static const unsigned char INTEGRITY_VALUE = 0xA6; + static const size_t MIN_CEK_LENGTH_BYTES = 128 / 8; + + static const char* KEY_WRAP_TAG = "AES_KeyWrap_Cipher_OpenSSL"; + + AES_KeyWrap_Cipher_OpenSSL::AES_KeyWrap_Cipher_OpenSSL(const CryptoBuffer& key) : OpenSSLCipher(key, 0) + { + InitCipher(); + } + + CryptoBuffer AES_KeyWrap_Cipher_OpenSSL::EncryptBuffer(const CryptoBuffer& plainText) + { + if (!m_failure) + { + m_workingKeyBuffer = CryptoBuffer({&m_workingKeyBuffer, (CryptoBuffer*) &plainText}); + } + return CryptoBuffer(); + } + + CryptoBuffer AES_KeyWrap_Cipher_OpenSSL::FinalizeEncryption() + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(KEY_WRAP_TAG, "Cipher not properly initialized for encryption finalization. Aborting"); + return CryptoBuffer(); + } + + if (m_workingKeyBuffer.GetLength() < MIN_CEK_LENGTH_BYTES) + { + AWS_LOGSTREAM_ERROR(KEY_WRAP_TAG, "Incorrect input length of " << m_workingKeyBuffer.GetLength()); + m_failure = true; + return CryptoBuffer(); + } + + //the following is an in place implementation of + //RFC 3394 using the alternate in-place implementation. + //we use one in-place buffer instead of the copy at the end. + //the one letter variable names are meant to directly reflect the variables in the RFC + CryptoBuffer cipherText(m_workingKeyBuffer.GetLength() + BlockSizeBytes); + + //put the integrity check register in the first 8 bytes of the final buffer. + memset(cipherText.GetUnderlyingData(), INTEGRITY_VALUE, BlockSizeBytes); + unsigned char* a = cipherText.GetUnderlyingData(); + + //put the register buffer after the integrity check register + memcpy(cipherText.GetUnderlyingData() + BlockSizeBytes, m_workingKeyBuffer.GetUnderlyingData(), + m_workingKeyBuffer.GetLength()); + unsigned char* r = cipherText.GetUnderlyingData() + BlockSizeBytes; + + int n = static_cast<int>(m_workingKeyBuffer.GetLength() / BlockSizeBytes); + + //temporary encryption buffer + CryptoBuffer b(BlockSizeBytes * 2); + int outLen = static_cast<int>(b.GetLength()); + + //concatenation buffer + CryptoBuffer tempInput(BlockSizeBytes * 2); + + for (int j = 0; j <= 5; ++j) + { + for (int i = 1; i <= n; ++i) + { + //concat A and R[i], A should be most significant and then R[i] should be least significant. + memcpy(tempInput.GetUnderlyingData(), a, BlockSizeBytes); + memcpy(tempInput.GetUnderlyingData() + BlockSizeBytes, r, BlockSizeBytes); + + //encrypt the concatenated A and R[I] and store it in B + if (!EVP_EncryptUpdate(m_encryptor_ctx, b.GetUnderlyingData(), &outLen, + tempInput.GetUnderlyingData(), static_cast<int>(tempInput.GetLength()))) + { + LogErrors(KEY_WRAP_TAG); + m_failure = true; + return CryptoBuffer(); + } + + unsigned char t = static_cast<unsigned char>((n * j) + i); + //put the 64 MSB ^ T into A + memcpy(a, b.GetUnderlyingData(), BlockSizeBytes); + a[7] ^= t; + //put the 64 LSB into R[i] + memcpy(r, b.GetUnderlyingData() + BlockSizeBytes, BlockSizeBytes); + //increment i -> R[i] + r += BlockSizeBytes; + } + //reset R + r = cipherText.GetUnderlyingData() + BlockSizeBytes; + } + + return cipherText; + } + + CryptoBuffer AES_KeyWrap_Cipher_OpenSSL::DecryptBuffer(const CryptoBuffer& cipherText) + { + if (!m_failure) + { + m_workingKeyBuffer = CryptoBuffer({&m_workingKeyBuffer, (CryptoBuffer*)&cipherText}); + } + return CryptoBuffer(); + } + + CryptoBuffer AES_KeyWrap_Cipher_OpenSSL::FinalizeDecryption() + { + if (m_failure) + { + AWS_LOGSTREAM_FATAL(KEY_WRAP_TAG, "Cipher not properly initialized for decryption finalization. Aborting"); + return CryptoBuffer(); + } + + if (m_workingKeyBuffer.GetLength() < MIN_CEK_LENGTH_BYTES + BlockSizeBytes) + { + AWS_LOGSTREAM_ERROR(KEY_WRAP_TAG, "Incorrect input length of " << m_workingKeyBuffer.GetLength()); + m_failure = true; + return CryptoBuffer(); + } + + //the following is an in place implementation of + //RFC 3394 using the alternate in-place implementation. + //we use one in-place buffer instead of the copy at the end. + //the one letter variable names are meant to directly reflect the variables in the RFC + CryptoBuffer plainText(m_workingKeyBuffer.GetLength() - BlockSizeBytes); + memcpy(plainText.GetUnderlyingData(), m_workingKeyBuffer.GetUnderlyingData() + BlockSizeBytes, plainText.GetLength()); + + //integrity register should be the first 8 bytes of the cipher text + unsigned char* a = m_workingKeyBuffer.GetUnderlyingData(); + + //in-place register is the plaintext. For decryption, start at the last array position (8 bytes before the end); + unsigned char* r = plainText.GetUnderlyingData() + plainText.GetLength() - BlockSizeBytes; + + int n = static_cast<int>(plainText.GetLength() / BlockSizeBytes); + + //temporary encryption buffer + CryptoBuffer b(BlockSizeBytes * 10); + int outLen = static_cast<int>(b.GetLength()); + + //concatenation buffer + CryptoBuffer tempInput(BlockSizeBytes * 2); + + for(int j = 5; j >= 0; --j) + { + for(int i = n; i >= 1; --i) + { + //concat + //A ^ t + memcpy(tempInput.GetUnderlyingData(), a, BlockSizeBytes); + unsigned char t = static_cast<unsigned char>((n * j) + i); + tempInput[7] ^= t; + //R[i] + memcpy(tempInput.GetUnderlyingData() + BlockSizeBytes, r, BlockSizeBytes); + + //Decrypt the concatenated buffer + if(!EVP_DecryptUpdate(m_decryptor_ctx, b.GetUnderlyingData(), &outLen, + tempInput.GetUnderlyingData(), static_cast<int>(tempInput.GetLength()))) + { + m_failure = true; + LogErrors(KEY_WRAP_TAG); + return CryptoBuffer(); + } + + //set A to MSB 64 bits of decrypted result + memcpy(a, b.GetUnderlyingData(), BlockSizeBytes); + //set R[i] to LSB 64 bits of decrypted result + memcpy(r, b.GetUnderlyingData() + BlockSizeBytes, BlockSizeBytes); + //decrement i -> R[i] + r -= BlockSizeBytes; + } + + r = plainText.GetUnderlyingData() + plainText.GetLength() - BlockSizeBytes; + } + + //here we perform the integrity check to make sure A == 0xA6A6A6A6A6A6A6A6 + for(size_t i = 0; i < BlockSizeBytes; ++i) + { + if(a[i] != INTEGRITY_VALUE) + { + m_failure = true; + AWS_LOGSTREAM_ERROR(KEY_WRAP_TAG, "Integrity check failed for key wrap decryption."); + return CryptoBuffer(); + } + } + + return plainText; + } + + void AES_KeyWrap_Cipher_OpenSSL::InitCipher() + { + if (m_failure || !CheckKeyAndIVLength(KeyLengthBits/8, 0)) + { + return; + } + + if (!(EVP_EncryptInit_ex(m_encryptor_ctx, EVP_aes_256_ecb(), nullptr, m_key.GetUnderlyingData(), nullptr) && + EVP_CIPHER_CTX_set_padding(m_encryptor_ctx, 0)) || + !(EVP_DecryptInit_ex(m_decryptor_ctx, EVP_aes_256_ecb(), nullptr, m_key.GetUnderlyingData(), nullptr) && + EVP_CIPHER_CTX_set_padding(m_decryptor_ctx, 0))) + { + m_failure = true; + LogErrors(KEY_WRAP_TAG); + } + } + + void AES_KeyWrap_Cipher_OpenSSL::Reset() + { + m_workingKeyBuffer = CryptoBuffer(); + OpenSSLCipher::Reset(); + InitCipher(); + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventDecoderStream.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventDecoderStream.cpp new file mode 100644 index 0000000000..5ecd2d0444 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventDecoderStream.cpp @@ -0,0 +1,22 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/event/EventDecoderStream.h> +#include <iostream> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + EventDecoderStream::EventDecoderStream(EventStreamDecoder& decoder, size_t bufferSize) : + Aws::IOStream(&m_eventStreamBuf), + m_eventStreamBuf(decoder, bufferSize) + + { + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp new file mode 100644 index 0000000000..f8640f0e8c --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventEncoderStream.cpp @@ -0,0 +1,28 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/event/EventEncoderStream.h> +#include <iostream> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + EventEncoderStream::EventEncoderStream(size_t bufferSize) : + Aws::IOStream(&m_streambuf), + m_streambuf(bufferSize) + { + } + + EventEncoderStream& EventEncoderStream::WriteEvent(const Aws::Utils::Event::Message& msg) + { + auto bits = m_encoder.EncodeAndSign(msg); + write(reinterpret_cast<char*>(bits.data()), bits.size()); + return *this; + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventHeader.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventHeader.cpp new file mode 100644 index 0000000000..c3c989bedb --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventHeader.cpp @@ -0,0 +1,107 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/event/EventHeader.h> +#include <aws/core/utils/HashingUtils.h> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + static const int HASH_BOOL_TRUE = HashingUtils::HashString("BOOL_TRUE"); + static const int HASH_BOOL_FALSE = HashingUtils::HashString("BOOL_FALSE"); + static const int HASH_BYTE = HashingUtils::HashString("BYTE"); + static const int HASH_INT16 = HashingUtils::HashString("INT16"); + static const int HASH_INT32 = HashingUtils::HashString("INT32"); + static const int HASH_INT64 = HashingUtils::HashString("INT64"); + static const int HASH_BYTE_BUF = HashingUtils::HashString("BYTE_BUFFER"); + static const int HASH_STRING = HashingUtils::HashString("STRING"); + static const int HASH_TIMESTAMP = HashingUtils::HashString("TIMESTAMP"); + static const int HASH_UUID = HashingUtils::HashString("UUID"); + + EventHeaderValue::EventHeaderType EventHeaderValue::GetEventHeaderTypeForName(const Aws::String& name) + { + int hashCode = Aws::Utils::HashingUtils::HashString(name.c_str()); + if (hashCode == HASH_BOOL_TRUE) + { + return EventHeaderType::BOOL_TRUE; + } + else if (hashCode == HASH_BOOL_FALSE) + { + return EventHeaderType::BOOL_FALSE; + } + else if (hashCode == HASH_BYTE) + { + return EventHeaderType::BYTE; + } + else if (hashCode == HASH_INT16) + { + return EventHeaderType::INT16; + } + else if (hashCode == HASH_INT32) + { + return EventHeaderType::INT32; + } + else if (hashCode == HASH_INT64) + { + return EventHeaderType::INT64; + } + else if (hashCode == HASH_BYTE_BUF) + { + return EventHeaderType::BYTE_BUF; + } + else if (hashCode == HASH_STRING) + { + return EventHeaderType::STRING; + } + else if (hashCode == HASH_TIMESTAMP) + { + return EventHeaderType::TIMESTAMP; + } + else if (hashCode == HASH_UUID) + { + return EventHeaderType::UUID; + } + else + { + return EventHeaderType::UNKNOWN; + } + } + + Aws::String EventHeaderValue::GetNameForEventHeaderType(EventHeaderType value) + { + switch (value) + { + case EventHeaderType::BOOL_TRUE: + return "BOOL_TRUE"; + case EventHeaderType::BOOL_FALSE: + return "BOOL_FALSE"; + case EventHeaderType::BYTE: + return "BYTE"; + case EventHeaderType::INT16: + return "INT16"; + case EventHeaderType::INT32: + return "INT32"; + case EventHeaderType::INT64: + return "INT64"; + case EventHeaderType::BYTE_BUF: + return "BYTE_BUF"; + case EventHeaderType::STRING: + return "STRING"; + case EventHeaderType::TIMESTAMP: + return "TIMESTAMP"; + case EventHeaderType::UUID: + return "UUID"; + default: + return "UNKNOWN"; + } + } + + } + } +} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventMessage.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventMessage.cpp new file mode 100644 index 0000000000..de8b904775 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventMessage.cpp @@ -0,0 +1,132 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/event/EventMessage.h> +#include <aws/core/utils/HashingUtils.h> +#include <algorithm> +#include <iterator> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + const char EVENT_TYPE_HEADER[] = ":event-type"; + const char CONTENT_TYPE_HEADER[] = ":content-type"; + const char MESSAGE_TYPE_HEADER[] = ":message-type"; + const char ERROR_CODE_HEADER[] = ":error-code"; + const char ERROR_MESSAGE_HEADER[] = ":error-message"; + const char EXCEPTION_TYPE_HEADER[] = ":exception-type"; + + static const int EVENT_HASH = HashingUtils::HashString("event"); + static const int ERROR_HASH = HashingUtils::HashString("error"); + static const int EXCEPTION_HASH = HashingUtils::HashString("exception"); + + static const int CONTENT_TYPE_APPLICATION_OCTET_STREAM_HASH = HashingUtils::HashString("application/octet-stream"); + static const int CONTENT_TYPE_APPLICATION_JSON_HASH = HashingUtils::HashString("application/json"); + static const int CONTENT_TYPE_TEXT_PLAIN_HASH = HashingUtils::HashString("text/plain"); + + Message::MessageType Message::GetMessageTypeForName(const Aws::String& name) + { + int hashCode = Aws::Utils::HashingUtils::HashString(name.c_str()); + if (hashCode == EVENT_HASH) + { + return MessageType::EVENT; + } + else if (hashCode == ERROR_HASH) + { + return MessageType::REQUEST_LEVEL_ERROR; + } + else if (hashCode == EXCEPTION_HASH) + { + return MessageType::REQUEST_LEVEL_EXCEPTION; + } + else + { + return MessageType::UNKNOWN; + } + } + + Aws::String Message::GetNameForMessageType(MessageType value) + { + switch (value) + { + case MessageType::EVENT: + return "event"; + case MessageType::REQUEST_LEVEL_ERROR: + return "error"; + case MessageType::REQUEST_LEVEL_EXCEPTION: + return "exception"; + default: + return "unknown"; + } + } + + Message::ContentType Message::GetContentTypeForName(const Aws::String& name) + { + int hashCode = Aws::Utils::HashingUtils::HashString(name.c_str()); + if (hashCode == CONTENT_TYPE_APPLICATION_OCTET_STREAM_HASH) + { + return ContentType::APPLICATION_OCTET_STREAM; + } + else if (hashCode == CONTENT_TYPE_APPLICATION_JSON_HASH) + { + return ContentType::APPLICATION_JSON; + } + else if (hashCode == CONTENT_TYPE_TEXT_PLAIN_HASH) + { + return ContentType::TEXT_PLAIN; + } + else + { + return ContentType::UNKNOWN; + } + } + + Aws::String Message::GetNameForContentType(ContentType value) + { + switch (value) + { + case ContentType::APPLICATION_OCTET_STREAM: + return "application/octet-stream"; + case ContentType::APPLICATION_JSON: + return "application/json"; + case ContentType::TEXT_PLAIN: + return "text/plain"; + default: + return "unknown"; + } + } + + void Message::Reset() + { + m_totalLength = 0; + m_headersLength = 0; + m_payloadLength = 0; + + m_eventHeaders.clear(); + m_eventPayload.clear(); + } + + void Message::WriteEventPayload(const unsigned char* data, size_t length) + { + std::copy(data, data + length, std::back_inserter(m_eventPayload)); + } + + void Message::WriteEventPayload(const Aws::Vector<unsigned char>& bits) + { + std::copy(bits.cbegin(), bits.cend(), std::back_inserter(m_eventPayload)); + } + + void Message::WriteEventPayload(const Aws::String& bits) + { + std::copy(bits.cbegin(), bits.cend(), std::back_inserter(m_eventPayload)); + } + + } // namespace Event + } // namespace Utils +} // namespace Aws + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamBuf.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamBuf.cpp new file mode 100644 index 0000000000..6a1766bb9f --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamBuf.cpp @@ -0,0 +1,147 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/event/EventStreamBuf.h> +#include <cassert> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + const size_t DEFAULT_BUF_SIZE = 1024; + + EventStreamBuf::EventStreamBuf(EventStreamDecoder& decoder, size_t bufferLength) : + m_byteBuffer(bufferLength), + m_bufferLength(bufferLength), + m_decoder(decoder) + { + assert(decoder); + char* begin = reinterpret_cast<char*>(m_byteBuffer.GetUnderlyingData()); + char* end = begin + bufferLength - 1; + + setp(begin, end); + setg(begin, begin, begin); + } + + EventStreamBuf::~EventStreamBuf() + { + if (m_decoder) + { + writeToDecoder(); + } + } + + void EventStreamBuf::writeToDecoder() + { + if (pptr() > pbase()) + { + size_t length = static_cast<size_t>(pptr() - pbase()); + m_decoder.Pump(m_byteBuffer, length); + + if (!m_decoder) + { + m_err.write(reinterpret_cast<char*>(m_byteBuffer.GetUnderlyingData()), length); + } + else + { + pbump(-static_cast<int>(length)); + } + } + } + + std::streampos EventStreamBuf::seekoff(std::streamoff off, std::ios_base::seekdir dir, std::ios_base::openmode which) + { + if (dir == std::ios_base::beg) + { + return seekpos(off, which); + } + else if (dir == std::ios_base::end) + { + return seekpos(m_bufferLength - 1 - off, which); + } + else if (dir == std::ios_base::cur) + { + if (which == std::ios_base::in) + { + return seekpos((gptr() - (char*)m_byteBuffer.GetUnderlyingData()) + off, which); + } + if (which == std::ios_base::out) + { + return seekpos((pptr() - (char*)m_byteBuffer.GetUnderlyingData()) + off, which); + } + } + + return std::streamoff(-1); + } + + std::streampos EventStreamBuf::seekpos(std::streampos pos, std::ios_base::openmode which) + { + assert(static_cast<size_t>(pos) <= m_bufferLength); + if (static_cast<size_t>(pos) > m_bufferLength) + { + return std::streampos(std::streamoff(-1)); + } + + if (which == std::ios_base::in) + { + m_err.seekg(pos); + return m_err.tellg(); + } + + if (which == std::ios_base::out) + { + return pos; + } + + return std::streampos(std::streamoff(-1)); + } + + int EventStreamBuf::underflow() + { + if (!m_err || m_err.eof() || m_decoder) + { + return std::char_traits<char>::eof(); + } + + m_err.flush(); + m_err.read(reinterpret_cast<char*>(m_byteBuffer.GetUnderlyingData()), m_byteBuffer.GetLength()); + + char* begin = reinterpret_cast<char*>(m_byteBuffer.GetUnderlyingData()); + setg(begin, begin, begin + m_err.gcount()); + return std::char_traits<char>::to_int_type(*gptr()); + } + + int EventStreamBuf::overflow(int ch) + { + auto eof = std::char_traits<char>::eof(); + + if (m_decoder) + { + if (ch != eof) + { + *pptr() = (char)ch; + pbump(1); + } + + writeToDecoder(); + return ch; + } + + return eof; + } + + int EventStreamBuf::sync() + { + if (m_decoder) + { + writeToDecoder(); + } + + return 0; + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamDecoder.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamDecoder.cpp new file mode 100644 index 0000000000..f70a6c88f6 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamDecoder.cpp @@ -0,0 +1,170 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/common/common.h> +#include <aws/core/utils/event/EventHeader.h> +#include <aws/core/utils/event/EventMessage.h> +#include <aws/core/utils/event/EventStreamDecoder.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/utils/UnreferencedParam.h> +#include <aws/core/utils/memory/AWSMemory.h> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + static const char EVENT_STREAM_DECODER_CLASS_TAG[] = "Aws::Utils::Event::EventStreamDecoder"; + + EventStreamDecoder::EventStreamDecoder(EventStreamHandler* handler) : m_eventStreamHandler(handler) + { + aws_event_stream_streaming_decoder_init(&m_decoder, + get_aws_allocator(), + onPayloadSegment, + onPreludeReceived, + onHeaderReceived, + onError, + (void*)handler); + } + + EventStreamDecoder::~EventStreamDecoder() + { + aws_event_stream_streaming_decoder_clean_up(&m_decoder); + } + + void EventStreamDecoder::Pump(const ByteBuffer& data) + { + Pump(data, data.GetLength()); + } + + void EventStreamDecoder::Pump(const ByteBuffer& data, size_t length) + { + aws_byte_buf dataBuf = aws_byte_buf_from_array(static_cast<uint8_t*>(data.GetUnderlyingData()), length); + aws_event_stream_streaming_decoder_pump(&m_decoder, &dataBuf); + } + + void EventStreamDecoder::Reset() + { + m_eventStreamHandler->Reset(); + } + + void EventStreamDecoder::ResetEventStreamHandler(EventStreamHandler* handler) + { + aws_event_stream_streaming_decoder_init(&m_decoder, get_aws_allocator(), + onPayloadSegment, + onPreludeReceived, + onHeaderReceived, + onError, + reinterpret_cast<void *>(handler)); + } + + void EventStreamDecoder::onPayloadSegment( + aws_event_stream_streaming_decoder* decoder, + aws_byte_buf* payload, + int8_t isFinalSegment, + void* context) + { + AWS_UNREFERENCED_PARAM(decoder); + auto handler = static_cast<EventStreamHandler*>(context); + assert(handler); + if (!handler) + { + AWS_LOGSTREAM_ERROR(EVENT_STREAM_DECODER_CLASS_TAG, "Payload received, but decoder encountered internal errors before." + "ErrorCode: " << EventStreamErrorsMapper::GetNameForError(handler->GetInternalError()) << ", " + "ErrorMessage: " << handler->GetEventPayloadAsString()); + return; + } + handler->WriteMessageEventPayload(static_cast<unsigned char*>(payload->buffer), payload->len); + + // Complete payload received + if (isFinalSegment == 1) + { + assert(handler->IsMessageCompleted()); + handler->OnEvent(); + handler->Reset(); + } + } + + void EventStreamDecoder::onPreludeReceived( + aws_event_stream_streaming_decoder* decoder, + aws_event_stream_message_prelude* prelude, + void* context) + { + AWS_UNREFERENCED_PARAM(decoder); + auto handler = static_cast<EventStreamHandler*>(context); + handler->Reset(); + + //Encounter internal error in prelude received. + //This error will be handled by OnError callback function later. + if (prelude->total_len < prelude->headers_len + 16) + { + return; + } + handler->SetMessageMetadata(prelude->total_len, prelude->headers_len, + prelude->total_len - prelude->headers_len - 4/*total byte-length*/ - 4/*headers byte-length*/ - 4/*prelude crc*/ - 4/*message crc*/); + AWS_LOGSTREAM_TRACE(EVENT_STREAM_DECODER_CLASS_TAG, "Message received, the expected length of the message is: " << prelude->total_len << + " bytes, and the expected length of the header is: " << prelude->headers_len << " bytes"); + + //Handle empty message + //if (handler->m_message.GetHeadersLength() == 0 && handler->m_message.GetPayloadLength() == 0) + if (handler->IsMessageCompleted()) + { + handler->OnEvent(); + handler->Reset(); + } + } + + void EventStreamDecoder::onHeaderReceived( + aws_event_stream_streaming_decoder* decoder, + aws_event_stream_message_prelude* prelude, + aws_event_stream_header_value_pair* header, + void* context) + { + AWS_UNREFERENCED_PARAM(decoder); + AWS_UNREFERENCED_PARAM(prelude); + auto handler = static_cast<EventStreamHandler*>(context); + assert(handler); + if (!handler) + { + AWS_LOGSTREAM_ERROR(EVENT_STREAM_DECODER_CLASS_TAG, "Payload received, but decoder encountered internal errors before." + "ErrorCode: " << EventStreamErrorsMapper::GetNameForError(handler->GetInternalError()) << ", " + "ErrorMessage: " << handler->GetEventPayloadAsString()); + return; + } + + // The length of a header = 1 byte (to represent the length of header name) + length of header name + 1 byte (to represent header type) + // + 2 bytes (to represent length of header value) + length of header value + handler->InsertMessageEventHeader(Aws::String(header->header_name, header->header_name_len), + 1 + header->header_name_len + 1 + 2 + header->header_value_len, EventHeaderValue(header)); + + // Handle messages only have headers, but without payload. + //if (handler->m_message.GetHeadersLength() == handler->m_headersBytesReceived() && handler->m_message.GetPayloadLength() == 0) + if (handler->IsMessageCompleted()) + { + handler->OnEvent(); + handler->Reset(); + } + } + + void EventStreamDecoder::onError( + aws_event_stream_streaming_decoder* decoder, + aws_event_stream_message_prelude* prelude, + int error_code, + const char* message, + void* context) + { + AWS_UNREFERENCED_PARAM(decoder); + AWS_UNREFERENCED_PARAM(prelude); + auto handler = static_cast<EventStreamHandler*>(context); + handler->SetFailure(); + handler->SetInternalError(error_code); + handler->WriteMessageEventPayload(reinterpret_cast<const unsigned char*>(message), strlen(message)); + handler->OnEvent(); + } + } // namespace Event + } // namespace Utils +} // namespace Aws + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp new file mode 100644 index 0000000000..ef7104e839 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamEncoder.cpp @@ -0,0 +1,162 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/event/EventHeader.h> +#include <aws/core/utils/event/EventMessage.h> +#include <aws/core/utils/event/EventStreamEncoder.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <aws/core/auth/AWSAuthSigner.h> +#include <aws/common/byte_order.h> +#include <aws/core/utils/memory/AWSMemory.h> + +#include <cassert> + +namespace Aws +{ + namespace Utils + { + namespace Event + { + static const char TAG[] = "EventStreamEncoder"; + + static void EncodeHeaders(const Aws::Utils::Event::Message& msg, aws_array_list* headers) + { + aws_array_list_init_dynamic(headers, get_aws_allocator(), msg.GetEventHeaders().size(), sizeof(aws_event_stream_header_value_pair)); + for (auto&& header : msg.GetEventHeaders()) + { + const uint8_t headerKeyLen = static_cast<uint8_t>(header.first.length()); + switch(header.second.GetType()) + { + case EventHeaderValue::EventHeaderType::BOOL_TRUE: + case EventHeaderValue::EventHeaderType::BOOL_FALSE: + aws_event_stream_add_bool_header(headers, header.first.c_str(), headerKeyLen, header.second.GetEventHeaderValueAsBoolean()); + break; + case EventHeaderValue::EventHeaderType::BYTE: + aws_event_stream_add_bool_header(headers, header.first.c_str(), headerKeyLen, header.second.GetEventHeaderValueAsByte()); + break; + case EventHeaderValue::EventHeaderType::INT16: + aws_event_stream_add_int16_header(headers, header.first.c_str(), headerKeyLen, header.second.GetEventHeaderValueAsInt16()); + break; + case EventHeaderValue::EventHeaderType::INT32: + aws_event_stream_add_int32_header(headers, header.first.c_str(), headerKeyLen, header.second.GetEventHeaderValueAsInt32()); + break; + case EventHeaderValue::EventHeaderType::INT64: + aws_event_stream_add_int64_header(headers, header.first.c_str(), headerKeyLen, header.second.GetEventHeaderValueAsInt64()); + break; + case EventHeaderValue::EventHeaderType::BYTE_BUF: + { + const auto& bytes = header.second.GetEventHeaderValueAsBytebuf(); + aws_event_stream_add_bytebuf_header(headers, header.first.c_str(), headerKeyLen, bytes.GetUnderlyingData(), static_cast<uint16_t>(bytes.GetLength()), 1 /*copy*/); + } + break; + case EventHeaderValue::EventHeaderType::STRING: + { + const auto& bytes = header.second.GetUnderlyingBuffer(); + aws_event_stream_add_string_header(headers, header.first.c_str(), headerKeyLen, reinterpret_cast<char*>(bytes.GetUnderlyingData()), static_cast<uint16_t>(bytes.GetLength()), 0 /*copy*/); + } + break; + case EventHeaderValue::EventHeaderType::TIMESTAMP: + aws_event_stream_add_timestamp_header(headers, header.first.c_str(), headerKeyLen, header.second.GetEventHeaderValueAsTimestamp()); + break; + case EventHeaderValue::EventHeaderType::UUID: + { + ByteBuffer uuidBytes = header.second.GetEventHeaderValueAsUuid(); + aws_event_stream_add_uuid_header(headers, header.first.c_str(), headerKeyLen, uuidBytes.GetUnderlyingData()); + } + break; + default: + AWS_LOG_ERROR(TAG, "Encountered unknown type of header."); + break; + } + } + } + + EventStreamEncoder::EventStreamEncoder(Client::AWSAuthSigner* signer) : m_signer(signer) + { + } + + + Aws::Vector<unsigned char> EventStreamEncoder::EncodeAndSign(const Aws::Utils::Event::Message& msg) + { + aws_event_stream_message encoded = Encode(msg); + aws_event_stream_message signedMessage = Sign(&encoded); + + const auto signedMessageLength = signedMessage.message_buffer ? aws_event_stream_message_total_length(&signedMessage) : 0; + + Aws::Vector<unsigned char> outputBits(signedMessage.message_buffer, signedMessage.message_buffer + signedMessageLength); + aws_event_stream_message_clean_up(&encoded); + aws_event_stream_message_clean_up(&signedMessage); + return outputBits; + } + + aws_event_stream_message EventStreamEncoder::Encode(const Aws::Utils::Event::Message& msg) + { + aws_array_list headers; + EncodeHeaders(msg, &headers); + + aws_byte_buf payload; + payload.len = msg.GetEventPayload().size(); + // this const_cast is OK because aws_byte_buf will only be "read from" by the following functions. + payload.buffer = const_cast<uint8_t*>(msg.GetEventPayload().data()); + payload.capacity = 0; + payload.allocator = nullptr; + + aws_event_stream_message encoded; + if(aws_event_stream_message_init(&encoded, get_aws_allocator(), &headers, &payload) == AWS_OP_ERR) + { + AWS_LOGSTREAM_ERROR(TAG, "Error creating event-stream message from payload."); + aws_event_stream_headers_list_cleanup(&headers); + // GCC 4.9.4 issues a warning with -Wextra if we simply do + // return {}; + aws_event_stream_message empty{nullptr, nullptr, 0}; + return empty; + } + aws_event_stream_headers_list_cleanup(&headers); + return encoded; + } + + aws_event_stream_message EventStreamEncoder::Sign(aws_event_stream_message* msg) + { + const auto msglen = msg->message_buffer ? aws_event_stream_message_total_length(msg) : 0; + Event::Message signedMessage; + signedMessage.WriteEventPayload(msg->message_buffer, msglen); + + assert(m_signer); + if (!m_signer->SignEventMessage(signedMessage, m_signatureSeed)) + { + AWS_LOGSTREAM_ERROR(TAG, "Failed to sign event message frame."); + // GCC 4.9.4 issues a warning with -Wextra if we simply do + // return {}; + aws_event_stream_message empty{nullptr, nullptr, 0}; + return empty; + } + + aws_array_list headers; + EncodeHeaders(signedMessage, &headers); + + aws_byte_buf payload; + payload.len = signedMessage.GetEventPayload().size(); + payload.buffer = signedMessage.GetEventPayload().data(); + payload.capacity = 0; + payload.allocator = nullptr; + + aws_event_stream_message signedmsg; + if(aws_event_stream_message_init(&signedmsg, get_aws_allocator(), &headers, &payload)) + { + AWS_LOGSTREAM_ERROR(TAG, "Error creating event-stream message from payload."); + aws_event_stream_headers_list_cleanup(&headers); + // GCC 4.9.4 issues a warning with -Wextra if we simply do + // return {}; + aws_event_stream_message empty{nullptr, nullptr, 0}; + return empty; + } + aws_event_stream_headers_list_cleanup(&headers); + return signedmsg; + } + + } // namespace Event + } // namespace Utils +} // namespace Aws + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamErrors.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamErrors.cpp new file mode 100644 index 0000000000..836d0b47c5 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/event/EventStreamErrors.cpp @@ -0,0 +1,66 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/client/AWSError.h> +#include <aws/core/utils/HashingUtils.h> +#include <aws/core/utils/event/EventStreamErrors.h> + + using namespace Aws::Client; +// using namespace Aws::S3; +// using namespace Aws::Utils; + +namespace Aws +{ + namespace Utils + { + namespace Event + { + namespace EventStreamErrorsMapper + { + /* + static const int EVENT_STREAM_NO_ERROR_HASH = HashingUtils::HashString("EventStreamNoError"); + static const int EVENT_STREAM_BUFFER_LENGTH_MISMATCH_HASH = HashingUtils::HashString("EventStreamBufferLengthMismatch"); + static const int EVENT_STREAM_INSUFFICIENT_BUFFER_LEN_HASH = HashingUtils::HashString("EventStreamInsufficientBufferLen"); + static const int EVENT_STREAM_MESSAGE_FIELD_SIZE_EXCEEDED_HASH = HashingUtils::HashString("EventStreamMessageFieldSizeExceeded"); + static const int EVENT_STREAM_PRELUDE_CHECKSUM_FAILURE_HASH = HashingUtils::HashString("EventStreamPreludeChecksumFailure"); + static const int EVENT_STREAM_MESSAGE_CHECKSUM_FAILURE_HASH = HashingUtils::HashString("EventStreamMessageChecksumFailure"); + static const int EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN_HASH = HashingUtils::HashString("EventStreamMessageInvalidHeadersLen"); + static const int EVENT_STREAM_MESSAGE_UNKNOWN_HEADER_TYPE_HASH = HashingUtils::HashString("EventStreamMessageUnknownHeaderType"); + static const int EVENT_STREAM_MESSAGE_PARSER_ILLEGAL_STATE_HASH = HashingUtils::HashString("EventStreamMessageParserIllegalState"); + */ + const char* GetNameForError(EventStreamErrors error) + { + switch (error) + { + case EventStreamErrors::EVENT_STREAM_NO_ERROR: + return "EventStreamNoError"; + case EventStreamErrors::EVENT_STREAM_BUFFER_LENGTH_MISMATCH: + return "EventStreamBufferLengthMismatch"; + case EventStreamErrors::EVENT_STREAM_INSUFFICIENT_BUFFER_LEN: + return "EventStreamInsufficientBufferLen"; + case EventStreamErrors::EVENT_STREAM_MESSAGE_FIELD_SIZE_EXCEEDED: + return "EventStreamMessageFieldSizeExceeded"; + case EventStreamErrors::EVENT_STREAM_PRELUDE_CHECKSUM_FAILURE: + return "EventStreamPreludeChecksumFailure"; + case EventStreamErrors::EVENT_STREAM_MESSAGE_CHECKSUM_FAILURE: + return "EventStreamMessageChecksumFailure"; + case EventStreamErrors::EVENT_STREAM_MESSAGE_INVALID_HEADERS_LEN: + return "EventStreamMessageInvalidHeadersLen"; + case EventStreamErrors::EVENT_STREAM_MESSAGE_UNKNOWN_HEADER_TYPE: + return "EventStreamMessageUnknownHeaderType"; + case EventStreamErrors::EVENT_STREAM_MESSAGE_PARSER_ILLEGAL_STATE: + return "EventStreamMessageParserIllegalState"; + default: + return "EventStreamUnknownError"; + } + } + + AWSError<CoreErrors> GetAwsErrorForEventStreamError(EventStreamErrors error) + { + return AWSError<CoreErrors>(CoreErrors::UNKNOWN, GetNameForError(error), "", false); + } + } // namespace EventStreamErrorsMapper + } // namespace Event + } // namespace Utils +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/json/JsonSerializer.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/json/JsonSerializer.cpp new file mode 100644 index 0000000000..9b785d1995 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/json/JsonSerializer.cpp @@ -0,0 +1,665 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/json/JsonSerializer.h> + +#include <iterator> +#include <algorithm> +#include <aws/core/utils/memory/stl/AWSStringStream.h> +#include <aws/core/utils/StringUtils.h> + +using namespace Aws::Utils; +using namespace Aws::Utils::Json; + +JsonValue::JsonValue() : m_wasParseSuccessful(true) +{ + m_value = nullptr; +} + +JsonValue::JsonValue(cJSON* value) : + m_value(cJSON_Duplicate(value, true /* recurse */)), + m_wasParseSuccessful(true) +{ +} + +JsonValue::JsonValue(const Aws::String& value) : m_wasParseSuccessful(true) +{ + const char* return_parse_end; + m_value = cJSON_ParseWithOpts(value.c_str(), &return_parse_end, 1/*require_null_terminated*/); + + if (!m_value || cJSON_IsInvalid(m_value)) + { + m_wasParseSuccessful = false; + m_errorMessage = "Failed to parse JSON at: "; + m_errorMessage += return_parse_end; + } +} + +JsonValue::JsonValue(Aws::IStream& istream) : m_wasParseSuccessful(true) +{ + Aws::StringStream memoryStream; + std::copy(std::istreambuf_iterator<char>(istream), std::istreambuf_iterator<char>(), std::ostreambuf_iterator<char>(memoryStream)); + const char* return_parse_end; + const auto input = memoryStream.str(); + m_value = cJSON_ParseWithOpts(input.c_str(), &return_parse_end, 1/*require_null_terminated*/); + + if (!m_value || cJSON_IsInvalid(m_value)) + { + m_wasParseSuccessful = false; + m_errorMessage = "Failed to parse JSON. Invalid input at: "; + m_errorMessage += return_parse_end; + } +} + +JsonValue::JsonValue(const JsonValue& value) : + m_value(cJSON_Duplicate(value.m_value, true/*recurse*/)), + m_wasParseSuccessful(value.m_wasParseSuccessful), + m_errorMessage(value.m_errorMessage) +{ +} + +JsonValue::JsonValue(JsonValue&& value) : + m_value(value.m_value), + m_wasParseSuccessful(value.m_wasParseSuccessful), + m_errorMessage(std::move(value.m_errorMessage)) +{ + value.m_value = nullptr; +} + +void JsonValue::Destroy() +{ + cJSON_Delete(m_value); +} + +JsonValue::~JsonValue() +{ + Destroy(); +} + +JsonValue& JsonValue::operator=(const JsonValue& other) +{ + if (this == &other) + { + return *this; + } + + Destroy(); + m_value = cJSON_Duplicate(other.m_value, true /*recurse*/); + m_wasParseSuccessful = other.m_wasParseSuccessful; + m_errorMessage = other.m_errorMessage; + return *this; +} + +JsonValue& JsonValue::operator=(JsonValue&& other) +{ + if (this == &other) + { + return *this; + } + + using std::swap; + swap(m_value, other.m_value); + swap(m_errorMessage, other.m_errorMessage); + m_wasParseSuccessful = other.m_wasParseSuccessful; + return *this; +} + +static void AddOrReplace(cJSON* root, const char* key, cJSON* value) +{ + const auto existing = cJSON_GetObjectItemCaseSensitive(root, key); + if (existing) + { + cJSON_ReplaceItemInObjectCaseSensitive(root, key, value); + } + else + { + cJSON_AddItemToObject(root, key, value); + } +} + +JsonValue& JsonValue::WithString(const char* key, const Aws::String& value) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + const auto val = cJSON_CreateString(value.c_str()); + AddOrReplace(m_value, key, val); + return *this; +} + +JsonValue& JsonValue::WithString(const Aws::String& key, const Aws::String& value) +{ + return WithString(key.c_str(), value); +} + +JsonValue& JsonValue::AsString(const Aws::String& value) +{ + Destroy(); + m_value = cJSON_CreateString(value.c_str()); + return *this; +} + +JsonValue& JsonValue::WithBool(const char* key, bool value) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + const auto val = cJSON_CreateBool(value); + AddOrReplace(m_value, key, val); + return *this; +} + +JsonValue& JsonValue::WithBool(const Aws::String& key, bool value) +{ + return WithBool(key.c_str(), value); +} + +JsonValue& JsonValue::AsBool(bool value) +{ + Destroy(); + m_value = cJSON_CreateBool(value); + return *this; +} + +JsonValue& JsonValue::WithInteger(const char* key, int value) +{ + return WithDouble(key, static_cast<double>(value)); +} + +JsonValue& JsonValue::WithInteger(const Aws::String& key, int value) +{ + return WithDouble(key.c_str(), static_cast<double>(value)); +} + +JsonValue& JsonValue::AsInteger(int value) +{ + Destroy(); + m_value = cJSON_CreateNumber(static_cast<double>(value)); + return *this; +} + +JsonValue& JsonValue::WithInt64(const char* key, long long value) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + const auto val = cJSON_CreateInt64(value); + AddOrReplace(m_value, key, val); + return *this; +} + +JsonValue& JsonValue::WithInt64(const Aws::String& key, long long value) +{ + return WithInt64(key.c_str(), value); +} + +JsonValue& JsonValue::AsInt64(long long value) +{ + Destroy(); + m_value = cJSON_CreateInt64(value); + return *this; +} + +JsonValue& JsonValue::WithDouble(const char* key, double value) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + const auto val = cJSON_CreateNumber(value); + AddOrReplace(m_value, key, val); + return *this; +} + +JsonValue& JsonValue::WithDouble(const Aws::String& key, double value) +{ + return WithDouble(key.c_str(), value); +} + +JsonValue& JsonValue::AsDouble(double value) +{ + Destroy(); + m_value = cJSON_CreateNumber(value); + return *this; +} + +JsonValue& JsonValue::WithArray(const char* key, const Array<Aws::String>& array) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + auto arrayValue = cJSON_CreateArray(); + for (unsigned i = 0; i < array.GetLength(); ++i) + { + cJSON_AddItemToArray(arrayValue, cJSON_CreateString(array[i].c_str())); + } + + AddOrReplace(m_value, key, arrayValue); + return *this; +} + +JsonValue& JsonValue::WithArray(const Aws::String& key, const Array<Aws::String>& array) +{ + return WithArray(key.c_str(), array); +} + +JsonValue& JsonValue::WithArray(const Aws::String& key, const Array<JsonValue>& array) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + auto arrayValue = cJSON_CreateArray(); + for (unsigned i = 0; i < array.GetLength(); ++i) + { + cJSON_AddItemToArray(arrayValue, cJSON_Duplicate(array[i].m_value, true /*recurse*/)); + } + + AddOrReplace(m_value, key.c_str(), arrayValue); + return *this; +} + +JsonValue& JsonValue::WithArray(const Aws::String& key, Array<JsonValue>&& array) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + auto arrayValue = cJSON_CreateArray(); + for (unsigned i = 0; i < array.GetLength(); ++i) + { + cJSON_AddItemToArray(arrayValue, array[i].m_value); + array[i].m_value = nullptr; + } + + AddOrReplace(m_value, key.c_str(), arrayValue); + return *this; +} + +JsonValue& JsonValue::AsArray(const Array<JsonValue>& array) +{ + auto arrayValue = cJSON_CreateArray(); + for (unsigned i = 0; i < array.GetLength(); ++i) + { + cJSON_AddItemToArray(arrayValue, cJSON_Duplicate(array[i].m_value, true /*recurse*/)); + } + + Destroy(); + m_value = arrayValue; + return *this; +} + +JsonValue& JsonValue::AsArray(Array<JsonValue>&& array) +{ + auto arrayValue = cJSON_CreateArray(); + for (unsigned i = 0; i < array.GetLength(); ++i) + { + cJSON_AddItemToArray(arrayValue, array[i].m_value); + array[i].m_value = nullptr; + } + + Destroy(); + m_value = arrayValue; + return *this; +} + +JsonValue& JsonValue::WithObject(const char* key, const JsonValue& value) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + const auto copy = value.m_value == nullptr ? cJSON_CreateObject() : cJSON_Duplicate(value.m_value, true /*recurse*/); + AddOrReplace(m_value, key, copy); + return *this; +} + +JsonValue& JsonValue::WithObject(const Aws::String& key, const JsonValue& value) +{ + return WithObject(key.c_str(), value); +} + +JsonValue& JsonValue::WithObject(const char* key, JsonValue&& value) +{ + if (!m_value) + { + m_value = cJSON_CreateObject(); + } + + AddOrReplace(m_value, key, value.m_value == nullptr ? cJSON_CreateObject() : value.m_value); + value.m_value = nullptr; + return *this; +} + +JsonValue& JsonValue::WithObject(const Aws::String& key, JsonValue&& value) +{ + return WithObject(key.c_str(), std::move(value)); +} + +JsonValue& JsonValue::AsObject(const JsonValue& value) +{ + *this = value; + return *this; +} + +JsonValue& JsonValue::AsObject(JsonValue && value) +{ + *this = std::move(value); + return *this; +} + +bool JsonValue::operator==(const JsonValue& other) const +{ + return cJSON_Compare(m_value, other.m_value, true /*case-sensitive*/) != 0; +} + +bool JsonValue::operator!=(const JsonValue& other) const +{ + return !(*this == other); +} + +JsonView JsonValue::View() const +{ + return *this; +} + +JsonView::JsonView() : m_value(nullptr) +{ +} + +JsonView::JsonView(const JsonValue& val) : m_value(val.m_value) +{ +} + +JsonView::JsonView(cJSON* val) : m_value(val) +{ +} + +JsonView& JsonView::operator=(const JsonValue& v) +{ + m_value = v.m_value; + return *this; +} + +JsonView& JsonView::operator=(cJSON* val) +{ + m_value = val; + return *this; +} + +Aws::String JsonView::GetString(const Aws::String& key) const +{ + assert(m_value); + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + auto str = cJSON_GetStringValue(item); + return str ? str : ""; +} + +Aws::String JsonView::AsString() const +{ + const char* str = cJSON_GetStringValue(m_value); + if (str == nullptr) + { + return {}; + } + return str; +} + +bool JsonView::GetBool(const Aws::String& key) const +{ + assert(m_value); + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + assert(item); + return item->valueint != 0; +} + +bool JsonView::AsBool() const +{ + assert(cJSON_IsBool(m_value)); + return cJSON_IsTrue(m_value) != 0; +} + +int JsonView::GetInteger(const Aws::String& key) const +{ + assert(m_value); + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + assert(item); + return item->valueint; +} + +int JsonView::AsInteger() const +{ + assert(cJSON_IsNumber(m_value)); // can be double or value larger than int_max, but at least not UB + return m_value->valueint; +} + +int64_t JsonView::GetInt64(const Aws::String& key) const +{ + assert(m_value); + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + assert(item); + if (item->valuestring) + { + return Aws::Utils::StringUtils::ConvertToInt64(item->valuestring); + } + else + { + return static_cast<int64_t>(item->valuedouble); + } +} + +int64_t JsonView::AsInt64() const +{ + assert(cJSON_IsNumber(m_value)); + if (m_value->valuestring) + { + return Aws::Utils::StringUtils::ConvertToInt64(m_value->valuestring); + } + else + { + return static_cast<int64_t>(m_value->valuedouble); + } +} + +double JsonView::GetDouble(const Aws::String& key) const +{ + assert(m_value); + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + assert(item); + return item->valuedouble; +} + +double JsonView::AsDouble() const +{ + assert(cJSON_IsNumber(m_value)); + return m_value->valuedouble; +} + +JsonView JsonView::GetObject(const Aws::String& key) const +{ + assert(m_value); + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + return item; +} + +JsonView JsonView::AsObject() const +{ + assert(cJSON_IsObject(m_value) || cJSON_IsNull(m_value)); + return m_value; +} + +Array<JsonView> JsonView::GetArray(const Aws::String& key) const +{ + assert(m_value); + auto array = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + assert(cJSON_IsArray(array)); + Array<JsonView> returnArray(cJSON_GetArraySize(array)); + + auto element = array->child; + for (unsigned i = 0; element && i < returnArray.GetLength(); ++i, element = element->next) + { + returnArray[i] = element; + } + + return returnArray; +} + +Array<JsonView> JsonView::AsArray() const +{ + assert(cJSON_IsArray(m_value)); + Array<JsonView> returnArray(cJSON_GetArraySize(m_value)); + + auto element = m_value->child; + + for (unsigned i = 0; element && i < returnArray.GetLength(); ++i, element = element->next) + { + returnArray[i] = element; + } + + return returnArray; +} + +Aws::Map<Aws::String, JsonView> JsonView::GetAllObjects() const +{ + Aws::Map<Aws::String, JsonView> valueMap; + if (!m_value) + { + return valueMap; + } + + for (auto iter = m_value->child; iter; iter = iter->next) + { + valueMap.emplace(std::make_pair(Aws::String(iter->string), JsonView(iter))); + } + + return valueMap; +} + +bool JsonView::ValueExists(const Aws::String& key) const +{ + if (!cJSON_IsObject(m_value)) + { + return false; + } + + auto item = cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()); + return !(item == nullptr || cJSON_IsNull(item)); +} + +bool JsonView::KeyExists(const Aws::String& key) const +{ + if (!cJSON_IsObject(m_value)) + { + return false; + } + + return cJSON_GetObjectItemCaseSensitive(m_value, key.c_str()) != nullptr;; +} + +bool JsonView::IsObject() const +{ + return cJSON_IsObject(m_value) != 0; +} + +bool JsonView::IsBool() const +{ + return cJSON_IsBool(m_value) != 0; +} + +bool JsonView::IsString() const +{ + return cJSON_IsString(m_value) != 0; +} + +bool JsonView::IsIntegerType() const +{ + if (!cJSON_IsNumber(m_value)) + { + return false; + } + + if (m_value->valuestring) + { + Aws::String valueString = m_value->valuestring; + return std::all_of(valueString.begin(), valueString.end(), [](unsigned char c){ return ::isdigit(c) || c == '+' || c == '-'; }); + } + return m_value->valuedouble == static_cast<long long>(m_value->valuedouble); +} + +bool JsonView::IsFloatingPointType() const +{ + if (!cJSON_IsNumber(m_value)) + { + return false; + } + + if (m_value->valuestring) + { + Aws::String valueString = m_value->valuestring; + return std::any_of(valueString.begin(), valueString.end(), [](unsigned char c){ return !::isdigit(c) && c != '+' && c != '-'; }); + } + return m_value->valuedouble != static_cast<long long>(m_value->valuedouble); +} + +bool JsonView::IsListType() const +{ + return cJSON_IsArray(m_value) != 0; +} + +bool JsonView::IsNull() const +{ + return cJSON_IsNull(m_value) != 0; +} + +Aws::String JsonView::WriteCompact(bool treatAsObject) const +{ + if (!m_value) + { + if (treatAsObject) + { + return "{}"; + } + return {}; + } + + auto temp = cJSON_PrintUnformatted(m_value); + Aws::String out(temp); + cJSON_free(temp); + return out; +} + +Aws::String JsonView::WriteReadable(bool treatAsObject) const +{ + if (!m_value) + { + if (treatAsObject) + { + return "{\n}\n"; + } + return {}; + } + + auto temp = cJSON_Print(m_value); + Aws::String out(temp); + cJSON_free(temp); + return out; +} + +JsonValue JsonView::Materialize() const +{ + return m_value; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/AWSLogging.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/AWSLogging.cpp new file mode 100644 index 0000000000..fc1b9fcc2e --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/AWSLogging.cpp @@ -0,0 +1,51 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/logging/AWSLogging.h> +#include <aws/core/utils/logging/LogSystemInterface.h> +#include <aws/core/utils/memory/stl/AWSStack.h> + +#include <memory> + +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; + +static std::shared_ptr<LogSystemInterface> AWSLogSystem(nullptr); +static std::shared_ptr<LogSystemInterface> OldLogger(nullptr); + +namespace Aws +{ +namespace Utils +{ +namespace Logging { + +void InitializeAWSLogging(const std::shared_ptr<LogSystemInterface> &logSystem) { + AWSLogSystem = logSystem; +} + +void ShutdownAWSLogging(void) { + InitializeAWSLogging(nullptr); +} + +LogSystemInterface *GetLogSystem() { + return AWSLogSystem.get(); +} + +void PushLogger(const std::shared_ptr<LogSystemInterface> &logSystem) +{ + OldLogger = AWSLogSystem; + AWSLogSystem = logSystem; +} + +void PopLogger() +{ + AWSLogSystem = OldLogger; + OldLogger = nullptr; +} + +} // namespace Logging +} // namespace Utils +} // namespace Aws
\ No newline at end of file diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/ConsoleLogSystem.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/ConsoleLogSystem.cpp new file mode 100644 index 0000000000..dec7cef82f --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/ConsoleLogSystem.cpp @@ -0,0 +1,22 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/logging/ConsoleLogSystem.h> + +#include <iostream> + +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; + +void ConsoleLogSystem::ProcessFormattedStatement(Aws::String&& statement) +{ + std::cout << statement; +} + +void ConsoleLogSystem::Flush() +{ + std::cout.flush(); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/DefaultLogSystem.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/DefaultLogSystem.cpp new file mode 100644 index 0000000000..7286bb6378 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/DefaultLogSystem.cpp @@ -0,0 +1,117 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/logging/DefaultLogSystem.h> + +#include <aws/core/utils/DateTime.h> +#include <aws/core/utils/memory/stl/AWSVector.h> + +#include <fstream> + +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; + +static const char* AllocationTag = "DefaultLogSystem"; +static const int BUFFERED_MSG_COUNT = 100; + +static std::shared_ptr<Aws::OFStream> MakeDefaultLogFile(const Aws::String& filenamePrefix) +{ + Aws::String newFileName = filenamePrefix + DateTime::CalculateGmtTimestampAsString("%Y-%m-%d-%H") + ".log"; + return Aws::MakeShared<Aws::OFStream>(AllocationTag, newFileName.c_str(), Aws::OFStream::out | Aws::OFStream::app); +} + +static void LogThread(DefaultLogSystem::LogSynchronizationData* syncData, const std::shared_ptr<Aws::OStream>& logFile, const Aws::String& filenamePrefix, bool rollLog) +{ + // localtime requires access to env. variables to get Timezone, which is not thread-safe + int32_t lastRolledHour = DateTime::Now().GetHour(false /*localtime*/); + std::shared_ptr<Aws::OStream> log = logFile; + + for(;;) + { + std::unique_lock<std::mutex> locker(syncData->m_logQueueMutex); + syncData->m_queueSignal.wait(locker, [&](){ return syncData->m_stopLogging == true || syncData->m_queuedLogMessages.size() > 0; } ); + + if (syncData->m_stopLogging && syncData->m_queuedLogMessages.size() == 0) + { + break; + } + + Aws::Vector<Aws::String> messages(std::move(syncData->m_queuedLogMessages)); + syncData->m_queuedLogMessages.reserve(BUFFERED_MSG_COUNT); + + locker.unlock(); + + if (messages.size() > 0) + { + if (rollLog) + { + // localtime requires access to env. variables to get Timezone, which is not thread-safe + int32_t currentHour = DateTime::Now().GetHour(false /*localtime*/); + if (currentHour != lastRolledHour) + { + log = MakeDefaultLogFile(filenamePrefix); + lastRolledHour = currentHour; + } + } + + for (const auto& msg : messages) + { + (*log) << msg; + } + + log->flush(); + } + } +} + +DefaultLogSystem::DefaultLogSystem(LogLevel logLevel, const std::shared_ptr<Aws::OStream>& logFile) : + Base(logLevel), + m_syncData(), + m_loggingThread() +{ + m_loggingThread = std::thread(LogThread, &m_syncData, logFile, "", false); +} + +DefaultLogSystem::DefaultLogSystem(LogLevel logLevel, const Aws::String& filenamePrefix) : + Base(logLevel), + m_syncData(), + m_loggingThread() +{ + m_loggingThread = std::thread(LogThread, &m_syncData, MakeDefaultLogFile(filenamePrefix), filenamePrefix, true); +} + +DefaultLogSystem::~DefaultLogSystem() +{ + { + std::lock_guard<std::mutex> locker(m_syncData.m_logQueueMutex); + m_syncData.m_stopLogging = true; + } + + m_syncData.m_queueSignal.notify_one(); + + m_loggingThread.join(); +} + +void DefaultLogSystem::ProcessFormattedStatement(Aws::String&& statement) +{ + std::unique_lock<std::mutex> locker(m_syncData.m_logQueueMutex); + m_syncData.m_queuedLogMessages.emplace_back(std::move(statement)); + if(m_syncData.m_queuedLogMessages.size() >= BUFFERED_MSG_COUNT) + { + locker.unlock(); + m_syncData.m_queueSignal.notify_one(); + } + else + { + locker.unlock(); + } +} + +void DefaultLogSystem::Flush() +{ + m_syncData.m_queueSignal.notify_one(); +} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/FormattedLogSystem.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/FormattedLogSystem.cpp new file mode 100644 index 0000000000..41c4d7e09c --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/FormattedLogSystem.cpp @@ -0,0 +1,99 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + + +#include <aws/core/utils/logging/FormattedLogSystem.h> + +#include <aws/core/utils/DateTime.h> +#include <aws/core/utils/Array.h> + +#include <fstream> +#include <cstdarg> +#include <stdio.h> +#include <thread> + +using namespace Aws::Utils; +using namespace Aws::Utils::Logging; + +static Aws::String CreateLogPrefixLine(LogLevel logLevel, const char* tag) +{ + Aws::StringStream ss; + + switch(logLevel) + { + case LogLevel::Error: + ss << "[ERROR] "; + break; + + case LogLevel::Fatal: + ss << "[FATAL] "; + break; + + case LogLevel::Warn: + ss << "[WARN] "; + break; + + case LogLevel::Info: + ss << "[INFO] "; + break; + + case LogLevel::Debug: + ss << "[DEBUG] "; + break; + + case LogLevel::Trace: + ss << "[TRACE] "; + break; + + default: + ss << "[UNKOWN] "; + break; + } + + ss << DateTime::Now().CalculateGmtTimeWithMsPrecision() << " " << tag << " [" << std::this_thread::get_id() << "] "; + + return ss.str(); +} + +FormattedLogSystem::FormattedLogSystem(LogLevel logLevel) : + m_logLevel(logLevel) +{ +} + +void FormattedLogSystem::Log(LogLevel logLevel, const char* tag, const char* formatStr, ...) +{ + Aws::StringStream ss; + ss << CreateLogPrefixLine(logLevel, tag); + + std::va_list args; + va_start(args, formatStr); + + va_list tmp_args; //unfortunately you cannot consume a va_list twice + va_copy(tmp_args, args); //so we have to copy it + #ifdef WIN32 + const int requiredLength = _vscprintf(formatStr, tmp_args) + 1; + #else + const int requiredLength = vsnprintf(nullptr, 0, formatStr, tmp_args) + 1; + #endif + va_end(tmp_args); + + Array<char> outputBuff(requiredLength); + #ifdef WIN32 + vsnprintf_s(outputBuff.GetUnderlyingData(), requiredLength, _TRUNCATE, formatStr, args); + #else + vsnprintf(outputBuff.GetUnderlyingData(), requiredLength, formatStr, args); + #endif // WIN32 + + ss << outputBuff.GetUnderlyingData() << std::endl; + + ProcessFormattedStatement(ss.str()); + + va_end(args); +} + +void FormattedLogSystem::LogStream(LogLevel logLevel, const char* tag, const Aws::OStringStream &message_stream) +{ + ProcessFormattedStatement(CreateLogPrefixLine(logLevel, tag) + message_stream.str() + "\n"); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/LogLevel.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/LogLevel.cpp new file mode 100644 index 0000000000..9ff1bf3126 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/logging/LogLevel.cpp @@ -0,0 +1,45 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/logging/LogLevel.h> + +#include <aws/core/utils/memory/stl/AWSMap.h> +#include <aws/core/utils/memory/stl/AWSString.h> +#include <cassert> + +using namespace Aws::Utils::Logging; + +namespace Aws +{ +namespace Utils +{ +namespace Logging +{ + +Aws::String GetLogLevelName(LogLevel logLevel) +{ + switch (logLevel) + { + case LogLevel::Fatal: + return "FATAL"; + case LogLevel::Error: + return "ERROR"; + case LogLevel::Warn: + return "WARN"; + case LogLevel::Info: + return "INFO"; + case LogLevel::Debug: + return "DEBUG"; + case LogLevel::Trace: + return "TRACE"; + default: + assert(0); + return ""; + } +} + +} // namespace Logging +} // namespace Utils +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/memory/AWSMemory.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/memory/AWSMemory.cpp new file mode 100644 index 0000000000..96d339d385 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/memory/AWSMemory.cpp @@ -0,0 +1,134 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/memory/AWSMemory.h> + +#include <aws/core/utils/memory/MemorySystemInterface.h> +#include <aws/common/common.h> + +#include <atomic> + +using namespace Aws::Utils; +using namespace Aws::Utils::Memory; + +#ifdef USE_AWS_MEMORY_MANAGEMENT + static MemorySystemInterface* AWSMemorySystem(nullptr); +#endif // USE_AWS_MEMORY_MANAGEMENT + +namespace Aws +{ +namespace Utils +{ +namespace Memory +{ + +void InitializeAWSMemorySystem(MemorySystemInterface& memorySystem) +{ + #ifdef USE_AWS_MEMORY_MANAGEMENT + if(AWSMemorySystem != nullptr) + { + AWSMemorySystem->End(); + } + + AWSMemorySystem = &memorySystem; + AWSMemorySystem->Begin(); + #else + AWS_UNREFERENCED_PARAM(memorySystem); + #endif // USE_AWS_MEMORY_MANAGEMENT +} + +void ShutdownAWSMemorySystem(void) +{ + #ifdef USE_AWS_MEMORY_MANAGEMENT + if(AWSMemorySystem != nullptr) + { + AWSMemorySystem->End(); + } + AWSMemorySystem = nullptr; + #endif // USE_AWS_MEMORY_MANAGEMENT +} + +MemorySystemInterface* GetMemorySystem() +{ + #ifdef USE_AWS_MEMORY_MANAGEMENT + return AWSMemorySystem; + #else + return nullptr; + #endif // USE_AWS_MEMORY_MANAGEMENT +} + +} // namespace Memory +} // namespace Utils + +void* Malloc(const char* allocationTag, size_t allocationSize) +{ + Aws::Utils::Memory::MemorySystemInterface* memorySystem = Aws::Utils::Memory::GetMemorySystem(); + + void* rawMemory = nullptr; + if(memorySystem != nullptr) + { + rawMemory = memorySystem->AllocateMemory(allocationSize, 1, allocationTag); + } + else + { + rawMemory = malloc(allocationSize); + } + + return rawMemory; +} + + +void Free(void* memoryPtr) +{ + if(memoryPtr == nullptr) + { + return; + } + + Aws::Utils::Memory::MemorySystemInterface* memorySystem = Aws::Utils::Memory::GetMemorySystem(); + if(memorySystem != nullptr) + { + memorySystem->FreeMemory(memoryPtr); + } + else + { + free(memoryPtr); + } +} + +static void* MemAcquire(aws_allocator* allocator, size_t size) +{ + (void)allocator; // unused; + return Aws::Malloc("MemAcquire", size); +} + +static void MemRelease(aws_allocator* allocator, void* ptr) +{ + (void)allocator; // unused; + return Aws::Free(ptr); +} + +static aws_allocator create_aws_allocator() +{ +#if (__GNUC__ == 4) && !defined(__clang__) + AWS_SUPPRESS_WARNING("-Wmissing-field-initializers", aws_allocator wrapper{};); +#else + aws_allocator wrapper{}; +#endif + wrapper.mem_acquire = MemAcquire; + wrapper.mem_release = MemRelease; + wrapper.mem_realloc = nullptr; + return wrapper; +} + +aws_allocator* get_aws_allocator() +{ + static aws_allocator wrapper = create_aws_allocator(); + return &wrapper; +} + +} // namespace Aws + + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/memory/stl/SimpleStringStream.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/memory/stl/SimpleStringStream.cpp new file mode 100644 index 0000000000..4662749872 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/memory/stl/SimpleStringStream.cpp @@ -0,0 +1,66 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/memory/stl/SimpleStringStream.h> + +namespace Aws +{ + +SimpleStringStream::SimpleStringStream() : + base(&m_streamBuffer), + m_streamBuffer() +{ +} + +SimpleStringStream::SimpleStringStream(const Aws::String& value) : + base(&m_streamBuffer), + m_streamBuffer(value) +{ +} + +void SimpleStringStream::str(const Aws::String& value) +{ + m_streamBuffer.str(value); +} + +// + +SimpleIStringStream::SimpleIStringStream() : + base(&m_streamBuffer), + m_streamBuffer() +{ +} + +SimpleIStringStream::SimpleIStringStream(const Aws::String& value) : + base(&m_streamBuffer), + m_streamBuffer(value) +{ +} + +void SimpleIStringStream::str(const Aws::String& value) +{ + m_streamBuffer.str(value); +} + +// + +SimpleOStringStream::SimpleOStringStream() : + base(&m_streamBuffer), + m_streamBuffer() +{ +} + +SimpleOStringStream::SimpleOStringStream(const Aws::String& value) : + base(&m_streamBuffer), + m_streamBuffer(value) +{ +} + +void SimpleOStringStream::str(const Aws::String& value) +{ + m_streamBuffer.str(value); +} + +} // namespace Aws diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/ConcurrentStreamBuf.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/ConcurrentStreamBuf.cpp new file mode 100644 index 0000000000..3f59dbe96d --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/ConcurrentStreamBuf.cpp @@ -0,0 +1,126 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include <aws/core/utils/stream/ConcurrentStreamBuf.h> +#include <aws/core/utils/logging/LogMacros.h> +#include <cstdint> +#include <cassert> + +namespace Aws +{ + namespace Utils + { + namespace Stream + { + const char TAG[] = "ConcurrentStreamBuf"; + ConcurrentStreamBuf::ConcurrentStreamBuf(size_t bufferLength) : + m_putArea(bufferLength), // we access [0] of the put area below so we must initialize it. + m_eof(false) + { + m_getArea.reserve(bufferLength); + m_backbuf.reserve(bufferLength); + + char* pbegin = reinterpret_cast<char*>(&m_putArea[0]); + setp(pbegin, pbegin + bufferLength); + } + + void ConcurrentStreamBuf::SetEof() + { + { + std::unique_lock<std::mutex> lock(m_lock); + m_eof = true; + } + m_signal.notify_all(); + } + + void ConcurrentStreamBuf::FlushPutArea() + { + const size_t bitslen = pptr() - pbase(); + if (bitslen) + { + // scope the lock + { + std::unique_lock<std::mutex> lock(m_lock); + m_signal.wait(lock, [this, bitslen]{ return m_eof || bitslen <= (m_backbuf.capacity() - m_backbuf.size()); }); + if (m_eof) + { + return; + } + std::copy(pbase(), pptr(), std::back_inserter(m_backbuf)); + } + m_signal.notify_one(); + char* pbegin = reinterpret_cast<char*>(&m_putArea[0]); + setp(pbegin, pbegin + m_putArea.size()); + } + } + + std::streampos ConcurrentStreamBuf::seekoff(std::streamoff, std::ios_base::seekdir, std::ios_base::openmode) + { + return std::streamoff(-1); // Seeking is not supported. + } + + std::streampos ConcurrentStreamBuf::seekpos(std::streampos, std::ios_base::openmode) + { + return std::streamoff(-1); // Seeking is not supported. + } + + int ConcurrentStreamBuf::underflow() + { + { + std::unique_lock<std::mutex> lock(m_lock); + m_signal.wait(lock, [this]{ return m_backbuf.empty() == false || m_eof; }); + + if (m_eof && m_backbuf.empty()) + { + return std::char_traits<char>::eof(); + } + + m_getArea.clear(); // keep the get-area from growing unbounded. + std::copy(m_backbuf.begin(), m_backbuf.end(), std::back_inserter(m_getArea)); + m_backbuf.clear(); + } + m_signal.notify_one(); + char* gbegin = reinterpret_cast<char*>(&m_getArea[0]); + setg(gbegin, gbegin, gbegin + m_getArea.size()); + return std::char_traits<char>::to_int_type(*gptr()); + } + + std::streamsize ConcurrentStreamBuf::showmanyc() + { + std::unique_lock<std::mutex> lock(m_lock); + AWS_LOGSTREAM_TRACE(TAG, "stream how many character? " << m_backbuf.size()); + return m_backbuf.size(); + } + + int ConcurrentStreamBuf::overflow(int ch) + { + const auto eof = std::char_traits<char>::eof(); + + if (ch == eof) + { + FlushPutArea(); + return eof; + } + + FlushPutArea(); + { + std::unique_lock<std::mutex> lock(m_lock); + if (m_eof) + { + return eof; + } + *pptr() = static_cast<char>(ch); + pbump(1); + return ch; + } + } + + int ConcurrentStreamBuf::sync() + { + FlushPutArea(); + return 0; + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/PreallocatedStreamBuf.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/PreallocatedStreamBuf.cpp new file mode 100644 index 0000000000..f656fc8613 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/PreallocatedStreamBuf.cpp @@ -0,0 +1,75 @@ + +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/stream/PreallocatedStreamBuf.h> +#include <cassert> + +namespace Aws +{ + namespace Utils + { + namespace Stream + { + PreallocatedStreamBuf::PreallocatedStreamBuf(unsigned char* buffer, uint64_t lengthToRead) : + m_underlyingBuffer(buffer), m_lengthToRead(lengthToRead) + { + char* end = reinterpret_cast<char*>(m_underlyingBuffer + m_lengthToRead); + char* begin = reinterpret_cast<char*>(m_underlyingBuffer); + setp(begin, end); + setg(begin, begin, end); + } + + PreallocatedStreamBuf::pos_type PreallocatedStreamBuf::seekoff(off_type off, std::ios_base::seekdir dir, std::ios_base::openmode which) + { + if (dir == std::ios_base::beg) + { + return seekpos(off, which); + } + else if (dir == std::ios_base::end) + { + return seekpos(m_lengthToRead - off, which); + } + else if (dir == std::ios_base::cur) + { + if(which == std::ios_base::in) + { + return seekpos((gptr() - reinterpret_cast<char*>(m_underlyingBuffer)) + off, which); + } + else + { + return seekpos((pptr() - reinterpret_cast<char*>(m_underlyingBuffer)) + off, which); + } + } + + return off_type(-1); + } + + PreallocatedStreamBuf::pos_type PreallocatedStreamBuf::seekpos(pos_type pos, std::ios_base::openmode which) + { + assert(static_cast<size_t>(pos) <= m_lengthToRead); + if (static_cast<size_t>(pos) > m_lengthToRead) + { + return pos_type(off_type(-1)); + } + + char* end = reinterpret_cast<char*>(m_underlyingBuffer + m_lengthToRead); + char* begin = reinterpret_cast<char*>(m_underlyingBuffer); + + if (which == std::ios_base::in) + { + setg(begin, begin + static_cast<size_t>(pos), end); + } + + if (which == std::ios_base::out) + { + setp(begin + static_cast<size_t>(pos), end); + } + + return pos; + } + } + } +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/ResponseStream.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/ResponseStream.cpp new file mode 100644 index 0000000000..6d1f90ed12 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/ResponseStream.cpp @@ -0,0 +1,91 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/stream/ResponseStream.h> +#include <aws/core/utils/memory/stl/AWSStringStream.h> + +#if defined(_GLIBCXX_FULLY_DYNAMIC_STRING) && _GLIBCXX_FULLY_DYNAMIC_STRING == 0 && defined(__ANDROID__) +#include <aws/core/utils/stream/SimpleStreamBuf.h> +using DefaultStreamBufType = Aws::Utils::Stream::SimpleStreamBuf; +#else +using DefaultStreamBufType = Aws::StringBuf; +#endif + +using namespace Aws::Utils::Stream; + +ResponseStream::ResponseStream(void) : + m_underlyingStream(nullptr) +{ +} + +ResponseStream::ResponseStream(Aws::IOStream* underlyingStreamToManage) : + m_underlyingStream(underlyingStreamToManage) +{ +} + +ResponseStream::ResponseStream(const Aws::IOStreamFactory& factory) : + m_underlyingStream(factory()) +{ +} + +ResponseStream::ResponseStream(ResponseStream&& toMove) : m_underlyingStream(toMove.m_underlyingStream) +{ + toMove.m_underlyingStream = nullptr; +} + +ResponseStream& ResponseStream::operator=(ResponseStream&& toMove) +{ + if(m_underlyingStream == toMove.m_underlyingStream) + { + return *this; + } + + ReleaseStream(); + m_underlyingStream = toMove.m_underlyingStream; + toMove.m_underlyingStream = nullptr; + + return *this; +} + +ResponseStream::~ResponseStream() +{ + ReleaseStream(); +} + +void ResponseStream::ReleaseStream() +{ + if (m_underlyingStream) + { + m_underlyingStream->flush(); + Aws::Delete(m_underlyingStream); + } + + m_underlyingStream = nullptr; +} + +static const char *DEFAULT_STREAM_TAG = "DefaultUnderlyingStream"; + +DefaultUnderlyingStream::DefaultUnderlyingStream() : + Base( Aws::New< DefaultStreamBufType >( DEFAULT_STREAM_TAG ) ) +{} + +DefaultUnderlyingStream::DefaultUnderlyingStream(Aws::UniquePtr<std::streambuf> buf) : + Base(buf.release()) +{} + +DefaultUnderlyingStream::~DefaultUnderlyingStream() +{ + if( rdbuf() ) + { + Aws::Delete( rdbuf() ); + } +} + +static const char* RESPONSE_STREAM_FACTORY_TAG = "ResponseStreamFactory"; + +Aws::IOStream* Aws::Utils::Stream::DefaultResponseStreamFactoryMethod() +{ + return Aws::New<Aws::Utils::Stream::DefaultUnderlyingStream>(RESPONSE_STREAM_FACTORY_TAG); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/SimpleStreamBuf.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/SimpleStreamBuf.cpp new file mode 100644 index 0000000000..6e42994744 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/stream/SimpleStreamBuf.cpp @@ -0,0 +1,239 @@ + +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/stream/SimpleStreamBuf.h> + +#include <algorithm> +#include <cassert> +#include <cstring> + +namespace Aws +{ +namespace Utils +{ +namespace Stream +{ + +static const uint32_t DEFAULT_BUFFER_SIZE = 100; +static const char* SIMPLE_STREAMBUF_ALLOCATION_TAG = "SimpleStreamBufTag"; + +SimpleStreamBuf::SimpleStreamBuf() : + m_buffer(nullptr), + m_bufferSize(0) +{ + m_buffer = Aws::NewArray<char>(DEFAULT_BUFFER_SIZE, SIMPLE_STREAMBUF_ALLOCATION_TAG); + m_bufferSize = DEFAULT_BUFFER_SIZE; + + char* begin = m_buffer; + char* end = begin + m_bufferSize; + + setp(begin, end); + setg(begin, begin, begin); +} + +SimpleStreamBuf::SimpleStreamBuf(const Aws::String& value) : + m_buffer(nullptr), + m_bufferSize(0) +{ + size_t baseSize = (std::max)(value.size(), static_cast<std::size_t>(DEFAULT_BUFFER_SIZE)); + + m_buffer = Aws::NewArray<char>(baseSize, SIMPLE_STREAMBUF_ALLOCATION_TAG); + m_bufferSize = baseSize; + + std::memcpy(m_buffer, value.c_str(), value.size()); + + char* begin = m_buffer; + char* end = begin + m_bufferSize; + + setp(begin + value.size(), end); + setg(begin, begin, begin); +} + +SimpleStreamBuf::~SimpleStreamBuf() +{ + if(m_buffer) + { + Aws::DeleteArray<char>(m_buffer); + m_buffer = nullptr; + } + + m_bufferSize = 0; +} + +std::streampos SimpleStreamBuf::seekoff(std::streamoff off, std::ios_base::seekdir dir, std::ios_base::openmode which) +{ + if (dir == std::ios_base::beg) + { + return seekpos(off, which); + } + else if (dir == std::ios_base::end) + { + return seekpos((pptr() - m_buffer) - off, which); + } + else if (dir == std::ios_base::cur) + { + if(which == std::ios_base::in) + { + return seekpos((gptr() - m_buffer) + off, which); + } + else + { + return seekpos((pptr() - m_buffer) + off, which); + } + } + + return off_type(-1); +} + +std::streampos SimpleStreamBuf::seekpos(std::streampos pos, std::ios_base::openmode which) +{ + size_t maxSeek = pptr() - m_buffer; + assert(static_cast<size_t>(pos) <= maxSeek); + if (static_cast<size_t>(pos) > maxSeek) + { + return pos_type(off_type(-1)); + } + + if (which == std::ios_base::in) + { + setg(m_buffer, m_buffer + static_cast<size_t>(pos), pptr()); + } + + if (which == std::ios_base::out) + { + setp(m_buffer + static_cast<size_t>(pos), epptr()); + } + + return pos; +} + +bool SimpleStreamBuf::GrowBuffer() +{ + size_t currentSize = m_bufferSize; + size_t newSize = currentSize * 2; + + char* newBuffer = Aws::NewArray<char>(newSize, SIMPLE_STREAMBUF_ALLOCATION_TAG); + if(newBuffer == nullptr) + { + return false; + } + + if(currentSize > 0) + { + std::memcpy(newBuffer, m_buffer, currentSize); + } + + if(m_buffer) + { + Aws::DeleteArray<char>(m_buffer); + } + + m_buffer = newBuffer; + m_bufferSize = newSize; + + return true; +} + +int SimpleStreamBuf::overflow (int c) +{ + auto endOfFile = std::char_traits< char >::eof(); + if(c == endOfFile) + { + return endOfFile; + } + + char* old_begin = m_buffer; + + char *old_pptr = pptr(); + char *old_gptr = gptr(); + char *old_egptr = egptr(); + + size_t currentWritePosition = m_bufferSize; + + if(!GrowBuffer()) + { + return endOfFile; + } + + char* new_begin = m_buffer; + char* new_end = new_begin + m_bufferSize; + + setp(new_begin + (old_pptr - old_begin) + 1, new_end); + setg(new_begin, new_begin + (old_gptr - old_begin), new_begin + (old_egptr - old_begin)); + + auto val = std::char_traits< char >::to_char_type(c); + *(new_begin + currentWritePosition) = val; + + return c; +} + +std::streamsize SimpleStreamBuf::xsputn(const char* s, std::streamsize n) +{ + std::streamsize writeCount = 0; + while(writeCount < n) + { + char* current_pptr = pptr(); + char* current_epptr = epptr(); + + if (current_pptr < current_epptr) + { + std::size_t copySize = (std::min)(static_cast< std::size_t >(n - writeCount), + static_cast< std::size_t >(current_epptr - current_pptr)); + + std::memcpy(current_pptr, s + writeCount, copySize); + writeCount += copySize; + setp(current_pptr + copySize, current_epptr); + setg(m_buffer, gptr(), pptr()); + } + else if (overflow(std::char_traits< char >::to_int_type(*(s + writeCount))) != std::char_traits<char>::eof()) + { + writeCount++; + } + else + { + return writeCount; + } + } + + return writeCount; +} + +Aws::String SimpleStreamBuf::str() const +{ + return Aws::String(m_buffer, pptr()); +} + +int SimpleStreamBuf::underflow() +{ + if(egptr() != pptr()) + { + setg(m_buffer, gptr(), pptr()); + } + + if(gptr() != egptr()) + { + return std::char_traits< char >::to_int_type(*gptr()); + } + else + { + return std::char_traits< char >::eof(); + } +} + +void SimpleStreamBuf::str(const Aws::String& value) +{ + char* begin = m_buffer; + char* end = begin + m_bufferSize; + + setp(begin, end); + setg(begin, begin, begin); + + xsputn(value.c_str(), value.size()); +} + +} +} +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/Executor.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/Executor.cpp new file mode 100644 index 0000000000..4a3c4209c4 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/Executor.cpp @@ -0,0 +1,155 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/threading/Executor.h> +#include <aws/core/utils/threading/ThreadTask.h> +#include <thread> +#include <cassert> + +static const char* POOLED_CLASS_TAG = "PooledThreadExecutor"; + +using namespace Aws::Utils::Threading; + +bool DefaultExecutor::SubmitToThread(std::function<void()>&& fx) +{ + auto main = [fx, this] { + fx(); + Detach(std::this_thread::get_id()); + }; + + State expected; + do + { + expected = State::Free; + if(m_state.compare_exchange_strong(expected, State::Locked)) + { + std::thread t(main); + const auto id = t.get_id(); // copy the id before we std::move the thread + m_threads.emplace(id, std::move(t)); + m_state = State::Free; + return true; + } + } + while(expected != State::Shutdown); + return false; +} + +void DefaultExecutor::Detach(std::thread::id id) +{ + State expected; + do + { + expected = State::Free; + if(m_state.compare_exchange_strong(expected, State::Locked)) + { + auto it = m_threads.find(id); + assert(it != m_threads.end()); + it->second.detach(); + m_threads.erase(it); + m_state = State::Free; + return; + } + } + while(expected != State::Shutdown); +} + +DefaultExecutor::~DefaultExecutor() +{ + auto expected = State::Free; + while(!m_state.compare_exchange_strong(expected, State::Shutdown)) + { + //spin while currently detaching threads finish + assert(expected == State::Locked); + expected = State::Free; + } + + auto it = m_threads.begin(); + while(!m_threads.empty()) + { + it->second.join(); + it = m_threads.erase(it); + } +} + +PooledThreadExecutor::PooledThreadExecutor(size_t poolSize, OverflowPolicy overflowPolicy) : + m_sync(0, poolSize), m_poolSize(poolSize), m_overflowPolicy(overflowPolicy) +{ + for (size_t index = 0; index < m_poolSize; ++index) + { + m_threadTaskHandles.push_back(Aws::New<ThreadTask>(POOLED_CLASS_TAG, *this)); + } +} + +PooledThreadExecutor::~PooledThreadExecutor() +{ + for(auto threadTask : m_threadTaskHandles) + { + threadTask->StopProcessingWork(); + } + + m_sync.ReleaseAll(); + + for (auto threadTask : m_threadTaskHandles) + { + Aws::Delete(threadTask); + } + + while(m_tasks.size() > 0) + { + std::function<void()>* fn = m_tasks.front(); + m_tasks.pop(); + + if(fn) + { + Aws::Delete(fn); + } + } + +} + +bool PooledThreadExecutor::SubmitToThread(std::function<void()>&& fn) +{ + //avoid the need to do copies inside the lock. Instead lets do a pointer push. + std::function<void()>* fnCpy = Aws::New<std::function<void()>>(POOLED_CLASS_TAG, std::forward<std::function<void()>>(fn)); + + { + std::lock_guard<std::mutex> locker(m_queueLock); + + if (m_overflowPolicy == OverflowPolicy::REJECT_IMMEDIATELY && m_tasks.size() >= m_poolSize) + { + Aws::Delete(fnCpy); + return false; + } + + m_tasks.push(fnCpy); + } + + m_sync.Release(); + + return true; +} + +std::function<void()>* PooledThreadExecutor::PopTask() +{ + std::lock_guard<std::mutex> locker(m_queueLock); + + if (m_tasks.size() > 0) + { + std::function<void()>* fn = m_tasks.front(); + if (fn) + { + m_tasks.pop(); + return fn; + } + } + + return nullptr; +} + +bool PooledThreadExecutor::HasTasks() +{ + std::lock_guard<std::mutex> locker(m_queueLock); + return m_tasks.size() > 0; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/ReaderWriterLock.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/ReaderWriterLock.cpp new file mode 100644 index 0000000000..ddb5860563 --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/ReaderWriterLock.cpp @@ -0,0 +1,64 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/threading/ReaderWriterLock.h> +#include <cstdint> +#include <limits> +#include <cassert> + +using namespace Aws::Utils::Threading; + +static const int64_t MaxReaders = (std::numeric_limits<std::int32_t>::max)(); + +ReaderWriterLock::ReaderWriterLock() : + m_readers(0), + m_holdouts(0), + m_readerSem(0, static_cast<size_t>(MaxReaders)), + m_writerSem(0, 1) +{ +} + +void ReaderWriterLock::LockReader() +{ + if (++m_readers < 0) + { + m_readerSem.WaitOne(); + } +} + +void ReaderWriterLock::UnlockReader() +{ + if (--m_readers < 0 && --m_holdouts == 0) + { + m_writerSem.Release(); + } +} + +void ReaderWriterLock::LockWriter() +{ + m_writerLock.lock(); + if(const auto current = m_readers.fetch_sub(MaxReaders)) + { + assert(current > 0); + const auto holdouts = m_holdouts.fetch_add(current) + current; + assert(holdouts >= 0); + if(holdouts > 0) + { + m_writerSem.WaitOne(); + } + } +} + +void ReaderWriterLock::UnlockWriter() +{ + assert(m_holdouts == 0); + const auto current = m_readers.fetch_add(MaxReaders) + MaxReaders; + assert(current >= 0); + for(int64_t r = 0; r < current; r++) + { + m_readerSem.Release(); + } + m_writerLock.unlock(); +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/Semaphore.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/Semaphore.cpp new file mode 100644 index 0000000000..86dabc9acf --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/Semaphore.cpp @@ -0,0 +1,39 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/threading/Semaphore.h> +#include <algorithm> + +using namespace Aws::Utils::Threading; + +Semaphore::Semaphore(size_t initialCount, size_t maxCount) + : m_count(initialCount), m_maxCount(maxCount) +{ +} + +void Semaphore::WaitOne() +{ + std::unique_lock<std::mutex> locker(m_mutex); + if(0 == m_count) + { + m_syncPoint.wait(locker, [this] { return m_count > 0; }); + } + --m_count; +} + +void Semaphore::Release() +{ + std::lock_guard<std::mutex> locker(m_mutex); + m_count = (std::min)(m_maxCount, m_count + 1); + m_syncPoint.notify_one(); +} + +void Semaphore::ReleaseAll() +{ + std::lock_guard<std::mutex> locker(m_mutex); + m_count = m_maxCount; + m_syncPoint.notify_all(); +} + diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/ThreadTask.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/ThreadTask.cpp new file mode 100644 index 0000000000..a899fe045d --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/threading/ThreadTask.cpp @@ -0,0 +1,46 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/threading/ThreadTask.h> +#include <aws/core/utils/threading/Executor.h> + +using namespace Aws::Utils; +using namespace Aws::Utils::Threading; + +ThreadTask::ThreadTask(PooledThreadExecutor& executor) : m_continue(true), m_executor(executor), m_thread(std::bind(&ThreadTask::MainTaskRunner, this)) +{ +} + +ThreadTask::~ThreadTask() +{ + StopProcessingWork(); + m_thread.join(); +} + +void ThreadTask::MainTaskRunner() +{ + while (m_continue) + { + while (m_continue && m_executor.HasTasks()) + { + auto fn = m_executor.PopTask(); + if(fn) + { + (*fn)(); + Aws::Delete(fn); + } + } + + if(m_continue) + { + m_executor.m_sync.WaitOne(); + } + } +} + +void ThreadTask::StopProcessingWork() +{ + m_continue = false; +} diff --git a/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/xml/XmlSerializer.cpp b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/xml/XmlSerializer.cpp new file mode 100644 index 0000000000..c06befaf9b --- /dev/null +++ b/contrib/libs/aws-sdk-cpp/aws-cpp-sdk-core/source/utils/xml/XmlSerializer.cpp @@ -0,0 +1,302 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include <aws/core/utils/xml/XmlSerializer.h> + +#include <aws/core/utils/StringUtils.h> +#include <aws/core/external/tinyxml2/tinyxml2.h> + +#include <utility> +#include <algorithm> +#include <iostream> + +using namespace Aws::Utils::Xml; +using namespace Aws::Utils; + +Aws::String Aws::Utils::Xml::DecodeEscapedXmlText(const Aws::String& textToDecode) +{ + Aws::String decodedString = textToDecode; + StringUtils::Replace(decodedString, """, "\""); + StringUtils::Replace(decodedString, "'", "'"); + StringUtils::Replace(decodedString, "<", "<"); + StringUtils::Replace(decodedString, ">", ">"); + StringUtils::Replace(decodedString, "&", "&"); + + return decodedString; +} + +XmlNode::XmlNode(const XmlNode& other) : m_node(other.m_node), m_doc(other.m_doc) +{ +} + +XmlNode& XmlNode::operator=(const XmlNode& other) +{ + if(this == &other) + { + return *this; + } + + m_node = other.m_node; + m_doc = other.m_doc; + + return *this; +} + +const Aws::String XmlNode::GetName() const +{ + return m_node->Value(); +} + +void XmlNode::SetName(const Aws::String& name) +{ + m_node->SetValue(name.c_str(), false); +} + +const Aws::String XmlNode::GetAttributeValue(const Aws::String& name) const +{ + auto pointer = m_node->ToElement()->Attribute(name.c_str(), nullptr); + return pointer ? pointer : ""; +} + +void XmlNode::SetAttributeValue(const Aws::String& name, const Aws::String& value) +{ + m_node->ToElement()->SetAttribute(name.c_str(), value.c_str()); +} + +bool XmlNode::HasNextNode() const +{ + return m_node->NextSibling() != nullptr; +} + +XmlNode XmlNode::NextNode() const +{ + return XmlNode(m_node->NextSiblingElement(), *m_doc); +} + +XmlNode XmlNode::NextNode(const char* name) const +{ + return XmlNode(m_node->NextSiblingElement(name), *m_doc); +} + +XmlNode XmlNode::NextNode(const Aws::String& name) const +{ + return NextNode(name.c_str()); +} + +XmlNode XmlNode::FirstChild() const +{ + return XmlNode(m_node->FirstChildElement(), *m_doc); +} + +XmlNode XmlNode::FirstChild(const char* name) const +{ + return XmlNode(m_node->FirstChildElement(name), *m_doc); +} + +XmlNode XmlNode::FirstChild(const Aws::String& name) const +{ + return FirstChild(name.c_str()); +} + +bool XmlNode::HasChildren() const +{ + return !m_node->NoChildren(); +} + +XmlNode XmlNode::Parent() const +{ + return XmlNode(m_node->Parent()->ToElement(), *m_doc); +} + +Aws::String XmlNode::GetText() const +{ + if (m_node != nullptr) + { + Aws::External::tinyxml2::XMLPrinter printer; + Aws::External::tinyxml2::XMLNode* node = m_node->FirstChild(); + while (node != nullptr) + { + node->Accept(&printer); + node = node->NextSibling(); + } + + return printer.CStr(); + } + + return {}; +} + +void XmlNode::SetText(const Aws::String& textValue) +{ + if (m_node != nullptr) + { + Aws::External::tinyxml2::XMLText* text = m_doc->m_doc->NewText(textValue.c_str()); + m_node->InsertEndChild(text); + } +} + +XmlNode XmlNode::CreateChildElement(const Aws::String& name) +{ + Aws::External::tinyxml2::XMLElement* element = m_doc->m_doc->NewElement(name.c_str()); + return XmlNode(m_node->InsertEndChild(element), *m_doc); +} + +XmlNode XmlNode::CreateSiblingElement(const Aws::String& name) +{ + Aws::External::tinyxml2::XMLElement* element = m_doc->m_doc->NewElement(name.c_str()); + return XmlNode(m_node->Parent()->InsertEndChild(element), *m_doc); +} + +bool XmlNode::IsNull() +{ + return m_node == nullptr; +} + +static const char* XML_SERIALIZER_ALLOCATION_TAG = "XmlDocument"; + +XmlDocument::XmlDocument() +{ + m_doc = nullptr; +} + +XmlDocument::XmlDocument(const XmlDocument& doc) +{ + if (doc.m_doc == nullptr) + { + m_doc = nullptr; + } + else + { + InitDoc(); + doc.m_doc->DeepCopy(m_doc); + } +} + +XmlDocument::XmlDocument(XmlDocument&& doc) : m_doc{ doc.m_doc } // take the innards +{ + doc.m_doc = nullptr; // leave nothing behind +} + +XmlDocument& XmlDocument::operator=(const XmlDocument& other) +{ + if (this == &other) + { + return *this; + } + + if (other.m_doc == nullptr) + { + if (m_doc != nullptr) + { + m_doc->Clear(); + m_doc = nullptr; + } + } + else + { + if (m_doc == nullptr) + { + InitDoc(); + } + else + { + m_doc->Clear(); + } + other.m_doc->DeepCopy(m_doc); + } + + return *this; +} + +XmlDocument& XmlDocument::operator=(XmlDocument&& other) +{ + if (this == &other) + { + return *this; + } + + std::swap(m_doc, other.m_doc); + return *this; +} + +XmlDocument::~XmlDocument() +{ + if (m_doc) + { + Aws::Delete(m_doc); + } +} + +void XmlDocument::InitDoc() +{ + m_doc = Aws::New<Aws::External::tinyxml2::XMLDocument>(XML_SERIALIZER_ALLOCATION_TAG, true, Aws::External::tinyxml2::Whitespace::PRESERVE_WHITESPACE); +} + +XmlNode XmlDocument::GetRootElement() const +{ + if (m_doc) + { + return XmlNode(m_doc->FirstChildElement(), *this); + } + else + { + return XmlNode(nullptr, *this); + } + +} + +bool XmlDocument::WasParseSuccessful() const +{ + if (m_doc) + { + return !m_doc->Error(); + } + else + { + return true; + } + +} + +Aws::String XmlDocument::GetErrorMessage() const +{ + return !WasParseSuccessful() ? m_doc->ErrorName() : ""; +} + +Aws::String XmlDocument::ConvertToString() const +{ + if (!m_doc) return ""; + + Aws::External::tinyxml2::XMLPrinter printer; + printer.PushHeader(false, true); + m_doc->Accept(&printer); + + return printer.CStr(); +} + +XmlDocument XmlDocument::CreateFromXmlStream(Aws::IOStream& xmlStream) +{ + Aws::String xmlString((Aws::IStreamBufIterator(xmlStream)), Aws::IStreamBufIterator()); + return CreateFromXmlString(xmlString); +} + +XmlDocument XmlDocument::CreateFromXmlString(const Aws::String& xmlText) +{ + XmlDocument xmlDocument; + xmlDocument.InitDoc(); + xmlDocument.m_doc->Parse(xmlText.c_str(), xmlText.size()); + return xmlDocument; +} + +XmlDocument XmlDocument::CreateWithRootNode(const Aws::String& rootNodeName) +{ + XmlDocument xmlDocument; + xmlDocument.InitDoc(); + Aws::External::tinyxml2::XMLElement* rootNode = xmlDocument.m_doc->NewElement(rootNodeName.c_str()); + xmlDocument.m_doc->LinkEndChild(rootNode); + + return xmlDocument; +} + |