diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/openssl/io/stream.cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/openssl/io/stream.cpp')
-rw-r--r-- | library/cpp/openssl/io/stream.cpp | 329 |
1 files changed, 329 insertions, 0 deletions
diff --git a/library/cpp/openssl/io/stream.cpp b/library/cpp/openssl/io/stream.cpp new file mode 100644 index 0000000000..0b4be38c0e --- /dev/null +++ b/library/cpp/openssl/io/stream.cpp @@ -0,0 +1,329 @@ +#include "stream.h" + +#include <util/generic/deque.h> +#include <util/generic/singleton.h> +#include <util/generic/yexception.h> + +#include <library/cpp/openssl/init/init.h> +#include <library/cpp/openssl/method/io.h> +#include <library/cpp/resource/resource.h> + +#include <openssl/bio.h> +#include <openssl/ssl.h> +#include <openssl/err.h> +#include <openssl/tls1.h> +#include <openssl/x509v3.h> + +using TOptions = TOpenSslClientIO::TOptions; + +namespace { + struct TSslIO; + + struct TSslInitOnDemand { + inline TSslInitOnDemand() { + InitOpenSSL(); + } + }; + + int GetLastSslError() noexcept { + return ERR_peek_last_error(); + } + + const char* SslErrorText(int error) noexcept { + return ERR_error_string(error, nullptr); + } + + inline TStringBuf SslLastError() noexcept { + return SslErrorText(GetLastSslError()); + } + + struct TSslError: public yexception { + inline TSslError() { + *this << SslLastError(); + } + }; + + struct TSslDestroy { + static inline void Destroy(ssl_ctx_st* ctx) noexcept { + SSL_CTX_free(ctx); + } + + static inline void Destroy(ssl_st* ssl) noexcept { + SSL_free(ssl); + } + + static inline void Destroy(bio_st* bio) noexcept { + BIO_free(bio); + } + + static inline void Destroy(x509_st* x509) noexcept { + X509_free(x509); + } + }; + + template <class T> + using TSslHolderPtr = THolder<T, TSslDestroy>; + + using TSslContextPtr = TSslHolderPtr<ssl_ctx_st>; + using TSslPtr = TSslHolderPtr<ssl_st>; + using TBioPtr = TSslHolderPtr<bio_st>; + using TX509Ptr = TSslHolderPtr<x509_st>; + + inline TSslContextPtr CreateSslCtx(const ssl_method_st* method) { + TSslContextPtr ctx(SSL_CTX_new(method)); + + if (!ctx) { + ythrow TSslError() << "SSL_CTX_new"; + } + + SSL_CTX_set_options(ctx.Get(), SSL_OP_NO_SSLv2); + SSL_CTX_set_options(ctx.Get(), SSL_OP_NO_SSLv3); + SSL_CTX_set_options(ctx.Get(), SSL_OP_MICROSOFT_SESS_ID_BUG); + SSL_CTX_set_options(ctx.Get(), SSL_OP_NETSCAPE_CHALLENGE_BUG); + + return ctx; + } + + struct TStreamIO : public NOpenSSL::TAbstractIO { + inline TStreamIO(IInputStream* in, IOutputStream* out) + : In(in) + , Out(out) + { + } + + int Write(const char* data, size_t dlen, size_t* written) override { + Out->Write(data, dlen); + + *written = dlen; + return 1; + } + + int Read(char* data, size_t dlen, size_t* readbytes) override { + *readbytes = In->Read(data, dlen); + return 1; + } + + int Puts(const char* buf) override { + Y_UNUSED(buf); + return -1; + } + + int Gets(char* buf, int size) override { + Y_UNUSED(buf); + Y_UNUSED(size); + return -1; + } + + void Flush() override { + } + + IInputStream* In; + IOutputStream* Out; + }; + + struct TSslIO: public TSslInitOnDemand, public TOptions { + inline TSslIO(IInputStream* in, IOutputStream* out, const TOptions& opts) + : TOptions(opts) + , Io(in, out) + , Ctx(CreateClientContext()) + , Ssl(ConstructSsl()) + { + Connect(); + } + + inline TSslContextPtr CreateClientContext() { + TSslContextPtr ctx = CreateSslCtx(SSLv23_client_method()); + if (ClientCert_) { + if (!ClientCert_->CertificateFile_ || !ClientCert_->PrivateKeyFile_) { + ythrow yexception() << "both client certificate and private key are required"; + } + if (ClientCert_->PrivateKeyPassword_) { + SSL_CTX_set_default_passwd_cb(ctx.Get(), [](char* buf, int size, int rwflag, void* userData) -> int { + Y_UNUSED(rwflag); + auto io = static_cast<TSslIO*>(userData); + if (!io) { + return -1; + } + if (size < static_cast<int>(io->ClientCert_->PrivateKeyPassword_.size())) { + return -1; + } + return io->ClientCert_->PrivateKeyPassword_.copy(buf, size, 0); + }); + SSL_CTX_set_default_passwd_cb_userdata(ctx.Get(), this); + } + if (1 != SSL_CTX_use_certificate_chain_file(ctx.Get(), ClientCert_->CertificateFile_.c_str())) { + ythrow TSslError() << "SSL_CTX_use_certificate_chain_file"; + } + if (1 != SSL_CTX_use_PrivateKey_file(ctx.Get(), ClientCert_->PrivateKeyFile_.c_str(), SSL_FILETYPE_PEM)) { + ythrow TSslError() << "SSL_CTX_use_PrivateKey_file"; + } + if (1 != SSL_CTX_check_private_key(ctx.Get())) { + ythrow TSslError() << "SSL_CTX_check_private_key (client)"; + } + } + return ctx; + } + + inline TSslPtr ConstructSsl() { + TSslPtr ssl(SSL_new(Ctx.Get())); + + if (!ssl) { + ythrow TSslError() << "SSL_new"; + } + + if (VerifyCert_) { + InitVerification(ssl.Get()); + } + + BIO_up_ref(Io); // SSL_set_bio consumes only one reference if rbio and wbio are the same + SSL_set_bio(ssl.Get(), Io, Io); + + return ssl; + } + + inline void InitVerification(ssl_st* ssl) { + X509_VERIFY_PARAM* param = SSL_get0_param(ssl); + X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); + Y_ENSURE(X509_VERIFY_PARAM_set1_host(param, VerifyCert_->Hostname_.data(), VerifyCert_->Hostname_.size())); + SSL_set_tlsext_host_name(ssl, VerifyCert_->Hostname_.data()); // TLS extenstion: SNI + + SSL_CTX_set_cert_store(Ctx.Get(), GetBuiltinOpenSslX509Store().Release()); + + Y_ENSURE_EX(1 == SSL_CTX_set_default_verify_paths(Ctx.Get()), + TSslError()); + // it is OK to ignore result of SSL_CTX_load_verify_locations(): + // Dir "/etc/ssl/certs/" may be missing + SSL_CTX_load_verify_locations(Ctx.Get(), + "/etc/ssl/certs/ca-certificates.crt", + "/etc/ssl/certs/"); + + SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr); + } + + inline void Connect() { + if (SSL_connect(Ssl.Get()) != 1) { + ythrow TSslError() << "SSL_connect"; + } + } + + inline void Finish() const { + SSL_shutdown(Ssl.Get()); + } + + inline size_t Read(void* buf, size_t len) { + const int ret = SSL_read(Ssl.Get(), buf, len); + + if (ret < 0) { + ythrow TSslError() << "SSL_read"; + } + + return ret; + } + + inline void Write(const char* buf, size_t len) { + while (len) { + const int ret = SSL_write(Ssl.Get(), buf, len); + + if (ret < 0) { + ythrow TSslError() << "SSL_write"; + } + + buf += (size_t)ret; + len -= (size_t)ret; + } + } + + TStreamIO Io; + TSslContextPtr Ctx; + TSslPtr Ssl; + }; +} + +struct TOpenSslClientIO::TImpl: public TSslIO { + inline TImpl(IInputStream* in, IOutputStream* out, const TOptions& opts) + : TSslIO(in, out, opts) + { + } +}; + +TOpenSslClientIO::TOpenSslClientIO(IInputStream* in, IOutputStream* out) + : Impl_(new TImpl(in, out, TOptions())) +{ +} + +TOpenSslClientIO::TOpenSslClientIO(IInputStream* in, IOutputStream* out, const TOptions& options) + : Impl_(new TImpl(in, out, options)) +{ +} + +TOpenSslClientIO::~TOpenSslClientIO() { + try { + Impl_->Finish(); + } catch (...) { + } +} + +void TOpenSslClientIO::DoWrite(const void* buf, size_t len) { + Impl_->Write((const char*)buf, len); +} + +size_t TOpenSslClientIO::DoRead(void* buf, size_t len) { + return Impl_->Read(buf, len); +} + +namespace NPrivate { + void TSslDestroy::Destroy(x509_store_st* x509) noexcept { + X509_STORE_free(x509); + } +} + +class TBuiltinCerts { +public: + TBuiltinCerts() { + TString c = NResource::Find("/builtin/cacert"); + + TBioPtr cbio(BIO_new_mem_buf(c.data(), c.size())); + Y_ENSURE_EX(cbio, TSslError() << "BIO_new_mem_buf"); + + while (true) { + TX509Ptr cert(PEM_read_bio_X509(cbio.Get(), nullptr, nullptr, nullptr)); + if (!cert) { + break; + } + Certs.push_back(std::move(cert)); + } + + int err = GetLastSslError(); + if (!Certs.empty() && ERR_GET_LIB(err) == ERR_LIB_PEM && ERR_GET_REASON(err) == PEM_R_NO_START_LINE) { + ERR_clear_error(); + } else { + ythrow TSslError() << "can't load provided bundle: " << ERR_reason_error_string(err); + } + + Y_ENSURE_EX(!Certs.empty(), TSslError()); + } + + TOpenSslX509StorePtr GetX509Store() const { + TOpenSslX509StorePtr store(X509_STORE_new()); + + for (const TX509Ptr& c : Certs) { + if (0 == X509_STORE_add_cert(store.Get(), c.Get())) { + int err = GetLastSslError(); + if (ERR_GET_LIB(err) == ERR_LIB_X509 && ERR_GET_REASON(err) == X509_R_CERT_ALREADY_IN_HASH_TABLE) { + ERR_clear_error(); + } else { + ythrow TSslError() << "can't load provided bundle: " << ERR_reason_error_string(err); + } + } + } + + return store; + } + +private: + TDeque<TX509Ptr> Certs; +}; + +TOpenSslX509StorePtr GetBuiltinOpenSslX509Store() { + return Singleton<TBuiltinCerts>()->GetX509Store(); +} |