#include "halting_stream.h" #include #include namespace NYT::NDetail { using namespace NConcurrency; //////////////////////////////////////////////////////////////////////////////// class THaltingAsyncStream : public IAsyncInputStream { public: THaltingAsyncStream( IAsyncInputStreamPtr underlying, i64 bytesBeforeHalt) : Underlying_(std::move(underlying)) , BytesBeforeHalt_(bytesBeforeHalt) { } private: void OnRead(TPromise promise, const TErrorOr& result) { if (result.IsOK()) { BytesRead_ += result.Value(); } promise.TrySet(result); } TFuture Read(const TSharedMutableRef& buffer) override { if (BytesRead_ >= BytesBeforeHalt_) { HaltPromise_ = NewPromise(); return HaltPromise_.ToFuture(); } auto limit = std::min(buffer.Size(), static_cast(BytesBeforeHalt_ - BytesRead_)); auto promise = NewPromise(); auto future = promise.ToFuture(); Underlying_->Read(buffer.Slice(0, limit)).Subscribe( BIND(&THaltingAsyncStream::OnRead, MakeStrong(this), std::move(promise))); return future; } private: IAsyncInputStreamPtr Underlying_; const i64 BytesBeforeHalt_; i64 BytesRead_ = 0; TPromise HaltPromise_; }; //////////////////////////////////////////////////////////////////////////////// IAsyncInputStreamPtr CreateHaltingAsyncStream( IAsyncInputStreamPtr underlying, i64 bytesBeforeHalt) { return New(std::move(underlying), bytesBeforeHalt); } //////////////////////////////////////////////////////////////////////////////// } // namespace NYT::NDetail