aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/actors/interconnect/interconnect_stream.h
blob: b9ca804e0e5b28b184d5c64befeaa66eaf288eab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#pragma once

#include <util/generic/string.h>
#include <util/generic/noncopyable.h>
#include <util/network/address.h>
#include <util/network/init.h>
#include <util/system/defaults.h>

#include "poller.h"

#include "interconnect_address.h"

#include <memory>

#include <sys/uio.h>

namespace NActors {
    class TPollerToken;
}

namespace NInterconnect {
    class TSocket: public NActors::TSharedDescriptor, public TNonCopyable {
    protected:
        TSocket(SOCKET fd);

        virtual ~TSocket() override;

        SOCKET Descriptor;

        virtual int GetDescriptor() override;

    private:
        friend class TSecureSocket;

        SOCKET ReleaseDescriptor() {
            return std::exchange(Descriptor, INVALID_SOCKET);
        }

    public:
        operator SOCKET() const {
            return Descriptor;
        }

        int Bind(const TAddress& addr) const;
        int Shutdown(int how) const;
        int GetConnectStatus() const;
    };

    class TStreamSocket: public TSocket {
    public:
        TStreamSocket(SOCKET fd);

        static TIntrusivePtr<TStreamSocket> Make(int domain, int *error = nullptr);

        virtual ssize_t Send(const void* msg, size_t len, TString *err = nullptr) const;
        virtual ssize_t Recv(void* buf, size_t len, TString *err = nullptr) const;

        virtual ssize_t WriteV(const struct iovec* iov, int iovcnt) const;
        virtual ssize_t ReadV(const struct iovec* iov, int iovcnt) const;

        int Connect(const TAddress& addr) const;
        int Connect(const NAddr::IRemoteAddr* addr) const;
        int Listen(int backlog) const;
        int Accept(TAddress& acceptedAddr) const;

        ssize_t GetUnsentQueueSize() const;

        void SetSendBufferSize(i32 len) const;
        ui32 GetSendBufferSize() const;

        virtual void Request(NActors::TPollerToken& token, bool read, bool write);
        virtual bool RequestReadNotificationAfterWouldBlock(NActors::TPollerToken& token);
        virtual bool RequestWriteNotificationAfterWouldBlock(NActors::TPollerToken& token);

        virtual size_t ExpectedWriteLength() const;
    };

    class TSecureSocketContext {
        class TImpl;
        THolder<TImpl> Impl;

        friend class TSecureSocket;

    public:
        TSecureSocketContext(const TString& certificate, const TString& privateKey, const TString& caFilePath,
            const TString& ciphers);
        ~TSecureSocketContext();

    public:
        using TPtr = std::shared_ptr<TSecureSocketContext>;
    };

    class TSecureSocket : public TStreamSocket {
        TSecureSocketContext::TPtr Context;

        class TImpl;
        THolder<TImpl> Impl;

    public:
        enum class EStatus {
            SUCCESS,
            ERROR,
            WANT_READ,
            WANT_WRITE,
        };

    public:
        TSecureSocket(TStreamSocket& socket, TSecureSocketContext::TPtr context);
        ~TSecureSocket();

        EStatus Establish(bool server, bool authOnly, TString& err) const;
        TIntrusivePtr<TStreamSocket> Detach();

        ssize_t Send(const void* msg, size_t len, TString *err) const override;
        ssize_t Recv(void* msg, size_t len, TString *err) const override;

        ssize_t WriteV(const struct iovec* iov, int iovcnt) const override;
        ssize_t ReadV(const struct iovec* iov, int iovcnt) const override;

        TString GetCipherName() const;
        int GetCipherBits() const;
        TString GetProtocolName() const;
        TString GetPeerCommonName() const;

        bool WantRead() const;
        bool WantWrite() const;
        void Request(NActors::TPollerToken& token, bool read, bool write) override;
        bool RequestReadNotificationAfterWouldBlock(NActors::TPollerToken& token) override;
        bool RequestWriteNotificationAfterWouldBlock(NActors::TPollerToken& token) override;
        size_t ExpectedWriteLength() const override;
    };

    class TDatagramSocket: public TSocket {
    public:
        typedef std::shared_ptr<TDatagramSocket> TPtr;

        TDatagramSocket(SOCKET fd);

        static TPtr Make(int domain);

        ssize_t SendTo(const void* msg, size_t len, const TAddress& toAddr) const;
        ssize_t RecvFrom(void* buf, size_t len, TAddress& fromAddr) const;
    };

}