aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/skiff/skiff_validator.cpp
blob: a2e9e1db90451e455761a611b60227108a095aac (plain) (tree)



































                                                                                                                  
                  











































                                                                                              
                                                 













                                                     
                                                 







































































































































































































































































                                                                                                              
                               




                                
                                

















                                                                                                                                    
              



                                                                                
#include "skiff.h"
#include "skiff_validator.h"

#include <vector>
#include <stack>

namespace NSkiff {

////////////////////////////////////////////////////////////////////////////////

struct IValidatorNode;

using TValidatorNodeList = std::vector<std::shared_ptr<IValidatorNode>>;
using TSkiffSchemaList = std::vector<std::shared_ptr<TSkiffSchema>>;

static std::shared_ptr<IValidatorNode> CreateUsageValidatorNode(const std::shared_ptr<TSkiffSchema>& skiffSchema);
static TValidatorNodeList CreateUsageValidatorNodeList(const TSkiffSchemaList& skiffSchemaList);

////////////////////////////////////////////////////////////////////////////////

template <typename T>
inline void ThrowUnexpectedParseWrite(T wireType)
{
    ythrow TSkiffException() << "Unexpected parse/write of \"" << ::ToString(wireType) << "\" token";
}

////////////////////////////////////////////////////////////////////////////////

struct IValidatorNode
{
    virtual ~IValidatorNode() = default;

    virtual void OnBegin(TValidatorNodeStack* /*validatorNodeStack*/)
    { }

    virtual void OnChildDone(TValidatorNodeStack* /*validatorNodeStack*/)
    {
        Y_ABORT();
    }

    virtual void OnSimpleType(TValidatorNodeStack* /*validatorNodeStack*/, EWireType wireType)
    {
        ThrowUnexpectedParseWrite(wireType);
    }

    virtual void BeforeVariant8Tag()
    {
        ThrowUnexpectedParseWrite(EWireType::Variant8);
    }

    virtual void OnVariant8Tag(TValidatorNodeStack* /*validatorNodeStack*/, ui8 /*tag*/)
    {
        IValidatorNode::BeforeVariant8Tag();
    }

    virtual void BeforeVariant16Tag()
    {
        ThrowUnexpectedParseWrite(EWireType::Variant16);
    }

    virtual void OnVariant16Tag(TValidatorNodeStack* /*validatorNodeStack*/, ui16 /*tag*/)
    {
        IValidatorNode::BeforeVariant16Tag();
    }
};

////////////////////////////////////////////////////////////////////////////////

class TValidatorNodeStack
{
public:
    explicit TValidatorNodeStack(std::shared_ptr<IValidatorNode> validator)
        : RootValidator_(std::move(validator))
    { }

    void PushValidator(IValidatorNode* validator)
    {
        ValidatorStack_.push(validator);
        validator->OnBegin(this);
    }

    void PopValidator()
    {
        Y_ABORT_UNLESS(!ValidatorStack_.empty());
        ValidatorStack_.pop();
        if (!ValidatorStack_.empty()) {
            ValidatorStack_.top()->OnChildDone(this);
        }
    }

    void PushRootIfRequired()
    {
        if (ValidatorStack_.empty()) {
            PushValidator(RootValidator_.get());
        }
    }

    IValidatorNode* Top() const
    {
        Y_ABORT_UNLESS(!ValidatorStack_.empty());
        return ValidatorStack_.top();
    }

    bool IsFinished() const
    {
        return ValidatorStack_.empty();
    }

private:
    const std::shared_ptr<IValidatorNode> RootValidator_;
    std::stack<IValidatorNode*> ValidatorStack_;
};

////////////////////////////////////////////////////////////////////////////////

class TNothingTypeValidator
    : public IValidatorNode
{
public:
    void OnBegin(TValidatorNodeStack* validatorNodeStack) override
    {
        validatorNodeStack->PopValidator();
    }
};

////////////////////////////////////////////////////////////////////////////////

class TSimpleTypeUsageValidator
    : public IValidatorNode
{
public:
    explicit TSimpleTypeUsageValidator(EWireType type)
        : Type_(type)
    { }

    void OnSimpleType(TValidatorNodeStack* validatorNodeStack, EWireType type) override
    {
        if (type != Type_) {
            ThrowUnexpectedParseWrite(type);
        }
        validatorNodeStack->PopValidator();
    }

private:
    const EWireType Type_;
};

////////////////////////////////////////////////////////////////////////////////

template <typename TTag>
void ValidateVariantTag(TValidatorNodeStack* validatorNodeStack, TTag tag, const TValidatorNodeList& children)
{
    if (tag == EndOfSequenceTag<TTag>()) {
        // Root validator is pushed into the stack before variant tag
        // if the stack is empty.
        validatorNodeStack->PopValidator();
    } else if (tag >= children.size()) {
        ythrow TSkiffException() << "Variant tag \"" << tag << "\" "
            << "exceeds number of children \"" << children.size();
    } else {
        validatorNodeStack->PushValidator(children[tag].get());
    }
}

class TVariant8TypeUsageValidator
    : public IValidatorNode
{
public:
    explicit TVariant8TypeUsageValidator(TValidatorNodeList children)
        : Children_(std::move(children))
    { }

    void BeforeVariant8Tag() override
    { }

    void OnVariant8Tag(TValidatorNodeStack* validatorNodeStack, ui8 tag) override
    {
        ValidateVariantTag(validatorNodeStack, tag, Children_);
    }

    void OnChildDone(TValidatorNodeStack* validatorNodeStack) override
    {
        validatorNodeStack->PopValidator();
    }

private:
    const TValidatorNodeList Children_;
};

////////////////////////////////////////////////////////////////////////////////

class TVariant16TypeUsageValidator
    : public IValidatorNode
{
public:
    explicit TVariant16TypeUsageValidator(TValidatorNodeList children)
        : Children_(std::move(children))
    { }

    void BeforeVariant16Tag() override
    { }

    void OnVariant16Tag(TValidatorNodeStack* validatorNodeStack, ui16 tag) override
    {
        ValidateVariantTag(validatorNodeStack, tag, Children_);
    }

    void OnChildDone(TValidatorNodeStack* validatorNodeStack) override
    {
        validatorNodeStack->PopValidator();
    }

private:
    const TValidatorNodeList Children_;
};

////////////////////////////////////////////////////////////////////////////////

class TRepeatedVariant8TypeUsageValidator
    : public IValidatorNode
{
public:
    explicit TRepeatedVariant8TypeUsageValidator(TValidatorNodeList children)
        : Children_(std::move(children))
    { }

    void BeforeVariant8Tag() override
    { }

    void OnVariant8Tag(TValidatorNodeStack* validatorNodeStack, ui8 tag) override
    {
        ValidateVariantTag(validatorNodeStack, tag, Children_);
    }

    void OnChildDone(TValidatorNodeStack* /*validatorNodeStack*/) override
    { }

private:
    const TValidatorNodeList Children_;
};

////////////////////////////////////////////////////////////////////////////////

class TRepeatedVariant16TypeUsageValidator
    : public IValidatorNode
{
public:
    explicit TRepeatedVariant16TypeUsageValidator(TValidatorNodeList children)
        : Children_(std::move(children))
    { }

    void BeforeVariant16Tag() override
    { }

    void OnVariant16Tag(TValidatorNodeStack* validatorNodeStack, ui16 tag) override
    {
        ValidateVariantTag(validatorNodeStack, tag, Children_);
    }

    void OnChildDone(TValidatorNodeStack* /*validatorNodeStack*/) override
    { }

private:
    const TValidatorNodeList Children_;
};

////////////////////////////////////////////////////////////////////////////////

class TTupleTypeUsageValidator
    : public IValidatorNode
{
public:
    explicit TTupleTypeUsageValidator(TValidatorNodeList children)
        : Children_(std::move(children))
    { }

    void OnBegin(TValidatorNodeStack* validatorNodeStack) override
    {
        Position_ = 0;
        if (!Children_.empty()) {
            validatorNodeStack->PushValidator(Children_[0].get());
        }
    }

    void OnChildDone(TValidatorNodeStack* validatorNodeStack) override
    {
        Position_++;
        if (Position_ < Children_.size()) {
            validatorNodeStack->PushValidator(Children_[Position_].get());
        } else {
            validatorNodeStack->PopValidator();
        }
    }

private:
    const TValidatorNodeList Children_;
    ui32 Position_ = 0;
};

////////////////////////////////////////////////////////////////////////////////

TSkiffValidator::TSkiffValidator(std::shared_ptr<TSkiffSchema> skiffSchema)
    : Context_(std::make_unique<TValidatorNodeStack>(CreateUsageValidatorNode(std::move(skiffSchema))))
{ }

TSkiffValidator::~TSkiffValidator()
{ }

void TSkiffValidator::BeforeVariant8Tag()
{
    Context_->PushRootIfRequired();
    Context_->Top()->BeforeVariant8Tag();
}

void TSkiffValidator::OnVariant8Tag(ui8 tag)
{
    Context_->PushRootIfRequired();
    Context_->Top()->OnVariant8Tag(Context_.get(), tag);
}

void TSkiffValidator::BeforeVariant16Tag()
{
    Context_->PushRootIfRequired();
    Context_->Top()->BeforeVariant16Tag();
}

void TSkiffValidator::OnVariant16Tag(ui16 tag)
{
    Context_->PushRootIfRequired();
    Context_->Top()->OnVariant16Tag(Context_.get(), tag);
}

void TSkiffValidator::OnSimpleType(EWireType value)
{
    Context_->PushRootIfRequired();
    Context_->Top()->OnSimpleType(Context_.get(), value);
}

void TSkiffValidator::ValidateFinished()
{
    if (!Context_->IsFinished()) {
        ythrow TSkiffException() << "Parse/write is not finished";
    }
}

////////////////////////////////////////////////////////////////////////////////

TValidatorNodeList CreateUsageValidatorNodeList(const TSkiffSchemaList& skiffSchemaList)
{
    TValidatorNodeList result;
    result.reserve(skiffSchemaList.size());
    for (const auto& skiffSchema : skiffSchemaList) {
        result.push_back(CreateUsageValidatorNode(skiffSchema));
    }
    return result;
}

std::shared_ptr<IValidatorNode> CreateUsageValidatorNode(const std::shared_ptr<TSkiffSchema>& skiffSchema)
{
    switch (skiffSchema->GetWireType()) {
        case EWireType::Int8:
        case EWireType::Int16:
        case EWireType::Int32:
        case EWireType::Int64:
        case EWireType::Int128:
        case EWireType::Int256:

        case EWireType::Uint8:
        case EWireType::Uint16:
        case EWireType::Uint32:
        case EWireType::Uint64:
        case EWireType::Uint128:
        case EWireType::Uint256:

        case EWireType::Double:
        case EWireType::Boolean:
        case EWireType::String32:
        case EWireType::Yson32:
            return std::make_shared<TSimpleTypeUsageValidator>(skiffSchema->GetWireType());
        case EWireType::Nothing:
            return std::make_shared<TNothingTypeValidator>();
        case EWireType::Tuple:
            return std::make_shared<TTupleTypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren()));
        case EWireType::Variant8:
            return std::make_shared<TVariant8TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren()));
        case EWireType::Variant16:
            return std::make_shared<TVariant16TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren()));
        case EWireType::RepeatedVariant8:
            return std::make_shared<TRepeatedVariant8TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren()));
        case EWireType::RepeatedVariant16:
            return std::make_shared<TRepeatedVariant16TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren()));
    }
    Y_ABORT();
}

////////////////////////////////////////////////////////////////////////////////

} // namespace NSkiff