diff options
| author | Alexander Rutkovsky <[email protected]> | 2022-02-10 16:47:40 +0300 |
|---|---|---|
| committer | Daniil Cherednik <[email protected]> | 2022-02-10 16:47:40 +0300 |
| commit | 667a4ee7da2e004784b9c3cfab824a81e96f4d66 (patch) | |
| tree | c0748b5dcbade83af788c0abfa89c0383d6b779c /library/cpp/actors/interconnect/interconnect_stream.cpp | |
| parent | f3646f91e0de459836a7800b9ce3e8dc57a2ab3a (diff) | |
Restoring authorship annotation for Alexander Rutkovsky <[email protected]>. Commit 2 of 2.
Diffstat (limited to 'library/cpp/actors/interconnect/interconnect_stream.cpp')
| -rw-r--r-- | library/cpp/actors/interconnect/interconnect_stream.cpp | 922 |
1 files changed, 461 insertions, 461 deletions
diff --git a/library/cpp/actors/interconnect/interconnect_stream.cpp b/library/cpp/actors/interconnect/interconnect_stream.cpp index d1cfff2c4a6..158ebc9e1d5 100644 --- a/library/cpp/actors/interconnect/interconnect_stream.cpp +++ b/library/cpp/actors/interconnect/interconnect_stream.cpp @@ -1,44 +1,44 @@ -#include "interconnect_stream.h" -#include "logging.h" +#include "interconnect_stream.h" +#include "logging.h" #include <library/cpp/openssl/init/init.h> -#include <util/network/socket.h> -#include <openssl/ssl.h> -#include <openssl/err.h> -#include <openssl/pem.h> - -#if defined(_win_) -#include <util/system/file.h> -#define SOCK_NONBLOCK 0 -#elif defined(_darwin_) -#define SOCK_NONBLOCK 0 -#else -#include <sys/un.h> -#include <sys/stat.h> -#endif //_win_ - -#if !defined(_win_) -#include <sys/ioctl.h> -#endif - -#include <cerrno> - -namespace NInterconnect { +#include <util/network/socket.h> +#include <openssl/ssl.h> +#include <openssl/err.h> +#include <openssl/pem.h> + +#if defined(_win_) +#include <util/system/file.h> +#define SOCK_NONBLOCK 0 +#elif defined(_darwin_) +#define SOCK_NONBLOCK 0 +#else +#include <sys/un.h> +#include <sys/stat.h> +#endif //_win_ + +#if !defined(_win_) +#include <sys/ioctl.h> +#endif + +#include <cerrno> + +namespace NInterconnect { namespace { inline int LastSocketError() { -#if defined(_win_) +#if defined(_win_) return WSAGetLastError(); -#else +#else return errno; -#endif +#endif } } - + TSocket::TSocket(SOCKET fd) : Descriptor(fd) { } - + TSocket::~TSocket() { if (Descriptor == INVALID_SOCKET) { return; @@ -57,30 +57,30 @@ namespace NInterconnect { default: Y_FAIL("It's something unexpected"); } - } - + } + int TSocket::GetDescriptor() { return Descriptor; - } - + } + int TSocket::Bind(const TAddress& addr) const { const auto ret = ::bind(Descriptor, addr.SockAddr(), addr.Size()); if (ret < 0) return -LastSocketError(); - + return 0; } - + int TSocket::Shutdown(int how) const { const auto ret = ::shutdown(Descriptor, how); if (ret < 0) return -LastSocketError(); - + return 0; } - + int TSocket::GetConnectStatus() const { int err = 0; socklen_t len = sizeof(err); @@ -88,190 +88,190 @@ namespace NInterconnect { err = LastSocketError(); } return err; - } - + } + ///////////////////////////////////////////////////////////////// - - TIntrusivePtr<TStreamSocket> TStreamSocket::Make(int domain) { - const SOCKET res = ::socket(domain, SOCK_STREAM | SOCK_NONBLOCK, 0); - if (res == -1) { - const int err = LastSocketError(); - Y_VERIFY(err != EMFILE && err != ENFILE); - } - return MakeIntrusive<TStreamSocket>(res); + + TIntrusivePtr<TStreamSocket> TStreamSocket::Make(int domain) { + const SOCKET res = ::socket(domain, SOCK_STREAM | SOCK_NONBLOCK, 0); + if (res == -1) { + const int err = LastSocketError(); + Y_VERIFY(err != EMFILE && err != ENFILE); + } + return MakeIntrusive<TStreamSocket>(res); } - + TStreamSocket::TStreamSocket(SOCKET fd) : TSocket(fd) { } - + ssize_t - TStreamSocket::Send(const void* msg, size_t len, TString* /*err*/) const { - const auto ret = ::send(Descriptor, static_cast<const char*>(msg), int(len), 0); + TStreamSocket::Send(const void* msg, size_t len, TString* /*err*/) const { + const auto ret = ::send(Descriptor, static_cast<const char*>(msg), int(len), 0); if (ret < 0) return -LastSocketError(); - + return ret; } - + ssize_t - TStreamSocket::Recv(void* buf, size_t len, TString* /*err*/) const { - const auto ret = ::recv(Descriptor, static_cast<char*>(buf), int(len), 0); + TStreamSocket::Recv(void* buf, size_t len, TString* /*err*/) const { + const auto ret = ::recv(Descriptor, static_cast<char*>(buf), int(len), 0); if (ret < 0) return -LastSocketError(); - + return ret; } - + ssize_t TStreamSocket::WriteV(const struct iovec* iov, int iovcnt) const { -#ifndef _win_ +#ifndef _win_ const auto ret = ::writev(Descriptor, iov, iovcnt); if (ret < 0) return -LastSocketError(); return ret; -#else - Y_FAIL("WriteV() unsupported on Windows"); -#endif +#else + Y_FAIL("WriteV() unsupported on Windows"); +#endif } - + ssize_t TStreamSocket::ReadV(const struct iovec* iov, int iovcnt) const { -#ifndef _win_ +#ifndef _win_ const auto ret = ::readv(Descriptor, iov, iovcnt); if (ret < 0) return -LastSocketError(); return ret; -#else - Y_FAIL("ReadV() unsupported on Windows"); -#endif +#else + Y_FAIL("ReadV() unsupported on Windows"); +#endif } - + ssize_t TStreamSocket::GetUnsentQueueSize() const { int num = -1; -#ifndef _win_ // we have no means to determine output queue size on Windows +#ifndef _win_ // we have no means to determine output queue size on Windows if (ioctl(Descriptor, TIOCOUTQ, &num) == -1) { num = -1; } #endif return num; - } - + } + int TStreamSocket::Connect(const TAddress& addr) const { const auto ret = ::connect(Descriptor, addr.SockAddr(), addr.Size()); if (ret < 0) return -LastSocketError(); - + return ret; } - + int TStreamSocket::Connect(const NAddr::IRemoteAddr* addr) const { const auto ret = ::connect(Descriptor, addr->Addr(), addr->Len()); if (ret < 0) return -LastSocketError(); - + return ret; } - + int TStreamSocket::Listen(int backlog) const { const auto ret = ::listen(Descriptor, backlog); if (ret < 0) return -LastSocketError(); - + return ret; } - + int TStreamSocket::Accept(TAddress& acceptedAddr) const { socklen_t acceptedSize = sizeof(::sockaddr_in6); const auto ret = ::accept(Descriptor, acceptedAddr.SockAddr(), &acceptedSize); if (ret == INVALID_SOCKET) return -LastSocketError(); - + return ret; } - + void TStreamSocket::SetSendBufferSize(i32 len) const { (void)SetSockOpt(Descriptor, SOL_SOCKET, SO_SNDBUF, len); } - - ui32 TStreamSocket::GetSendBufferSize() const { - ui32 res = 0; - CheckedGetSockOpt(Descriptor, SOL_SOCKET, SO_SNDBUF, res, "SO_SNDBUF"); - return res; - } - + + ui32 TStreamSocket::GetSendBufferSize() const { + ui32 res = 0; + CheckedGetSockOpt(Descriptor, SOL_SOCKET, SO_SNDBUF, res, "SO_SNDBUF"); + return res; + } + ////////////////////////////////////////////////////// - - TDatagramSocket::TPtr TDatagramSocket::Make(int domain) { - const SOCKET res = ::socket(domain, SOCK_DGRAM, 0); - if (res == -1) { - const int err = LastSocketError(); - Y_VERIFY(err != EMFILE && err != ENFILE); - } - return std::make_shared<TDatagramSocket>(res); + + TDatagramSocket::TPtr TDatagramSocket::Make(int domain) { + const SOCKET res = ::socket(domain, SOCK_DGRAM, 0); + if (res == -1) { + const int err = LastSocketError(); + Y_VERIFY(err != EMFILE && err != ENFILE); + } + return std::make_shared<TDatagramSocket>(res); } - + TDatagramSocket::TDatagramSocket(SOCKET fd) : TSocket(fd) { } - + ssize_t TDatagramSocket::SendTo(const void* msg, size_t len, const TAddress& toAddr) const { const auto ret = ::sendto(Descriptor, static_cast<const char*>(msg), int(len), 0, toAddr.SockAddr(), toAddr.Size()); if (ret < 0) return -LastSocketError(); - + return ret; } - + ssize_t TDatagramSocket::RecvFrom(void* buf, size_t len, TAddress& fromAddr) const { socklen_t fromSize = sizeof(::sockaddr_in6); const auto ret = ::recvfrom(Descriptor, static_cast<char*>(buf), int(len), 0, fromAddr.SockAddr(), &fromSize); if (ret < 0) return -LastSocketError(); - + return ret; } - - - // deleter for SSL objects - struct TDeleter { - void operator ()(BIO *bio) const { - BIO_free(bio); - } - - void operator ()(X509 *x509) const { - X509_free(x509); - } - - void operator ()(RSA *rsa) const { - RSA_free(rsa); - } - - void operator ()(SSL_CTX *ctx) const { - SSL_CTX_free(ctx); - } - }; - - class TSecureSocketContext::TImpl { - std::unique_ptr<SSL_CTX, TDeleter> Ctx; - - public: - TImpl(const TString& certificate, const TString& privateKey, const TString& caFilePath, - const TString& ciphers) { + + + // deleter for SSL objects + struct TDeleter { + void operator ()(BIO *bio) const { + BIO_free(bio); + } + + void operator ()(X509 *x509) const { + X509_free(x509); + } + + void operator ()(RSA *rsa) const { + RSA_free(rsa); + } + + void operator ()(SSL_CTX *ctx) const { + SSL_CTX_free(ctx); + } + }; + + class TSecureSocketContext::TImpl { + std::unique_ptr<SSL_CTX, TDeleter> Ctx; + + public: + TImpl(const TString& certificate, const TString& privateKey, const TString& caFilePath, + const TString& ciphers) { int ret; - InitOpenSSL(); + InitOpenSSL(); #if OPENSSL_VERSION_NUMBER < 0x10100000L - Ctx.reset(SSL_CTX_new(TLSv1_2_method())); - Y_VERIFY(Ctx, "SSL_CTX_new() failed"); + Ctx.reset(SSL_CTX_new(TLSv1_2_method())); + Y_VERIFY(Ctx, "SSL_CTX_new() failed"); #else Ctx.reset(SSL_CTX_new(TLS_method())); Y_VERIFY(Ctx, "SSL_CTX_new() failed"); @@ -280,19 +280,19 @@ namespace NInterconnect { ret = SSL_CTX_set_max_proto_version(Ctx.get(), TLS1_2_VERSION); Y_VERIFY(ret == 1, "failed to set max proto version"); #endif - SSL_CTX_set_verify(Ctx.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, &Verify); - SSL_CTX_set_mode(*this, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - - // apply certificates in SSL context - if (certificate) { - std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(certificate.data(), certificate.size())); - Y_VERIFY(bio); + SSL_CTX_set_verify(Ctx.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, &Verify); + SSL_CTX_set_mode(*this, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + + // apply certificates in SSL context + if (certificate) { + std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(certificate.data(), certificate.size())); + Y_VERIFY(bio); // first certificate in the chain is expected to be a leaf - std::unique_ptr<X509, TDeleter> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); - Y_VERIFY(cert, "failed to parse certificate"); - ret = SSL_CTX_use_certificate(Ctx.get(), cert.get()); - Y_VERIFY(ret == 1); + std::unique_ptr<X509, TDeleter> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); + Y_VERIFY(cert, "failed to parse certificate"); + ret = SSL_CTX_use_certificate(Ctx.get(), cert.get()); + Y_VERIFY(ret == 1); // loading additional certificates in the chain, if any while(true) { @@ -304,325 +304,325 @@ namespace NInterconnect { Y_VERIFY(ret == 1); // we must not free memory if certificate was added successfully by SSL_CTX_add0_chain_cert } - } - if (privateKey) { - std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(privateKey.data(), privateKey.size())); - Y_VERIFY(bio); - std::unique_ptr<RSA, TDeleter> pkey(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr)); - Y_VERIFY(pkey); - ret = SSL_CTX_use_RSAPrivateKey(Ctx.get(), pkey.get()); - Y_VERIFY(ret == 1); - } - if (caFilePath) { - ret = SSL_CTX_load_verify_locations(Ctx.get(), caFilePath.data(), nullptr); - Y_VERIFY(ret == 1); - } - - int success = SSL_CTX_set_cipher_list(Ctx.get(), ciphers ? ciphers.data() : "AES128-GCM-SHA256"); - Y_VERIFY(success, "failed to set cipher list"); - } - - operator SSL_CTX*() const { - return Ctx.get(); - } - - static int GetExIndex() { - static int index = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - return index; - } - - private: - static int Verify(int preverify, X509_STORE_CTX *ctx) { - if (!preverify) { - X509 *badCert = X509_STORE_CTX_get_current_cert(ctx); - int err = X509_STORE_CTX_get_error(ctx); - int depth = X509_STORE_CTX_get_error_depth(ctx); - SSL *ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); - TString *errp = static_cast<TString*>(SSL_get_ex_data(ssl, GetExIndex())); - char buffer[1024]; - X509_NAME_oneline(X509_get_subject_name(badCert), buffer, sizeof(buffer)); - TStringBuilder s; - s << "Error during certificate validation" - << " error# " << X509_verify_cert_error_string(err) - << " depth# " << depth - << " cert# " << buffer; - if (err == X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT) { - X509_NAME_oneline(X509_get_issuer_name(badCert), buffer, sizeof(buffer)); - s << " issuer# " << buffer; - } - *errp = s; - } - return preverify; - } - }; - - TSecureSocketContext::TSecureSocketContext(const TString& certificate, const TString& privateKey, - const TString& caFilePath, const TString& ciphers) - : Impl(new TImpl(certificate, privateKey, caFilePath, ciphers)) - {} - - TSecureSocketContext::~TSecureSocketContext() - {} - - class TSecureSocket::TImpl { - SSL *Ssl; - TString ErrorDescription; - bool WantRead_ = false; - bool WantWrite_ = false; - - public: - TImpl(SSL_CTX *ctx, int fd) - : Ssl(SSL_new(ctx)) - { - Y_VERIFY(Ssl, "SSL_new() failed"); - SSL_set_fd(Ssl, fd); - SSL_set_ex_data(Ssl, TSecureSocketContext::TImpl::GetExIndex(), &ErrorDescription); - } - - ~TImpl() { - SSL_free(Ssl); - } - - TString GetErrorStack() { - if (ErrorDescription) { - return ErrorDescription; - } - std::unique_ptr<BIO, int(*)(BIO*)> mem(BIO_new(BIO_s_mem()), BIO_free); - ERR_print_errors(mem.get()); - char *p = nullptr; - auto len = BIO_get_mem_data(mem.get(), &p); - return TString(p, len); - } - - EStatus ConvertResult(int res, TString& err) { - switch (res) { - case SSL_ERROR_NONE: - return EStatus::SUCCESS; - - case SSL_ERROR_WANT_READ: - return EStatus::WANT_READ; - - case SSL_ERROR_WANT_WRITE: - return EStatus::WANT_WRITE; - - case SSL_ERROR_SYSCALL: - err = TStringBuilder() << "syscall error: " << strerror(LastSocketError()) << ": " << GetErrorStack(); - break; - - case SSL_ERROR_ZERO_RETURN: - err = "TLS negotiation failed"; - break; - - case SSL_ERROR_SSL: - err = "SSL error: " + GetErrorStack(); - break; - - default: - err = "unknown OpenSSL error"; - break; - } - return EStatus::ERROR; - } - - enum EConnectState { - CONNECT, - SHUTDOWN, - READ, - } ConnectState = EConnectState::CONNECT; - - EStatus Establish(bool server, bool authOnly, TString& err) { - switch (ConnectState) { - case EConnectState::CONNECT: { - auto callback = server ? SSL_accept : SSL_connect; - const EStatus status = ConvertResult(SSL_get_error(Ssl, callback(Ssl)), err); - if (status != EStatus::SUCCESS || !authOnly) { - return status; - } - ConnectState = EConnectState::SHUTDOWN; - [[fallthrough]]; - } - - case EConnectState::SHUTDOWN: { - const int res = SSL_shutdown(Ssl); - if (res == 1) { - return EStatus::SUCCESS; - } else if (res != 0) { - return ConvertResult(SSL_get_error(Ssl, res), err); - } - ConnectState = EConnectState::READ; - [[fallthrough]]; - } - - case EConnectState::READ: { - char data[256]; - size_t numRead = 0; - const int res = SSL_get_error(Ssl, SSL_read_ex(Ssl, data, sizeof(data), &numRead)); - if (res == SSL_ERROR_ZERO_RETURN) { - return EStatus::SUCCESS; - } else if (res != SSL_ERROR_NONE) { - return ConvertResult(res, err); - } else if (numRead) { - err = "non-zero return from SSL_read_ex: " + ToString(numRead); - return EStatus::ERROR; - } else { - return EStatus::SUCCESS; - } - } - } - Y_FAIL(); - } - - std::optional<std::pair<const void*, size_t>> BlockedSend; - - ssize_t Send(const void* msg, size_t len, TString *err) { - Y_VERIFY(!BlockedSend || *BlockedSend == std::make_pair(msg, len)); - const ssize_t res = Operate(msg, len, &SSL_write_ex, err); - if (res == -EAGAIN) { - BlockedSend.emplace(msg, len); - } else { - BlockedSend.reset(); - } - return res; - } - - std::optional<std::pair<void*, size_t>> BlockedReceive; - - ssize_t Recv(void* msg, size_t len, TString *err) { - Y_VERIFY(!BlockedReceive || *BlockedReceive == std::make_pair(msg, len)); - const ssize_t res = Operate(msg, len, &SSL_read_ex, err); - if (res == -EAGAIN) { - BlockedReceive.emplace(msg, len); - } else { - BlockedReceive.reset(); - } - return res; - } - - TString GetCipherName() const { - return SSL_get_cipher_name(Ssl); - } - - int GetCipherBits() const { - return SSL_get_cipher_bits(Ssl, nullptr); - } - - TString GetProtocolName() const { - return SSL_get_cipher_version(Ssl); - } - - TString GetPeerCommonName() const { - TString res; - if (X509 *cert = SSL_get_peer_certificate(Ssl)) { - char buffer[256]; - memset(buffer, 0, sizeof(buffer)); - if (X509_NAME *name = X509_get_subject_name(cert)) { - X509_NAME_get_text_by_NID(name, NID_commonName, buffer, sizeof(buffer)); - } - X509_free(cert); - res = TString(buffer, strnlen(buffer, sizeof(buffer))); - } - return res; - } - - bool WantRead() const { - return WantRead_; - } - - bool WantWrite() const { - return WantWrite_; - } - - private: - template<typename TBuffer, typename TOp> - ssize_t Operate(TBuffer* buffer, size_t len, TOp&& op, TString *err) { - WantRead_ = WantWrite_ = false; - size_t processed = 0; - int ret = op(Ssl, buffer, len, &processed); - if (ret == 1) { - return processed; - } - switch (const int status = SSL_get_error(Ssl, ret)) { - case SSL_ERROR_ZERO_RETURN: - return 0; - - case SSL_ERROR_WANT_READ: - WantRead_ = true; - return -EAGAIN; - - case SSL_ERROR_WANT_WRITE: - WantWrite_ = true; - return -EAGAIN; - - case SSL_ERROR_SYSCALL: - return -LastSocketError(); - - case SSL_ERROR_SSL: - if (err) { - *err = GetErrorStack(); - } - return -EPROTO; - - default: - Y_FAIL("unexpected SSL_get_error() status# %d", status); - } - } - }; - - TSecureSocket::TSecureSocket(TStreamSocket& socket, TSecureSocketContext::TPtr context) - : TStreamSocket(socket.ReleaseDescriptor()) - , Context(std::move(context)) - , Impl(new TImpl(*Context->Impl, Descriptor)) - {} - - TSecureSocket::~TSecureSocket() - {} - - TSecureSocket::EStatus TSecureSocket::Establish(bool server, bool authOnly, TString& err) const { - return Impl->Establish(server, authOnly, err); - } - - TIntrusivePtr<TStreamSocket> TSecureSocket::Detach() { - return MakeIntrusive<TStreamSocket>(ReleaseDescriptor()); - } - - ssize_t TSecureSocket::Send(const void* msg, size_t len, TString *err) const { - return Impl->Send(msg, len, err); - } - - ssize_t TSecureSocket::Recv(void* msg, size_t len, TString *err) const { - return Impl->Recv(msg, len, err); - } - - ssize_t TSecureSocket::WriteV(const struct iovec* /*iov*/, int /*iovcnt*/) const { - Y_FAIL("unsupported on SSL sockets"); - } - - ssize_t TSecureSocket::ReadV(const struct iovec* /*iov*/, int /*iovcnt*/) const { - Y_FAIL("unsupported on SSL sockets"); - } - - TString TSecureSocket::GetCipherName() const { - return Impl->GetCipherName(); - } - - int TSecureSocket::GetCipherBits() const { - return Impl->GetCipherBits(); - } - - TString TSecureSocket::GetProtocolName() const { - return Impl->GetProtocolName(); - } - - TString TSecureSocket::GetPeerCommonName() const { - return Impl->GetPeerCommonName(); - } - - bool TSecureSocket::WantRead() const { - return Impl->WantRead(); - } - - bool TSecureSocket::WantWrite() const { - return Impl->WantWrite(); - } - -} + } + if (privateKey) { + std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(privateKey.data(), privateKey.size())); + Y_VERIFY(bio); + std::unique_ptr<RSA, TDeleter> pkey(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr)); + Y_VERIFY(pkey); + ret = SSL_CTX_use_RSAPrivateKey(Ctx.get(), pkey.get()); + Y_VERIFY(ret == 1); + } + if (caFilePath) { + ret = SSL_CTX_load_verify_locations(Ctx.get(), caFilePath.data(), nullptr); + Y_VERIFY(ret == 1); + } + + int success = SSL_CTX_set_cipher_list(Ctx.get(), ciphers ? ciphers.data() : "AES128-GCM-SHA256"); + Y_VERIFY(success, "failed to set cipher list"); + } + + operator SSL_CTX*() const { + return Ctx.get(); + } + + static int GetExIndex() { + static int index = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + return index; + } + + private: + static int Verify(int preverify, X509_STORE_CTX *ctx) { + if (!preverify) { + X509 *badCert = X509_STORE_CTX_get_current_cert(ctx); + int err = X509_STORE_CTX_get_error(ctx); + int depth = X509_STORE_CTX_get_error_depth(ctx); + SSL *ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + TString *errp = static_cast<TString*>(SSL_get_ex_data(ssl, GetExIndex())); + char buffer[1024]; + X509_NAME_oneline(X509_get_subject_name(badCert), buffer, sizeof(buffer)); + TStringBuilder s; + s << "Error during certificate validation" + << " error# " << X509_verify_cert_error_string(err) + << " depth# " << depth + << " cert# " << buffer; + if (err == X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT) { + X509_NAME_oneline(X509_get_issuer_name(badCert), buffer, sizeof(buffer)); + s << " issuer# " << buffer; + } + *errp = s; + } + return preverify; + } + }; + + TSecureSocketContext::TSecureSocketContext(const TString& certificate, const TString& privateKey, + const TString& caFilePath, const TString& ciphers) + : Impl(new TImpl(certificate, privateKey, caFilePath, ciphers)) + {} + + TSecureSocketContext::~TSecureSocketContext() + {} + + class TSecureSocket::TImpl { + SSL *Ssl; + TString ErrorDescription; + bool WantRead_ = false; + bool WantWrite_ = false; + + public: + TImpl(SSL_CTX *ctx, int fd) + : Ssl(SSL_new(ctx)) + { + Y_VERIFY(Ssl, "SSL_new() failed"); + SSL_set_fd(Ssl, fd); + SSL_set_ex_data(Ssl, TSecureSocketContext::TImpl::GetExIndex(), &ErrorDescription); + } + + ~TImpl() { + SSL_free(Ssl); + } + + TString GetErrorStack() { + if (ErrorDescription) { + return ErrorDescription; + } + std::unique_ptr<BIO, int(*)(BIO*)> mem(BIO_new(BIO_s_mem()), BIO_free); + ERR_print_errors(mem.get()); + char *p = nullptr; + auto len = BIO_get_mem_data(mem.get(), &p); + return TString(p, len); + } + + EStatus ConvertResult(int res, TString& err) { + switch (res) { + case SSL_ERROR_NONE: + return EStatus::SUCCESS; + + case SSL_ERROR_WANT_READ: + return EStatus::WANT_READ; + + case SSL_ERROR_WANT_WRITE: + return EStatus::WANT_WRITE; + + case SSL_ERROR_SYSCALL: + err = TStringBuilder() << "syscall error: " << strerror(LastSocketError()) << ": " << GetErrorStack(); + break; + + case SSL_ERROR_ZERO_RETURN: + err = "TLS negotiation failed"; + break; + + case SSL_ERROR_SSL: + err = "SSL error: " + GetErrorStack(); + break; + + default: + err = "unknown OpenSSL error"; + break; + } + return EStatus::ERROR; + } + + enum EConnectState { + CONNECT, + SHUTDOWN, + READ, + } ConnectState = EConnectState::CONNECT; + + EStatus Establish(bool server, bool authOnly, TString& err) { + switch (ConnectState) { + case EConnectState::CONNECT: { + auto callback = server ? SSL_accept : SSL_connect; + const EStatus status = ConvertResult(SSL_get_error(Ssl, callback(Ssl)), err); + if (status != EStatus::SUCCESS || !authOnly) { + return status; + } + ConnectState = EConnectState::SHUTDOWN; + [[fallthrough]]; + } + + case EConnectState::SHUTDOWN: { + const int res = SSL_shutdown(Ssl); + if (res == 1) { + return EStatus::SUCCESS; + } else if (res != 0) { + return ConvertResult(SSL_get_error(Ssl, res), err); + } + ConnectState = EConnectState::READ; + [[fallthrough]]; + } + + case EConnectState::READ: { + char data[256]; + size_t numRead = 0; + const int res = SSL_get_error(Ssl, SSL_read_ex(Ssl, data, sizeof(data), &numRead)); + if (res == SSL_ERROR_ZERO_RETURN) { + return EStatus::SUCCESS; + } else if (res != SSL_ERROR_NONE) { + return ConvertResult(res, err); + } else if (numRead) { + err = "non-zero return from SSL_read_ex: " + ToString(numRead); + return EStatus::ERROR; + } else { + return EStatus::SUCCESS; + } + } + } + Y_FAIL(); + } + + std::optional<std::pair<const void*, size_t>> BlockedSend; + + ssize_t Send(const void* msg, size_t len, TString *err) { + Y_VERIFY(!BlockedSend || *BlockedSend == std::make_pair(msg, len)); + const ssize_t res = Operate(msg, len, &SSL_write_ex, err); + if (res == -EAGAIN) { + BlockedSend.emplace(msg, len); + } else { + BlockedSend.reset(); + } + return res; + } + + std::optional<std::pair<void*, size_t>> BlockedReceive; + + ssize_t Recv(void* msg, size_t len, TString *err) { + Y_VERIFY(!BlockedReceive || *BlockedReceive == std::make_pair(msg, len)); + const ssize_t res = Operate(msg, len, &SSL_read_ex, err); + if (res == -EAGAIN) { + BlockedReceive.emplace(msg, len); + } else { + BlockedReceive.reset(); + } + return res; + } + + TString GetCipherName() const { + return SSL_get_cipher_name(Ssl); + } + + int GetCipherBits() const { + return SSL_get_cipher_bits(Ssl, nullptr); + } + + TString GetProtocolName() const { + return SSL_get_cipher_version(Ssl); + } + + TString GetPeerCommonName() const { + TString res; + if (X509 *cert = SSL_get_peer_certificate(Ssl)) { + char buffer[256]; + memset(buffer, 0, sizeof(buffer)); + if (X509_NAME *name = X509_get_subject_name(cert)) { + X509_NAME_get_text_by_NID(name, NID_commonName, buffer, sizeof(buffer)); + } + X509_free(cert); + res = TString(buffer, strnlen(buffer, sizeof(buffer))); + } + return res; + } + + bool WantRead() const { + return WantRead_; + } + + bool WantWrite() const { + return WantWrite_; + } + + private: + template<typename TBuffer, typename TOp> + ssize_t Operate(TBuffer* buffer, size_t len, TOp&& op, TString *err) { + WantRead_ = WantWrite_ = false; + size_t processed = 0; + int ret = op(Ssl, buffer, len, &processed); + if (ret == 1) { + return processed; + } + switch (const int status = SSL_get_error(Ssl, ret)) { + case SSL_ERROR_ZERO_RETURN: + return 0; + + case SSL_ERROR_WANT_READ: + WantRead_ = true; + return -EAGAIN; + + case SSL_ERROR_WANT_WRITE: + WantWrite_ = true; + return -EAGAIN; + + case SSL_ERROR_SYSCALL: + return -LastSocketError(); + + case SSL_ERROR_SSL: + if (err) { + *err = GetErrorStack(); + } + return -EPROTO; + + default: + Y_FAIL("unexpected SSL_get_error() status# %d", status); + } + } + }; + + TSecureSocket::TSecureSocket(TStreamSocket& socket, TSecureSocketContext::TPtr context) + : TStreamSocket(socket.ReleaseDescriptor()) + , Context(std::move(context)) + , Impl(new TImpl(*Context->Impl, Descriptor)) + {} + + TSecureSocket::~TSecureSocket() + {} + + TSecureSocket::EStatus TSecureSocket::Establish(bool server, bool authOnly, TString& err) const { + return Impl->Establish(server, authOnly, err); + } + + TIntrusivePtr<TStreamSocket> TSecureSocket::Detach() { + return MakeIntrusive<TStreamSocket>(ReleaseDescriptor()); + } + + ssize_t TSecureSocket::Send(const void* msg, size_t len, TString *err) const { + return Impl->Send(msg, len, err); + } + + ssize_t TSecureSocket::Recv(void* msg, size_t len, TString *err) const { + return Impl->Recv(msg, len, err); + } + + ssize_t TSecureSocket::WriteV(const struct iovec* /*iov*/, int /*iovcnt*/) const { + Y_FAIL("unsupported on SSL sockets"); + } + + ssize_t TSecureSocket::ReadV(const struct iovec* /*iov*/, int /*iovcnt*/) const { + Y_FAIL("unsupported on SSL sockets"); + } + + TString TSecureSocket::GetCipherName() const { + return Impl->GetCipherName(); + } + + int TSecureSocket::GetCipherBits() const { + return Impl->GetCipherBits(); + } + + TString TSecureSocket::GetProtocolName() const { + return Impl->GetProtocolName(); + } + + TString TSecureSocket::GetPeerCommonName() const { + return Impl->GetPeerCommonName(); + } + + bool TSecureSocket::WantRead() const { + return Impl->WantRead(); + } + + bool TSecureSocket::WantWrite() const { + return Impl->WantWrite(); + } + +} |
