summaryrefslogtreecommitdiffstats
path: root/yt/cpp/mapreduce/common/abortable_stream.cpp
blob: 6336d792f1438652cfbfc42a4bd4ce81417b43cc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "abortable_stream.h"

#include <yt/cpp/mapreduce/interface/errors.h>
#include <yt/cpp/mapreduce/interface/logging/yt_log.h>

#include <yt/yt/core/concurrency/async_stream.h>
#include <yt/yt/core/concurrency/scheduler_api.h>

#include <library/cpp/yt/logging/logger.h>
#include <library/cpp/yt/memory/ref.h>

#include <util/system/spinlock.h>

namespace NYT::NDetail {

using namespace NConcurrency;

////////////////////////////////////////////////////////////////////////////////

class TAbortableInputStreamAdapter
    : public IAbortableInputStream
{
public:
    explicit TAbortableInputStreamAdapter(IAsyncInputStreamPtr underlyingStream)
        : UnderlyingStream_(std::move(underlyingStream))
    { }

    void Abort() override
    {
        auto guard = Guard(Lock_);
        IsAborted_ = true;
        CurrentFuture_.Cancel(TError("Stream was aborted"));
    }

    bool IsAborted() const override
    {
        auto guard = Guard(Lock_);
        return IsAborted_;
    }

private:
    const IAsyncInputStreamPtr UnderlyingStream_;

    TAdaptiveLock Lock_;
    bool IsAborted_ = false;
    TFuture<size_t> CurrentFuture_;

    size_t DoRead(void* buffer, size_t length) override
    {
        if (length == 0) {
            return 0;
        }

        struct TAbortableInputStreamBufferTag { };
        auto readBuffer = TSharedMutableRef::Allocate<TAbortableInputStreamBufferTag>(length);

        auto future = UnderlyingStream_->Read(readBuffer);
        {
            auto guard = Guard(Lock_);
            CurrentFuture_ = future;
            if (IsAborted_) {
                future.Cancel(TError("Stream was aborted"));
            }
        }

        auto result = WaitFor(future);

        {
            auto guard = Guard(Lock_);
            if (IsAborted_) {
                ythrow TInputStreamAbortedError() << "Stream was aborted";
            }
        }

        auto bytesRead = result.ValueOrThrow();

        memcpy(buffer, readBuffer.Begin(), bytesRead);

        return bytesRead;
    }
};

std::unique_ptr<IAbortableInputStream> CreateAbortableInputStreamAdapter(
    IAsyncInputStreamPtr underlyingStream)
{
    YT_VERIFY(underlyingStream);
    return std::make_unique<TAbortableInputStreamAdapter>(
        std::move(underlyingStream));
}

////////////////////////////////////////////////////////////////////////////////

class TAbortableStreamFallback
    : public IAbortableInputStream
{
public:
    explicit TAbortableStreamFallback(IInputStream* underlyingStream)
        : UnderlyingStream_(underlyingStream)
    { }

    void Abort() override
    {
        YT_LOG_WARNING("Abort for this stream type is not supported");
    }

private:
    IInputStream* const UnderlyingStream_;

    size_t DoRead(void* buffer, size_t length) override
    {
        return UnderlyingStream_->Read(buffer, length);
    }
};

std::unique_ptr<IAbortableInputStream> CreateAbortableInputStreamAdapterFallback(
    IInputStream* underlyingStream)
{
    YT_VERIFY(underlyingStream);
    return std::make_unique<TAbortableStreamFallback>(underlyingStream);
}

////////////////////////////////////////////////////////////////////////////////

} // namespace NYT::NDetail