summaryrefslogtreecommitdiffstats
path: root/library/cpp/actors/interconnect/interconnect_stream.cpp
diff options
context:
space:
mode:
authorAlexander Rutkovsky <[email protected]>2022-02-10 16:47:40 +0300
committerDaniil Cherednik <[email protected]>2022-02-10 16:47:40 +0300
commit667a4ee7da2e004784b9c3cfab824a81e96f4d66 (patch)
treec0748b5dcbade83af788c0abfa89c0383d6b779c /library/cpp/actors/interconnect/interconnect_stream.cpp
parentf3646f91e0de459836a7800b9ce3e8dc57a2ab3a (diff)
Restoring authorship annotation for Alexander Rutkovsky <[email protected]>. Commit 2 of 2.
Diffstat (limited to 'library/cpp/actors/interconnect/interconnect_stream.cpp')
-rw-r--r--library/cpp/actors/interconnect/interconnect_stream.cpp922
1 files changed, 461 insertions, 461 deletions
diff --git a/library/cpp/actors/interconnect/interconnect_stream.cpp b/library/cpp/actors/interconnect/interconnect_stream.cpp
index d1cfff2c4a6..158ebc9e1d5 100644
--- a/library/cpp/actors/interconnect/interconnect_stream.cpp
+++ b/library/cpp/actors/interconnect/interconnect_stream.cpp
@@ -1,44 +1,44 @@
-#include "interconnect_stream.h"
-#include "logging.h"
+#include "interconnect_stream.h"
+#include "logging.h"
#include <library/cpp/openssl/init/init.h>
-#include <util/network/socket.h>
-#include <openssl/ssl.h>
-#include <openssl/err.h>
-#include <openssl/pem.h>
-
-#if defined(_win_)
-#include <util/system/file.h>
-#define SOCK_NONBLOCK 0
-#elif defined(_darwin_)
-#define SOCK_NONBLOCK 0
-#else
-#include <sys/un.h>
-#include <sys/stat.h>
-#endif //_win_
-
-#if !defined(_win_)
-#include <sys/ioctl.h>
-#endif
-
-#include <cerrno>
-
-namespace NInterconnect {
+#include <util/network/socket.h>
+#include <openssl/ssl.h>
+#include <openssl/err.h>
+#include <openssl/pem.h>
+
+#if defined(_win_)
+#include <util/system/file.h>
+#define SOCK_NONBLOCK 0
+#elif defined(_darwin_)
+#define SOCK_NONBLOCK 0
+#else
+#include <sys/un.h>
+#include <sys/stat.h>
+#endif //_win_
+
+#if !defined(_win_)
+#include <sys/ioctl.h>
+#endif
+
+#include <cerrno>
+
+namespace NInterconnect {
namespace {
inline int
LastSocketError() {
-#if defined(_win_)
+#if defined(_win_)
return WSAGetLastError();
-#else
+#else
return errno;
-#endif
+#endif
}
}
-
+
TSocket::TSocket(SOCKET fd)
: Descriptor(fd)
{
}
-
+
TSocket::~TSocket() {
if (Descriptor == INVALID_SOCKET) {
return;
@@ -57,30 +57,30 @@ namespace NInterconnect {
default:
Y_FAIL("It's something unexpected");
}
- }
-
+ }
+
int TSocket::GetDescriptor() {
return Descriptor;
- }
-
+ }
+
int
TSocket::Bind(const TAddress& addr) const {
const auto ret = ::bind(Descriptor, addr.SockAddr(), addr.Size());
if (ret < 0)
return -LastSocketError();
-
+
return 0;
}
-
+
int
TSocket::Shutdown(int how) const {
const auto ret = ::shutdown(Descriptor, how);
if (ret < 0)
return -LastSocketError();
-
+
return 0;
}
-
+
int TSocket::GetConnectStatus() const {
int err = 0;
socklen_t len = sizeof(err);
@@ -88,190 +88,190 @@ namespace NInterconnect {
err = LastSocketError();
}
return err;
- }
-
+ }
+
/////////////////////////////////////////////////////////////////
-
- TIntrusivePtr<TStreamSocket> TStreamSocket::Make(int domain) {
- const SOCKET res = ::socket(domain, SOCK_STREAM | SOCK_NONBLOCK, 0);
- if (res == -1) {
- const int err = LastSocketError();
- Y_VERIFY(err != EMFILE && err != ENFILE);
- }
- return MakeIntrusive<TStreamSocket>(res);
+
+ TIntrusivePtr<TStreamSocket> TStreamSocket::Make(int domain) {
+ const SOCKET res = ::socket(domain, SOCK_STREAM | SOCK_NONBLOCK, 0);
+ if (res == -1) {
+ const int err = LastSocketError();
+ Y_VERIFY(err != EMFILE && err != ENFILE);
+ }
+ return MakeIntrusive<TStreamSocket>(res);
}
-
+
TStreamSocket::TStreamSocket(SOCKET fd)
: TSocket(fd)
{
}
-
+
ssize_t
- TStreamSocket::Send(const void* msg, size_t len, TString* /*err*/) const {
- const auto ret = ::send(Descriptor, static_cast<const char*>(msg), int(len), 0);
+ TStreamSocket::Send(const void* msg, size_t len, TString* /*err*/) const {
+ const auto ret = ::send(Descriptor, static_cast<const char*>(msg), int(len), 0);
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
+
ssize_t
- TStreamSocket::Recv(void* buf, size_t len, TString* /*err*/) const {
- const auto ret = ::recv(Descriptor, static_cast<char*>(buf), int(len), 0);
+ TStreamSocket::Recv(void* buf, size_t len, TString* /*err*/) const {
+ const auto ret = ::recv(Descriptor, static_cast<char*>(buf), int(len), 0);
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
+
ssize_t
TStreamSocket::WriteV(const struct iovec* iov, int iovcnt) const {
-#ifndef _win_
+#ifndef _win_
const auto ret = ::writev(Descriptor, iov, iovcnt);
if (ret < 0)
return -LastSocketError();
return ret;
-#else
- Y_FAIL("WriteV() unsupported on Windows");
-#endif
+#else
+ Y_FAIL("WriteV() unsupported on Windows");
+#endif
}
-
+
ssize_t
TStreamSocket::ReadV(const struct iovec* iov, int iovcnt) const {
-#ifndef _win_
+#ifndef _win_
const auto ret = ::readv(Descriptor, iov, iovcnt);
if (ret < 0)
return -LastSocketError();
return ret;
-#else
- Y_FAIL("ReadV() unsupported on Windows");
-#endif
+#else
+ Y_FAIL("ReadV() unsupported on Windows");
+#endif
}
-
+
ssize_t TStreamSocket::GetUnsentQueueSize() const {
int num = -1;
-#ifndef _win_ // we have no means to determine output queue size on Windows
+#ifndef _win_ // we have no means to determine output queue size on Windows
if (ioctl(Descriptor, TIOCOUTQ, &num) == -1) {
num = -1;
}
#endif
return num;
- }
-
+ }
+
int
TStreamSocket::Connect(const TAddress& addr) const {
const auto ret = ::connect(Descriptor, addr.SockAddr(), addr.Size());
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
+
int
TStreamSocket::Connect(const NAddr::IRemoteAddr* addr) const {
const auto ret = ::connect(Descriptor, addr->Addr(), addr->Len());
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
+
int
TStreamSocket::Listen(int backlog) const {
const auto ret = ::listen(Descriptor, backlog);
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
+
int
TStreamSocket::Accept(TAddress& acceptedAddr) const {
socklen_t acceptedSize = sizeof(::sockaddr_in6);
const auto ret = ::accept(Descriptor, acceptedAddr.SockAddr(), &acceptedSize);
if (ret == INVALID_SOCKET)
return -LastSocketError();
-
+
return ret;
}
-
+
void
TStreamSocket::SetSendBufferSize(i32 len) const {
(void)SetSockOpt(Descriptor, SOL_SOCKET, SO_SNDBUF, len);
}
-
- ui32 TStreamSocket::GetSendBufferSize() const {
- ui32 res = 0;
- CheckedGetSockOpt(Descriptor, SOL_SOCKET, SO_SNDBUF, res, "SO_SNDBUF");
- return res;
- }
-
+
+ ui32 TStreamSocket::GetSendBufferSize() const {
+ ui32 res = 0;
+ CheckedGetSockOpt(Descriptor, SOL_SOCKET, SO_SNDBUF, res, "SO_SNDBUF");
+ return res;
+ }
+
//////////////////////////////////////////////////////
-
- TDatagramSocket::TPtr TDatagramSocket::Make(int domain) {
- const SOCKET res = ::socket(domain, SOCK_DGRAM, 0);
- if (res == -1) {
- const int err = LastSocketError();
- Y_VERIFY(err != EMFILE && err != ENFILE);
- }
- return std::make_shared<TDatagramSocket>(res);
+
+ TDatagramSocket::TPtr TDatagramSocket::Make(int domain) {
+ const SOCKET res = ::socket(domain, SOCK_DGRAM, 0);
+ if (res == -1) {
+ const int err = LastSocketError();
+ Y_VERIFY(err != EMFILE && err != ENFILE);
+ }
+ return std::make_shared<TDatagramSocket>(res);
}
-
+
TDatagramSocket::TDatagramSocket(SOCKET fd)
: TSocket(fd)
{
}
-
+
ssize_t
TDatagramSocket::SendTo(const void* msg, size_t len, const TAddress& toAddr) const {
const auto ret = ::sendto(Descriptor, static_cast<const char*>(msg), int(len), 0, toAddr.SockAddr(), toAddr.Size());
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
+
ssize_t
TDatagramSocket::RecvFrom(void* buf, size_t len, TAddress& fromAddr) const {
socklen_t fromSize = sizeof(::sockaddr_in6);
const auto ret = ::recvfrom(Descriptor, static_cast<char*>(buf), int(len), 0, fromAddr.SockAddr(), &fromSize);
if (ret < 0)
return -LastSocketError();
-
+
return ret;
}
-
-
- // deleter for SSL objects
- struct TDeleter {
- void operator ()(BIO *bio) const {
- BIO_free(bio);
- }
-
- void operator ()(X509 *x509) const {
- X509_free(x509);
- }
-
- void operator ()(RSA *rsa) const {
- RSA_free(rsa);
- }
-
- void operator ()(SSL_CTX *ctx) const {
- SSL_CTX_free(ctx);
- }
- };
-
- class TSecureSocketContext::TImpl {
- std::unique_ptr<SSL_CTX, TDeleter> Ctx;
-
- public:
- TImpl(const TString& certificate, const TString& privateKey, const TString& caFilePath,
- const TString& ciphers) {
+
+
+ // deleter for SSL objects
+ struct TDeleter {
+ void operator ()(BIO *bio) const {
+ BIO_free(bio);
+ }
+
+ void operator ()(X509 *x509) const {
+ X509_free(x509);
+ }
+
+ void operator ()(RSA *rsa) const {
+ RSA_free(rsa);
+ }
+
+ void operator ()(SSL_CTX *ctx) const {
+ SSL_CTX_free(ctx);
+ }
+ };
+
+ class TSecureSocketContext::TImpl {
+ std::unique_ptr<SSL_CTX, TDeleter> Ctx;
+
+ public:
+ TImpl(const TString& certificate, const TString& privateKey, const TString& caFilePath,
+ const TString& ciphers) {
int ret;
- InitOpenSSL();
+ InitOpenSSL();
#if OPENSSL_VERSION_NUMBER < 0x10100000L
- Ctx.reset(SSL_CTX_new(TLSv1_2_method()));
- Y_VERIFY(Ctx, "SSL_CTX_new() failed");
+ Ctx.reset(SSL_CTX_new(TLSv1_2_method()));
+ Y_VERIFY(Ctx, "SSL_CTX_new() failed");
#else
Ctx.reset(SSL_CTX_new(TLS_method()));
Y_VERIFY(Ctx, "SSL_CTX_new() failed");
@@ -280,19 +280,19 @@ namespace NInterconnect {
ret = SSL_CTX_set_max_proto_version(Ctx.get(), TLS1_2_VERSION);
Y_VERIFY(ret == 1, "failed to set max proto version");
#endif
- SSL_CTX_set_verify(Ctx.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, &Verify);
- SSL_CTX_set_mode(*this, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
-
- // apply certificates in SSL context
- if (certificate) {
- std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(certificate.data(), certificate.size()));
- Y_VERIFY(bio);
+ SSL_CTX_set_verify(Ctx.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, &Verify);
+ SSL_CTX_set_mode(*this, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
+
+ // apply certificates in SSL context
+ if (certificate) {
+ std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(certificate.data(), certificate.size()));
+ Y_VERIFY(bio);
// first certificate in the chain is expected to be a leaf
- std::unique_ptr<X509, TDeleter> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
- Y_VERIFY(cert, "failed to parse certificate");
- ret = SSL_CTX_use_certificate(Ctx.get(), cert.get());
- Y_VERIFY(ret == 1);
+ std::unique_ptr<X509, TDeleter> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+ Y_VERIFY(cert, "failed to parse certificate");
+ ret = SSL_CTX_use_certificate(Ctx.get(), cert.get());
+ Y_VERIFY(ret == 1);
// loading additional certificates in the chain, if any
while(true) {
@@ -304,325 +304,325 @@ namespace NInterconnect {
Y_VERIFY(ret == 1);
// we must not free memory if certificate was added successfully by SSL_CTX_add0_chain_cert
}
- }
- if (privateKey) {
- std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(privateKey.data(), privateKey.size()));
- Y_VERIFY(bio);
- std::unique_ptr<RSA, TDeleter> pkey(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
- Y_VERIFY(pkey);
- ret = SSL_CTX_use_RSAPrivateKey(Ctx.get(), pkey.get());
- Y_VERIFY(ret == 1);
- }
- if (caFilePath) {
- ret = SSL_CTX_load_verify_locations(Ctx.get(), caFilePath.data(), nullptr);
- Y_VERIFY(ret == 1);
- }
-
- int success = SSL_CTX_set_cipher_list(Ctx.get(), ciphers ? ciphers.data() : "AES128-GCM-SHA256");
- Y_VERIFY(success, "failed to set cipher list");
- }
-
- operator SSL_CTX*() const {
- return Ctx.get();
- }
-
- static int GetExIndex() {
- static int index = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
- return index;
- }
-
- private:
- static int Verify(int preverify, X509_STORE_CTX *ctx) {
- if (!preverify) {
- X509 *badCert = X509_STORE_CTX_get_current_cert(ctx);
- int err = X509_STORE_CTX_get_error(ctx);
- int depth = X509_STORE_CTX_get_error_depth(ctx);
- SSL *ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
- TString *errp = static_cast<TString*>(SSL_get_ex_data(ssl, GetExIndex()));
- char buffer[1024];
- X509_NAME_oneline(X509_get_subject_name(badCert), buffer, sizeof(buffer));
- TStringBuilder s;
- s << "Error during certificate validation"
- << " error# " << X509_verify_cert_error_string(err)
- << " depth# " << depth
- << " cert# " << buffer;
- if (err == X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT) {
- X509_NAME_oneline(X509_get_issuer_name(badCert), buffer, sizeof(buffer));
- s << " issuer# " << buffer;
- }
- *errp = s;
- }
- return preverify;
- }
- };
-
- TSecureSocketContext::TSecureSocketContext(const TString& certificate, const TString& privateKey,
- const TString& caFilePath, const TString& ciphers)
- : Impl(new TImpl(certificate, privateKey, caFilePath, ciphers))
- {}
-
- TSecureSocketContext::~TSecureSocketContext()
- {}
-
- class TSecureSocket::TImpl {
- SSL *Ssl;
- TString ErrorDescription;
- bool WantRead_ = false;
- bool WantWrite_ = false;
-
- public:
- TImpl(SSL_CTX *ctx, int fd)
- : Ssl(SSL_new(ctx))
- {
- Y_VERIFY(Ssl, "SSL_new() failed");
- SSL_set_fd(Ssl, fd);
- SSL_set_ex_data(Ssl, TSecureSocketContext::TImpl::GetExIndex(), &ErrorDescription);
- }
-
- ~TImpl() {
- SSL_free(Ssl);
- }
-
- TString GetErrorStack() {
- if (ErrorDescription) {
- return ErrorDescription;
- }
- std::unique_ptr<BIO, int(*)(BIO*)> mem(BIO_new(BIO_s_mem()), BIO_free);
- ERR_print_errors(mem.get());
- char *p = nullptr;
- auto len = BIO_get_mem_data(mem.get(), &p);
- return TString(p, len);
- }
-
- EStatus ConvertResult(int res, TString& err) {
- switch (res) {
- case SSL_ERROR_NONE:
- return EStatus::SUCCESS;
-
- case SSL_ERROR_WANT_READ:
- return EStatus::WANT_READ;
-
- case SSL_ERROR_WANT_WRITE:
- return EStatus::WANT_WRITE;
-
- case SSL_ERROR_SYSCALL:
- err = TStringBuilder() << "syscall error: " << strerror(LastSocketError()) << ": " << GetErrorStack();
- break;
-
- case SSL_ERROR_ZERO_RETURN:
- err = "TLS negotiation failed";
- break;
-
- case SSL_ERROR_SSL:
- err = "SSL error: " + GetErrorStack();
- break;
-
- default:
- err = "unknown OpenSSL error";
- break;
- }
- return EStatus::ERROR;
- }
-
- enum EConnectState {
- CONNECT,
- SHUTDOWN,
- READ,
- } ConnectState = EConnectState::CONNECT;
-
- EStatus Establish(bool server, bool authOnly, TString& err) {
- switch (ConnectState) {
- case EConnectState::CONNECT: {
- auto callback = server ? SSL_accept : SSL_connect;
- const EStatus status = ConvertResult(SSL_get_error(Ssl, callback(Ssl)), err);
- if (status != EStatus::SUCCESS || !authOnly) {
- return status;
- }
- ConnectState = EConnectState::SHUTDOWN;
- [[fallthrough]];
- }
-
- case EConnectState::SHUTDOWN: {
- const int res = SSL_shutdown(Ssl);
- if (res == 1) {
- return EStatus::SUCCESS;
- } else if (res != 0) {
- return ConvertResult(SSL_get_error(Ssl, res), err);
- }
- ConnectState = EConnectState::READ;
- [[fallthrough]];
- }
-
- case EConnectState::READ: {
- char data[256];
- size_t numRead = 0;
- const int res = SSL_get_error(Ssl, SSL_read_ex(Ssl, data, sizeof(data), &numRead));
- if (res == SSL_ERROR_ZERO_RETURN) {
- return EStatus::SUCCESS;
- } else if (res != SSL_ERROR_NONE) {
- return ConvertResult(res, err);
- } else if (numRead) {
- err = "non-zero return from SSL_read_ex: " + ToString(numRead);
- return EStatus::ERROR;
- } else {
- return EStatus::SUCCESS;
- }
- }
- }
- Y_FAIL();
- }
-
- std::optional<std::pair<const void*, size_t>> BlockedSend;
-
- ssize_t Send(const void* msg, size_t len, TString *err) {
- Y_VERIFY(!BlockedSend || *BlockedSend == std::make_pair(msg, len));
- const ssize_t res = Operate(msg, len, &SSL_write_ex, err);
- if (res == -EAGAIN) {
- BlockedSend.emplace(msg, len);
- } else {
- BlockedSend.reset();
- }
- return res;
- }
-
- std::optional<std::pair<void*, size_t>> BlockedReceive;
-
- ssize_t Recv(void* msg, size_t len, TString *err) {
- Y_VERIFY(!BlockedReceive || *BlockedReceive == std::make_pair(msg, len));
- const ssize_t res = Operate(msg, len, &SSL_read_ex, err);
- if (res == -EAGAIN) {
- BlockedReceive.emplace(msg, len);
- } else {
- BlockedReceive.reset();
- }
- return res;
- }
-
- TString GetCipherName() const {
- return SSL_get_cipher_name(Ssl);
- }
-
- int GetCipherBits() const {
- return SSL_get_cipher_bits(Ssl, nullptr);
- }
-
- TString GetProtocolName() const {
- return SSL_get_cipher_version(Ssl);
- }
-
- TString GetPeerCommonName() const {
- TString res;
- if (X509 *cert = SSL_get_peer_certificate(Ssl)) {
- char buffer[256];
- memset(buffer, 0, sizeof(buffer));
- if (X509_NAME *name = X509_get_subject_name(cert)) {
- X509_NAME_get_text_by_NID(name, NID_commonName, buffer, sizeof(buffer));
- }
- X509_free(cert);
- res = TString(buffer, strnlen(buffer, sizeof(buffer)));
- }
- return res;
- }
-
- bool WantRead() const {
- return WantRead_;
- }
-
- bool WantWrite() const {
- return WantWrite_;
- }
-
- private:
- template<typename TBuffer, typename TOp>
- ssize_t Operate(TBuffer* buffer, size_t len, TOp&& op, TString *err) {
- WantRead_ = WantWrite_ = false;
- size_t processed = 0;
- int ret = op(Ssl, buffer, len, &processed);
- if (ret == 1) {
- return processed;
- }
- switch (const int status = SSL_get_error(Ssl, ret)) {
- case SSL_ERROR_ZERO_RETURN:
- return 0;
-
- case SSL_ERROR_WANT_READ:
- WantRead_ = true;
- return -EAGAIN;
-
- case SSL_ERROR_WANT_WRITE:
- WantWrite_ = true;
- return -EAGAIN;
-
- case SSL_ERROR_SYSCALL:
- return -LastSocketError();
-
- case SSL_ERROR_SSL:
- if (err) {
- *err = GetErrorStack();
- }
- return -EPROTO;
-
- default:
- Y_FAIL("unexpected SSL_get_error() status# %d", status);
- }
- }
- };
-
- TSecureSocket::TSecureSocket(TStreamSocket& socket, TSecureSocketContext::TPtr context)
- : TStreamSocket(socket.ReleaseDescriptor())
- , Context(std::move(context))
- , Impl(new TImpl(*Context->Impl, Descriptor))
- {}
-
- TSecureSocket::~TSecureSocket()
- {}
-
- TSecureSocket::EStatus TSecureSocket::Establish(bool server, bool authOnly, TString& err) const {
- return Impl->Establish(server, authOnly, err);
- }
-
- TIntrusivePtr<TStreamSocket> TSecureSocket::Detach() {
- return MakeIntrusive<TStreamSocket>(ReleaseDescriptor());
- }
-
- ssize_t TSecureSocket::Send(const void* msg, size_t len, TString *err) const {
- return Impl->Send(msg, len, err);
- }
-
- ssize_t TSecureSocket::Recv(void* msg, size_t len, TString *err) const {
- return Impl->Recv(msg, len, err);
- }
-
- ssize_t TSecureSocket::WriteV(const struct iovec* /*iov*/, int /*iovcnt*/) const {
- Y_FAIL("unsupported on SSL sockets");
- }
-
- ssize_t TSecureSocket::ReadV(const struct iovec* /*iov*/, int /*iovcnt*/) const {
- Y_FAIL("unsupported on SSL sockets");
- }
-
- TString TSecureSocket::GetCipherName() const {
- return Impl->GetCipherName();
- }
-
- int TSecureSocket::GetCipherBits() const {
- return Impl->GetCipherBits();
- }
-
- TString TSecureSocket::GetProtocolName() const {
- return Impl->GetProtocolName();
- }
-
- TString TSecureSocket::GetPeerCommonName() const {
- return Impl->GetPeerCommonName();
- }
-
- bool TSecureSocket::WantRead() const {
- return Impl->WantRead();
- }
-
- bool TSecureSocket::WantWrite() const {
- return Impl->WantWrite();
- }
-
-}
+ }
+ if (privateKey) {
+ std::unique_ptr<BIO, TDeleter> bio(BIO_new_mem_buf(privateKey.data(), privateKey.size()));
+ Y_VERIFY(bio);
+ std::unique_ptr<RSA, TDeleter> pkey(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
+ Y_VERIFY(pkey);
+ ret = SSL_CTX_use_RSAPrivateKey(Ctx.get(), pkey.get());
+ Y_VERIFY(ret == 1);
+ }
+ if (caFilePath) {
+ ret = SSL_CTX_load_verify_locations(Ctx.get(), caFilePath.data(), nullptr);
+ Y_VERIFY(ret == 1);
+ }
+
+ int success = SSL_CTX_set_cipher_list(Ctx.get(), ciphers ? ciphers.data() : "AES128-GCM-SHA256");
+ Y_VERIFY(success, "failed to set cipher list");
+ }
+
+ operator SSL_CTX*() const {
+ return Ctx.get();
+ }
+
+ static int GetExIndex() {
+ static int index = SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
+ return index;
+ }
+
+ private:
+ static int Verify(int preverify, X509_STORE_CTX *ctx) {
+ if (!preverify) {
+ X509 *badCert = X509_STORE_CTX_get_current_cert(ctx);
+ int err = X509_STORE_CTX_get_error(ctx);
+ int depth = X509_STORE_CTX_get_error_depth(ctx);
+ SSL *ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
+ TString *errp = static_cast<TString*>(SSL_get_ex_data(ssl, GetExIndex()));
+ char buffer[1024];
+ X509_NAME_oneline(X509_get_subject_name(badCert), buffer, sizeof(buffer));
+ TStringBuilder s;
+ s << "Error during certificate validation"
+ << " error# " << X509_verify_cert_error_string(err)
+ << " depth# " << depth
+ << " cert# " << buffer;
+ if (err == X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT) {
+ X509_NAME_oneline(X509_get_issuer_name(badCert), buffer, sizeof(buffer));
+ s << " issuer# " << buffer;
+ }
+ *errp = s;
+ }
+ return preverify;
+ }
+ };
+
+ TSecureSocketContext::TSecureSocketContext(const TString& certificate, const TString& privateKey,
+ const TString& caFilePath, const TString& ciphers)
+ : Impl(new TImpl(certificate, privateKey, caFilePath, ciphers))
+ {}
+
+ TSecureSocketContext::~TSecureSocketContext()
+ {}
+
+ class TSecureSocket::TImpl {
+ SSL *Ssl;
+ TString ErrorDescription;
+ bool WantRead_ = false;
+ bool WantWrite_ = false;
+
+ public:
+ TImpl(SSL_CTX *ctx, int fd)
+ : Ssl(SSL_new(ctx))
+ {
+ Y_VERIFY(Ssl, "SSL_new() failed");
+ SSL_set_fd(Ssl, fd);
+ SSL_set_ex_data(Ssl, TSecureSocketContext::TImpl::GetExIndex(), &ErrorDescription);
+ }
+
+ ~TImpl() {
+ SSL_free(Ssl);
+ }
+
+ TString GetErrorStack() {
+ if (ErrorDescription) {
+ return ErrorDescription;
+ }
+ std::unique_ptr<BIO, int(*)(BIO*)> mem(BIO_new(BIO_s_mem()), BIO_free);
+ ERR_print_errors(mem.get());
+ char *p = nullptr;
+ auto len = BIO_get_mem_data(mem.get(), &p);
+ return TString(p, len);
+ }
+
+ EStatus ConvertResult(int res, TString& err) {
+ switch (res) {
+ case SSL_ERROR_NONE:
+ return EStatus::SUCCESS;
+
+ case SSL_ERROR_WANT_READ:
+ return EStatus::WANT_READ;
+
+ case SSL_ERROR_WANT_WRITE:
+ return EStatus::WANT_WRITE;
+
+ case SSL_ERROR_SYSCALL:
+ err = TStringBuilder() << "syscall error: " << strerror(LastSocketError()) << ": " << GetErrorStack();
+ break;
+
+ case SSL_ERROR_ZERO_RETURN:
+ err = "TLS negotiation failed";
+ break;
+
+ case SSL_ERROR_SSL:
+ err = "SSL error: " + GetErrorStack();
+ break;
+
+ default:
+ err = "unknown OpenSSL error";
+ break;
+ }
+ return EStatus::ERROR;
+ }
+
+ enum EConnectState {
+ CONNECT,
+ SHUTDOWN,
+ READ,
+ } ConnectState = EConnectState::CONNECT;
+
+ EStatus Establish(bool server, bool authOnly, TString& err) {
+ switch (ConnectState) {
+ case EConnectState::CONNECT: {
+ auto callback = server ? SSL_accept : SSL_connect;
+ const EStatus status = ConvertResult(SSL_get_error(Ssl, callback(Ssl)), err);
+ if (status != EStatus::SUCCESS || !authOnly) {
+ return status;
+ }
+ ConnectState = EConnectState::SHUTDOWN;
+ [[fallthrough]];
+ }
+
+ case EConnectState::SHUTDOWN: {
+ const int res = SSL_shutdown(Ssl);
+ if (res == 1) {
+ return EStatus::SUCCESS;
+ } else if (res != 0) {
+ return ConvertResult(SSL_get_error(Ssl, res), err);
+ }
+ ConnectState = EConnectState::READ;
+ [[fallthrough]];
+ }
+
+ case EConnectState::READ: {
+ char data[256];
+ size_t numRead = 0;
+ const int res = SSL_get_error(Ssl, SSL_read_ex(Ssl, data, sizeof(data), &numRead));
+ if (res == SSL_ERROR_ZERO_RETURN) {
+ return EStatus::SUCCESS;
+ } else if (res != SSL_ERROR_NONE) {
+ return ConvertResult(res, err);
+ } else if (numRead) {
+ err = "non-zero return from SSL_read_ex: " + ToString(numRead);
+ return EStatus::ERROR;
+ } else {
+ return EStatus::SUCCESS;
+ }
+ }
+ }
+ Y_FAIL();
+ }
+
+ std::optional<std::pair<const void*, size_t>> BlockedSend;
+
+ ssize_t Send(const void* msg, size_t len, TString *err) {
+ Y_VERIFY(!BlockedSend || *BlockedSend == std::make_pair(msg, len));
+ const ssize_t res = Operate(msg, len, &SSL_write_ex, err);
+ if (res == -EAGAIN) {
+ BlockedSend.emplace(msg, len);
+ } else {
+ BlockedSend.reset();
+ }
+ return res;
+ }
+
+ std::optional<std::pair<void*, size_t>> BlockedReceive;
+
+ ssize_t Recv(void* msg, size_t len, TString *err) {
+ Y_VERIFY(!BlockedReceive || *BlockedReceive == std::make_pair(msg, len));
+ const ssize_t res = Operate(msg, len, &SSL_read_ex, err);
+ if (res == -EAGAIN) {
+ BlockedReceive.emplace(msg, len);
+ } else {
+ BlockedReceive.reset();
+ }
+ return res;
+ }
+
+ TString GetCipherName() const {
+ return SSL_get_cipher_name(Ssl);
+ }
+
+ int GetCipherBits() const {
+ return SSL_get_cipher_bits(Ssl, nullptr);
+ }
+
+ TString GetProtocolName() const {
+ return SSL_get_cipher_version(Ssl);
+ }
+
+ TString GetPeerCommonName() const {
+ TString res;
+ if (X509 *cert = SSL_get_peer_certificate(Ssl)) {
+ char buffer[256];
+ memset(buffer, 0, sizeof(buffer));
+ if (X509_NAME *name = X509_get_subject_name(cert)) {
+ X509_NAME_get_text_by_NID(name, NID_commonName, buffer, sizeof(buffer));
+ }
+ X509_free(cert);
+ res = TString(buffer, strnlen(buffer, sizeof(buffer)));
+ }
+ return res;
+ }
+
+ bool WantRead() const {
+ return WantRead_;
+ }
+
+ bool WantWrite() const {
+ return WantWrite_;
+ }
+
+ private:
+ template<typename TBuffer, typename TOp>
+ ssize_t Operate(TBuffer* buffer, size_t len, TOp&& op, TString *err) {
+ WantRead_ = WantWrite_ = false;
+ size_t processed = 0;
+ int ret = op(Ssl, buffer, len, &processed);
+ if (ret == 1) {
+ return processed;
+ }
+ switch (const int status = SSL_get_error(Ssl, ret)) {
+ case SSL_ERROR_ZERO_RETURN:
+ return 0;
+
+ case SSL_ERROR_WANT_READ:
+ WantRead_ = true;
+ return -EAGAIN;
+
+ case SSL_ERROR_WANT_WRITE:
+ WantWrite_ = true;
+ return -EAGAIN;
+
+ case SSL_ERROR_SYSCALL:
+ return -LastSocketError();
+
+ case SSL_ERROR_SSL:
+ if (err) {
+ *err = GetErrorStack();
+ }
+ return -EPROTO;
+
+ default:
+ Y_FAIL("unexpected SSL_get_error() status# %d", status);
+ }
+ }
+ };
+
+ TSecureSocket::TSecureSocket(TStreamSocket& socket, TSecureSocketContext::TPtr context)
+ : TStreamSocket(socket.ReleaseDescriptor())
+ , Context(std::move(context))
+ , Impl(new TImpl(*Context->Impl, Descriptor))
+ {}
+
+ TSecureSocket::~TSecureSocket()
+ {}
+
+ TSecureSocket::EStatus TSecureSocket::Establish(bool server, bool authOnly, TString& err) const {
+ return Impl->Establish(server, authOnly, err);
+ }
+
+ TIntrusivePtr<TStreamSocket> TSecureSocket::Detach() {
+ return MakeIntrusive<TStreamSocket>(ReleaseDescriptor());
+ }
+
+ ssize_t TSecureSocket::Send(const void* msg, size_t len, TString *err) const {
+ return Impl->Send(msg, len, err);
+ }
+
+ ssize_t TSecureSocket::Recv(void* msg, size_t len, TString *err) const {
+ return Impl->Recv(msg, len, err);
+ }
+
+ ssize_t TSecureSocket::WriteV(const struct iovec* /*iov*/, int /*iovcnt*/) const {
+ Y_FAIL("unsupported on SSL sockets");
+ }
+
+ ssize_t TSecureSocket::ReadV(const struct iovec* /*iov*/, int /*iovcnt*/) const {
+ Y_FAIL("unsupported on SSL sockets");
+ }
+
+ TString TSecureSocket::GetCipherName() const {
+ return Impl->GetCipherName();
+ }
+
+ int TSecureSocket::GetCipherBits() const {
+ return Impl->GetCipherBits();
+ }
+
+ TString TSecureSocket::GetProtocolName() const {
+ return Impl->GetProtocolName();
+ }
+
+ TString TSecureSocket::GetPeerCommonName() const {
+ return Impl->GetPeerCommonName();
+ }
+
+ bool TSecureSocket::WantRead() const {
+ return Impl->WantRead();
+ }
+
+ bool TSecureSocket::WantWrite() const {
+ return Impl->WantWrite();
+ }
+
+}