aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Processors/Formats/OutputFormatWithUTF8ValidationAdaptor.h
blob: 4c5c3ef72e9da4fa4d51d36f86948cfa04ef0a38 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#pragma once

#include <Processors/Formats/IOutputFormat.h>
#include <Processors/Formats/IRowOutputFormat.h>

#include <IO/WriteBuffer.h>
#include <IO/WriteBufferValidUTF8.h>

#include <Common/logger_useful.h>

namespace DB
{

template <typename Base>
class OutputFormatWithUTF8ValidationAdaptorBase : public Base
{
public:
    OutputFormatWithUTF8ValidationAdaptorBase(const Block & header, WriteBuffer & out_, bool validate_utf8)
        : Base(header, out_)
    {
        bool values_can_contain_invalid_utf8 = false;
        for (const auto & type : this->getPort(IOutputFormat::PortKind::Main).getHeader().getDataTypes())
        {
            if (!type->textCanContainOnlyValidUTF8())
                values_can_contain_invalid_utf8 = true;
        }

        if (validate_utf8 && values_can_contain_invalid_utf8)
            validating_ostr = std::make_unique<WriteBufferValidUTF8>(*Base::getWriteBufferPtr());
    }

    void flush() override
    {
        if (validating_ostr)
            validating_ostr->next();
        Base::flush();
    }

    void finalizeBuffers() override
    {
        if (validating_ostr)
            validating_ostr->finalize();
        Base::finalizeBuffers();
    }

    void resetFormatterImpl() override
    {
        LOG_DEBUG(&Poco::Logger::get("RowOutputFormatWithExceptionHandlerAdaptor"), "resetFormatterImpl");
        Base::resetFormatterImpl();
        if (validating_ostr)
            validating_ostr = std::make_unique<WriteBufferValidUTF8>(*Base::getWriteBufferPtr());
    }

protected:
    /// Returns buffer that should be used in derived classes instead of out.
    WriteBuffer * getWriteBufferPtr() override
    {
        if (validating_ostr)
            return validating_ostr.get();
        return Base::getWriteBufferPtr();
    }

private:
    /// Validates UTF-8 sequences, replaces bad sequences with replacement character.
    std::unique_ptr<WriteBuffer> validating_ostr;
};

using OutputFormatWithUTF8ValidationAdaptor = OutputFormatWithUTF8ValidationAdaptorBase<IOutputFormat>;
using RowOutputFormatWithUTF8ValidationAdaptor = OutputFormatWithUTF8ValidationAdaptorBase<IRowOutputFormat>;

}