aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/openssl/io/stream.cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/openssl/io/stream.cpp
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/openssl/io/stream.cpp')
-rw-r--r--library/cpp/openssl/io/stream.cpp329
1 files changed, 329 insertions, 0 deletions
diff --git a/library/cpp/openssl/io/stream.cpp b/library/cpp/openssl/io/stream.cpp
new file mode 100644
index 0000000000..0b4be38c0e
--- /dev/null
+++ b/library/cpp/openssl/io/stream.cpp
@@ -0,0 +1,329 @@
+#include "stream.h"
+
+#include <util/generic/deque.h>
+#include <util/generic/singleton.h>
+#include <util/generic/yexception.h>
+
+#include <library/cpp/openssl/init/init.h>
+#include <library/cpp/openssl/method/io.h>
+#include <library/cpp/resource/resource.h>
+
+#include <openssl/bio.h>
+#include <openssl/ssl.h>
+#include <openssl/err.h>
+#include <openssl/tls1.h>
+#include <openssl/x509v3.h>
+
+using TOptions = TOpenSslClientIO::TOptions;
+
+namespace {
+ struct TSslIO;
+
+ struct TSslInitOnDemand {
+ inline TSslInitOnDemand() {
+ InitOpenSSL();
+ }
+ };
+
+ int GetLastSslError() noexcept {
+ return ERR_peek_last_error();
+ }
+
+ const char* SslErrorText(int error) noexcept {
+ return ERR_error_string(error, nullptr);
+ }
+
+ inline TStringBuf SslLastError() noexcept {
+ return SslErrorText(GetLastSslError());
+ }
+
+ struct TSslError: public yexception {
+ inline TSslError() {
+ *this << SslLastError();
+ }
+ };
+
+ struct TSslDestroy {
+ static inline void Destroy(ssl_ctx_st* ctx) noexcept {
+ SSL_CTX_free(ctx);
+ }
+
+ static inline void Destroy(ssl_st* ssl) noexcept {
+ SSL_free(ssl);
+ }
+
+ static inline void Destroy(bio_st* bio) noexcept {
+ BIO_free(bio);
+ }
+
+ static inline void Destroy(x509_st* x509) noexcept {
+ X509_free(x509);
+ }
+ };
+
+ template <class T>
+ using TSslHolderPtr = THolder<T, TSslDestroy>;
+
+ using TSslContextPtr = TSslHolderPtr<ssl_ctx_st>;
+ using TSslPtr = TSslHolderPtr<ssl_st>;
+ using TBioPtr = TSslHolderPtr<bio_st>;
+ using TX509Ptr = TSslHolderPtr<x509_st>;
+
+ inline TSslContextPtr CreateSslCtx(const ssl_method_st* method) {
+ TSslContextPtr ctx(SSL_CTX_new(method));
+
+ if (!ctx) {
+ ythrow TSslError() << "SSL_CTX_new";
+ }
+
+ SSL_CTX_set_options(ctx.Get(), SSL_OP_NO_SSLv2);
+ SSL_CTX_set_options(ctx.Get(), SSL_OP_NO_SSLv3);
+ SSL_CTX_set_options(ctx.Get(), SSL_OP_MICROSOFT_SESS_ID_BUG);
+ SSL_CTX_set_options(ctx.Get(), SSL_OP_NETSCAPE_CHALLENGE_BUG);
+
+ return ctx;
+ }
+
+ struct TStreamIO : public NOpenSSL::TAbstractIO {
+ inline TStreamIO(IInputStream* in, IOutputStream* out)
+ : In(in)
+ , Out(out)
+ {
+ }
+
+ int Write(const char* data, size_t dlen, size_t* written) override {
+ Out->Write(data, dlen);
+
+ *written = dlen;
+ return 1;
+ }
+
+ int Read(char* data, size_t dlen, size_t* readbytes) override {
+ *readbytes = In->Read(data, dlen);
+ return 1;
+ }
+
+ int Puts(const char* buf) override {
+ Y_UNUSED(buf);
+ return -1;
+ }
+
+ int Gets(char* buf, int size) override {
+ Y_UNUSED(buf);
+ Y_UNUSED(size);
+ return -1;
+ }
+
+ void Flush() override {
+ }
+
+ IInputStream* In;
+ IOutputStream* Out;
+ };
+
+ struct TSslIO: public TSslInitOnDemand, public TOptions {
+ inline TSslIO(IInputStream* in, IOutputStream* out, const TOptions& opts)
+ : TOptions(opts)
+ , Io(in, out)
+ , Ctx(CreateClientContext())
+ , Ssl(ConstructSsl())
+ {
+ Connect();
+ }
+
+ inline TSslContextPtr CreateClientContext() {
+ TSslContextPtr ctx = CreateSslCtx(SSLv23_client_method());
+ if (ClientCert_) {
+ if (!ClientCert_->CertificateFile_ || !ClientCert_->PrivateKeyFile_) {
+ ythrow yexception() << "both client certificate and private key are required";
+ }
+ if (ClientCert_->PrivateKeyPassword_) {
+ SSL_CTX_set_default_passwd_cb(ctx.Get(), [](char* buf, int size, int rwflag, void* userData) -> int {
+ Y_UNUSED(rwflag);
+ auto io = static_cast<TSslIO*>(userData);
+ if (!io) {
+ return -1;
+ }
+ if (size < static_cast<int>(io->ClientCert_->PrivateKeyPassword_.size())) {
+ return -1;
+ }
+ return io->ClientCert_->PrivateKeyPassword_.copy(buf, size, 0);
+ });
+ SSL_CTX_set_default_passwd_cb_userdata(ctx.Get(), this);
+ }
+ if (1 != SSL_CTX_use_certificate_chain_file(ctx.Get(), ClientCert_->CertificateFile_.c_str())) {
+ ythrow TSslError() << "SSL_CTX_use_certificate_chain_file";
+ }
+ if (1 != SSL_CTX_use_PrivateKey_file(ctx.Get(), ClientCert_->PrivateKeyFile_.c_str(), SSL_FILETYPE_PEM)) {
+ ythrow TSslError() << "SSL_CTX_use_PrivateKey_file";
+ }
+ if (1 != SSL_CTX_check_private_key(ctx.Get())) {
+ ythrow TSslError() << "SSL_CTX_check_private_key (client)";
+ }
+ }
+ return ctx;
+ }
+
+ inline TSslPtr ConstructSsl() {
+ TSslPtr ssl(SSL_new(Ctx.Get()));
+
+ if (!ssl) {
+ ythrow TSslError() << "SSL_new";
+ }
+
+ if (VerifyCert_) {
+ InitVerification(ssl.Get());
+ }
+
+ BIO_up_ref(Io); // SSL_set_bio consumes only one reference if rbio and wbio are the same
+ SSL_set_bio(ssl.Get(), Io, Io);
+
+ return ssl;
+ }
+
+ inline void InitVerification(ssl_st* ssl) {
+ X509_VERIFY_PARAM* param = SSL_get0_param(ssl);
+ X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
+ Y_ENSURE(X509_VERIFY_PARAM_set1_host(param, VerifyCert_->Hostname_.data(), VerifyCert_->Hostname_.size()));
+ SSL_set_tlsext_host_name(ssl, VerifyCert_->Hostname_.data()); // TLS extenstion: SNI
+
+ SSL_CTX_set_cert_store(Ctx.Get(), GetBuiltinOpenSslX509Store().Release());
+
+ Y_ENSURE_EX(1 == SSL_CTX_set_default_verify_paths(Ctx.Get()),
+ TSslError());
+ // it is OK to ignore result of SSL_CTX_load_verify_locations():
+ // Dir "/etc/ssl/certs/" may be missing
+ SSL_CTX_load_verify_locations(Ctx.Get(),
+ "/etc/ssl/certs/ca-certificates.crt",
+ "/etc/ssl/certs/");
+
+ SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr);
+ }
+
+ inline void Connect() {
+ if (SSL_connect(Ssl.Get()) != 1) {
+ ythrow TSslError() << "SSL_connect";
+ }
+ }
+
+ inline void Finish() const {
+ SSL_shutdown(Ssl.Get());
+ }
+
+ inline size_t Read(void* buf, size_t len) {
+ const int ret = SSL_read(Ssl.Get(), buf, len);
+
+ if (ret < 0) {
+ ythrow TSslError() << "SSL_read";
+ }
+
+ return ret;
+ }
+
+ inline void Write(const char* buf, size_t len) {
+ while (len) {
+ const int ret = SSL_write(Ssl.Get(), buf, len);
+
+ if (ret < 0) {
+ ythrow TSslError() << "SSL_write";
+ }
+
+ buf += (size_t)ret;
+ len -= (size_t)ret;
+ }
+ }
+
+ TStreamIO Io;
+ TSslContextPtr Ctx;
+ TSslPtr Ssl;
+ };
+}
+
+struct TOpenSslClientIO::TImpl: public TSslIO {
+ inline TImpl(IInputStream* in, IOutputStream* out, const TOptions& opts)
+ : TSslIO(in, out, opts)
+ {
+ }
+};
+
+TOpenSslClientIO::TOpenSslClientIO(IInputStream* in, IOutputStream* out)
+ : Impl_(new TImpl(in, out, TOptions()))
+{
+}
+
+TOpenSslClientIO::TOpenSslClientIO(IInputStream* in, IOutputStream* out, const TOptions& options)
+ : Impl_(new TImpl(in, out, options))
+{
+}
+
+TOpenSslClientIO::~TOpenSslClientIO() {
+ try {
+ Impl_->Finish();
+ } catch (...) {
+ }
+}
+
+void TOpenSslClientIO::DoWrite(const void* buf, size_t len) {
+ Impl_->Write((const char*)buf, len);
+}
+
+size_t TOpenSslClientIO::DoRead(void* buf, size_t len) {
+ return Impl_->Read(buf, len);
+}
+
+namespace NPrivate {
+ void TSslDestroy::Destroy(x509_store_st* x509) noexcept {
+ X509_STORE_free(x509);
+ }
+}
+
+class TBuiltinCerts {
+public:
+ TBuiltinCerts() {
+ TString c = NResource::Find("/builtin/cacert");
+
+ TBioPtr cbio(BIO_new_mem_buf(c.data(), c.size()));
+ Y_ENSURE_EX(cbio, TSslError() << "BIO_new_mem_buf");
+
+ while (true) {
+ TX509Ptr cert(PEM_read_bio_X509(cbio.Get(), nullptr, nullptr, nullptr));
+ if (!cert) {
+ break;
+ }
+ Certs.push_back(std::move(cert));
+ }
+
+ int err = GetLastSslError();
+ if (!Certs.empty() && ERR_GET_LIB(err) == ERR_LIB_PEM && ERR_GET_REASON(err) == PEM_R_NO_START_LINE) {
+ ERR_clear_error();
+ } else {
+ ythrow TSslError() << "can't load provided bundle: " << ERR_reason_error_string(err);
+ }
+
+ Y_ENSURE_EX(!Certs.empty(), TSslError());
+ }
+
+ TOpenSslX509StorePtr GetX509Store() const {
+ TOpenSslX509StorePtr store(X509_STORE_new());
+
+ for (const TX509Ptr& c : Certs) {
+ if (0 == X509_STORE_add_cert(store.Get(), c.Get())) {
+ int err = GetLastSslError();
+ if (ERR_GET_LIB(err) == ERR_LIB_X509 && ERR_GET_REASON(err) == X509_R_CERT_ALREADY_IN_HASH_TABLE) {
+ ERR_clear_error();
+ } else {
+ ythrow TSslError() << "can't load provided bundle: " << ERR_reason_error_string(err);
+ }
+ }
+ }
+
+ return store;
+ }
+
+private:
+ TDeque<TX509Ptr> Certs;
+};
+
+TOpenSslX509StorePtr GetBuiltinOpenSslX509Store() {
+ return Singleton<TBuiltinCerts>()->GetX509Store();
+}