aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/openssl/io/stream.cpp
blob: 0b4be38c0e3f24d758c2c186f8286f85e8889260 (plain) (tree)
1
2
3
4
5
6
7
8
9
                   
                               

                                    

                                          


                        
                           










                                            
                                    

                                     
                                                  
                                                
     
                                               








                                               
                                                              

                              
                                                          

                          
                                                          
                          


                                                            

                      
                                                  
 


                                                     
 
                                                                     





                                                        
                                                        




                                                                      







                                                                            
 
                            
         

                                                                       
         

                                            
         


                                                
         
                               
         
                           

                                                             
                                                                                 
                            
                         




                                        































                                                                                                                          





                                                


                                            
                                                                                                    


                       

                                                                                         
                                                                                                                       












                                                                                      































                                                               
                     




                                               
                                                                            



                               
                                                                        


                                           
                                                                                                 
















                                                             










                                                             
                                                          

                                                            
                                                                                    















                                                                                                              
                                                     




















                                                                                                                   
#include "stream.h"

#include <util/generic/deque.h>
#include <util/generic/singleton.h>
#include <util/generic/yexception.h>

#include <library/cpp/openssl/init/init.h>
#include <library/cpp/openssl/method/io.h>
#include <library/cpp/resource/resource.h>

#include <openssl/bio.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/tls1.h>
#include <openssl/x509v3.h>

using TOptions = TOpenSslClientIO::TOptions;

namespace {
    struct TSslIO;

    struct TSslInitOnDemand {
        inline TSslInitOnDemand() {
            InitOpenSSL();
        }
    };

    int GetLastSslError() noexcept {
        return ERR_peek_last_error();
    }

    const char* SslErrorText(int error) noexcept {
        return ERR_error_string(error, nullptr);
    }

    inline TStringBuf SslLastError() noexcept {
        return SslErrorText(GetLastSslError());
    }

    struct TSslError: public yexception {
        inline TSslError() {
            *this << SslLastError();
        }
    };

    struct TSslDestroy {
        static inline void Destroy(ssl_ctx_st* ctx) noexcept {
            SSL_CTX_free(ctx);
        }

        static inline void Destroy(ssl_st* ssl) noexcept {
            SSL_free(ssl);
        }

        static inline void Destroy(bio_st* bio) noexcept {
            BIO_free(bio);
        }

        static inline void Destroy(x509_st* x509) noexcept {
            X509_free(x509);
        }
    };

    template <class T>
    using TSslHolderPtr = THolder<T, TSslDestroy>;

    using TSslContextPtr = TSslHolderPtr<ssl_ctx_st>;
    using TSslPtr = TSslHolderPtr<ssl_st>;
    using TBioPtr = TSslHolderPtr<bio_st>;
    using TX509Ptr = TSslHolderPtr<x509_st>;

    inline TSslContextPtr CreateSslCtx(const ssl_method_st* method) {
        TSslContextPtr ctx(SSL_CTX_new(method));

        if (!ctx) {
            ythrow TSslError() << "SSL_CTX_new";
        }

        SSL_CTX_set_options(ctx.Get(), SSL_OP_NO_SSLv2);
        SSL_CTX_set_options(ctx.Get(), SSL_OP_NO_SSLv3);
        SSL_CTX_set_options(ctx.Get(), SSL_OP_MICROSOFT_SESS_ID_BUG);
        SSL_CTX_set_options(ctx.Get(), SSL_OP_NETSCAPE_CHALLENGE_BUG);

        return ctx;
    }

    struct TStreamIO : public NOpenSSL::TAbstractIO {
        inline TStreamIO(IInputStream* in, IOutputStream* out)
            : In(in)
            , Out(out)
        {
        }

        int Write(const char* data, size_t dlen, size_t* written) override {
            Out->Write(data, dlen);

            *written = dlen;
            return 1;
        }

        int Read(char* data, size_t dlen, size_t* readbytes) override {
            *readbytes = In->Read(data, dlen);
            return 1;
        }

        int Puts(const char* buf) override {
            Y_UNUSED(buf);
            return -1;
        }

        int Gets(char* buf, int size) override {
            Y_UNUSED(buf);
            Y_UNUSED(size);
            return -1;
        }

        void Flush() override {
        }

        IInputStream* In;
        IOutputStream* Out;
    };

    struct TSslIO: public TSslInitOnDemand, public TOptions {
        inline TSslIO(IInputStream* in, IOutputStream* out, const TOptions& opts)
            : TOptions(opts)
            , Io(in, out)
            , Ctx(CreateClientContext())
            , Ssl(ConstructSsl())
        {
            Connect();
        }

        inline TSslContextPtr CreateClientContext() {
            TSslContextPtr ctx = CreateSslCtx(SSLv23_client_method());
            if (ClientCert_) {
                if (!ClientCert_->CertificateFile_ || !ClientCert_->PrivateKeyFile_) {
                    ythrow yexception() << "both client certificate and private key are required";
                }
                if (ClientCert_->PrivateKeyPassword_) {
                    SSL_CTX_set_default_passwd_cb(ctx.Get(), [](char* buf, int size, int rwflag, void* userData) -> int {
                        Y_UNUSED(rwflag);
                        auto io = static_cast<TSslIO*>(userData);
                        if (!io) {
                            return -1;
                        }
                        if (size < static_cast<int>(io->ClientCert_->PrivateKeyPassword_.size())) {
                            return -1;
                        }
                        return io->ClientCert_->PrivateKeyPassword_.copy(buf, size, 0);
                    });
                    SSL_CTX_set_default_passwd_cb_userdata(ctx.Get(), this);
                }
                if (1 != SSL_CTX_use_certificate_chain_file(ctx.Get(), ClientCert_->CertificateFile_.c_str())) {
                    ythrow TSslError() << "SSL_CTX_use_certificate_chain_file";
                }
                if (1 != SSL_CTX_use_PrivateKey_file(ctx.Get(), ClientCert_->PrivateKeyFile_.c_str(), SSL_FILETYPE_PEM)) {
                    ythrow TSslError() << "SSL_CTX_use_PrivateKey_file";
                }
                if (1 != SSL_CTX_check_private_key(ctx.Get())) {
                    ythrow TSslError() << "SSL_CTX_check_private_key (client)";
                }
            }
            return ctx;
        }

        inline TSslPtr ConstructSsl() {
            TSslPtr ssl(SSL_new(Ctx.Get()));

            if (!ssl) {
                ythrow TSslError() << "SSL_new";
            }

            if (VerifyCert_) {
                InitVerification(ssl.Get());
            }

            BIO_up_ref(Io); // SSL_set_bio consumes only one reference if rbio and wbio are the same
            SSL_set_bio(ssl.Get(), Io, Io);

            return ssl;
        }

        inline void InitVerification(ssl_st* ssl) {
            X509_VERIFY_PARAM* param = SSL_get0_param(ssl);
            X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
            Y_ENSURE(X509_VERIFY_PARAM_set1_host(param, VerifyCert_->Hostname_.data(), VerifyCert_->Hostname_.size()));
            SSL_set_tlsext_host_name(ssl, VerifyCert_->Hostname_.data()); // TLS extenstion: SNI

            SSL_CTX_set_cert_store(Ctx.Get(), GetBuiltinOpenSslX509Store().Release());

            Y_ENSURE_EX(1 == SSL_CTX_set_default_verify_paths(Ctx.Get()),
                        TSslError());
            // it is OK to ignore result of SSL_CTX_load_verify_locations():
            // Dir "/etc/ssl/certs/" may be missing
            SSL_CTX_load_verify_locations(Ctx.Get(),
                                          "/etc/ssl/certs/ca-certificates.crt",
                                          "/etc/ssl/certs/");

            SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr);
        }

        inline void Connect() {
            if (SSL_connect(Ssl.Get()) != 1) {
                ythrow TSslError() << "SSL_connect";
            }
        }

        inline void Finish() const {
            SSL_shutdown(Ssl.Get());
        }

        inline size_t Read(void* buf, size_t len) {
            const int ret = SSL_read(Ssl.Get(), buf, len);

            if (ret < 0) {
                ythrow TSslError() << "SSL_read";
            }

            return ret;
        }

        inline void Write(const char* buf, size_t len) {
            while (len) {
                const int ret = SSL_write(Ssl.Get(), buf, len);

                if (ret < 0) {
                    ythrow TSslError() << "SSL_write";
                }

                buf += (size_t)ret;
                len -= (size_t)ret;
            }
        }

        TStreamIO Io;
        TSslContextPtr Ctx;
        TSslPtr Ssl;
    };
}

struct TOpenSslClientIO::TImpl: public TSslIO {
    inline TImpl(IInputStream* in, IOutputStream* out, const TOptions& opts)
        : TSslIO(in, out, opts)
    {
    }
};

TOpenSslClientIO::TOpenSslClientIO(IInputStream* in, IOutputStream* out)
    : Impl_(new TImpl(in, out, TOptions()))
{
}

TOpenSslClientIO::TOpenSslClientIO(IInputStream* in, IOutputStream* out, const TOptions& options)
    : Impl_(new TImpl(in, out, options))
{
}

TOpenSslClientIO::~TOpenSslClientIO() {
    try {
        Impl_->Finish();
    } catch (...) {
    }
}

void TOpenSslClientIO::DoWrite(const void* buf, size_t len) {
    Impl_->Write((const char*)buf, len);
}

size_t TOpenSslClientIO::DoRead(void* buf, size_t len) {
    return Impl_->Read(buf, len);
}

namespace NPrivate {
    void TSslDestroy::Destroy(x509_store_st* x509) noexcept {
        X509_STORE_free(x509);
    }
}

class TBuiltinCerts {
public:
    TBuiltinCerts() {
        TString c = NResource::Find("/builtin/cacert");

        TBioPtr cbio(BIO_new_mem_buf(c.data(), c.size()));
        Y_ENSURE_EX(cbio, TSslError() << "BIO_new_mem_buf");

        while (true) {
            TX509Ptr cert(PEM_read_bio_X509(cbio.Get(), nullptr, nullptr, nullptr));
            if (!cert) {
                break;
            }
            Certs.push_back(std::move(cert));
        }

        int err = GetLastSslError();
        if (!Certs.empty() && ERR_GET_LIB(err) == ERR_LIB_PEM && ERR_GET_REASON(err) == PEM_R_NO_START_LINE) {
            ERR_clear_error();
        } else {
            ythrow TSslError() << "can't load provided bundle: " << ERR_reason_error_string(err);
        }

        Y_ENSURE_EX(!Certs.empty(), TSslError());
    }

    TOpenSslX509StorePtr GetX509Store() const {
        TOpenSslX509StorePtr store(X509_STORE_new());

        for (const TX509Ptr& c : Certs) {
            if (0 == X509_STORE_add_cert(store.Get(), c.Get())) {
                int err = GetLastSslError();
                if (ERR_GET_LIB(err) == ERR_LIB_X509 && ERR_GET_REASON(err) == X509_R_CERT_ALREADY_IN_HASH_TABLE) {
                    ERR_clear_error();
                } else {
                    ythrow TSslError() << "can't load provided bundle: " << ERR_reason_error_string(err);
                }
            }
        }

        return store;
    }

private:
    TDeque<TX509Ptr> Certs;
};

TOpenSslX509StorePtr GetBuiltinOpenSslX509Store() {
    return Singleton<TBuiltinCerts>()->GetX509Store();
}