diff options
author | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
---|---|---|
committer | qrort <qrort@yandex-team.com> | 2022-11-30 23:47:12 +0300 |
commit | 22f8ae0e3f5d68b92aecccdf96c1d841a0334311 (patch) | |
tree | bffa27765faf54126ad44bcafa89fadecb7a73d7 /library/cpp/skiff/skiff_validator.cpp | |
parent | 332b99e2173f0425444abb759eebcb2fafaa9209 (diff) | |
download | ydb-22f8ae0e3f5d68b92aecccdf96c1d841a0334311.tar.gz |
validate canons without yatest_common
Diffstat (limited to 'library/cpp/skiff/skiff_validator.cpp')
-rw-r--r-- | library/cpp/skiff/skiff_validator.cpp | 396 |
1 files changed, 396 insertions, 0 deletions
diff --git a/library/cpp/skiff/skiff_validator.cpp b/library/cpp/skiff/skiff_validator.cpp new file mode 100644 index 0000000000..1b1b98d5a6 --- /dev/null +++ b/library/cpp/skiff/skiff_validator.cpp @@ -0,0 +1,396 @@ +#include "skiff.h" +#include "skiff_validator.h" + +#include <vector> +#include <stack> + +namespace NSkiff { + +//////////////////////////////////////////////////////////////////////////////// + +struct IValidatorNode; + +using TValidatorNodeList = std::vector<std::shared_ptr<IValidatorNode>>; +using TSkiffSchemaList = std::vector<std::shared_ptr<TSkiffSchema>>; + +static std::shared_ptr<IValidatorNode> CreateUsageValidatorNode(const std::shared_ptr<TSkiffSchema>& skiffSchema); +static TValidatorNodeList CreateUsageValidatorNodeList(const TSkiffSchemaList& skiffSchemaList); + +//////////////////////////////////////////////////////////////////////////////// + +template <typename T> +inline void ThrowUnexpectedParseWrite(T wireType) +{ + ythrow TSkiffException() << "Unexpected parse/write of \"" << ::ToString(wireType) << "\" token"; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct IValidatorNode +{ + virtual ~IValidatorNode() = default; + + virtual void OnBegin(TValidatorNodeStack* /*validatorNodeStack*/) + { } + + virtual void OnChildDone(TValidatorNodeStack* /*validatorNodeStack*/) + { + Y_FAIL(); + } + + virtual void OnSimpleType(TValidatorNodeStack* /*validatorNodeStack*/, EWireType wireType) + { + ThrowUnexpectedParseWrite(wireType); + } + + virtual void BeforeVariant8Tag() + { + ThrowUnexpectedParseWrite(EWireType::Variant8); + } + + virtual void OnVariant8Tag(TValidatorNodeStack* /*validatorNodeStack*/, ui8 /*tag*/) + { + IValidatorNode::BeforeVariant8Tag(); + } + + virtual void BeforeVariant16Tag() + { + ThrowUnexpectedParseWrite(EWireType::Variant16); + } + + virtual void OnVariant16Tag(TValidatorNodeStack* /*validatorNodeStack*/, ui16 /*tag*/) + { + IValidatorNode::BeforeVariant16Tag(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TValidatorNodeStack +{ +public: + explicit TValidatorNodeStack(std::shared_ptr<IValidatorNode> validator) + : RootValidator_(std::move(validator)) + { } + + void PushValidator(IValidatorNode* validator) + { + ValidatorStack_.push(validator); + validator->OnBegin(this); + } + + void PopValidator() + { + Y_VERIFY(!ValidatorStack_.empty()); + ValidatorStack_.pop(); + if (!ValidatorStack_.empty()) { + ValidatorStack_.top()->OnChildDone(this); + } + } + + void PushRootIfRequired() + { + if (ValidatorStack_.empty()) { + PushValidator(RootValidator_.get()); + } + } + + IValidatorNode* Top() const + { + Y_VERIFY(!ValidatorStack_.empty()); + return ValidatorStack_.top(); + } + + bool IsFinished() const + { + return ValidatorStack_.empty(); + } + +private: + const std::shared_ptr<IValidatorNode> RootValidator_; + std::stack<IValidatorNode*> ValidatorStack_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TNothingTypeValidator + : public IValidatorNode +{ +public: + void OnBegin(TValidatorNodeStack* validatorNodeStack) override + { + validatorNodeStack->PopValidator(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TSimpleTypeUsageValidator + : public IValidatorNode +{ +public: + explicit TSimpleTypeUsageValidator(EWireType type) + : Type_(type) + { } + + void OnSimpleType(TValidatorNodeStack* validatorNodeStack, EWireType type) override + { + if (type != Type_) { + ThrowUnexpectedParseWrite(type); + } + validatorNodeStack->PopValidator(); + } + +private: + const EWireType Type_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template <typename TTag> +void ValidateVariantTag(TValidatorNodeStack* validatorNodeStack, TTag tag, const TValidatorNodeList& children) +{ + if (tag == EndOfSequenceTag<TTag>()) { + // Root validator is pushed into the stack before variant tag + // if the stack is empty. + validatorNodeStack->PopValidator(); + } else if (tag >= children.size()) { + ythrow TSkiffException() << "Variant tag \"" << tag << "\" " + << "exceeds number of children \"" << children.size(); + } else { + validatorNodeStack->PushValidator(children[tag].get()); + } +} + +class TVariant8TypeUsageValidator + : public IValidatorNode +{ +public: + explicit TVariant8TypeUsageValidator(TValidatorNodeList children) + : Children_(std::move(children)) + { } + + void BeforeVariant8Tag() override + { } + + void OnVariant8Tag(TValidatorNodeStack* validatorNodeStack, ui8 tag) override + { + ValidateVariantTag(validatorNodeStack, tag, Children_); + } + + void OnChildDone(TValidatorNodeStack* validatorNodeStack) override + { + validatorNodeStack->PopValidator(); + } + +private: + const TValidatorNodeList Children_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TVariant16TypeUsageValidator + : public IValidatorNode +{ +public: + explicit TVariant16TypeUsageValidator(TValidatorNodeList children) + : Children_(std::move(children)) + { } + + void BeforeVariant16Tag() override + { } + + void OnVariant16Tag(TValidatorNodeStack* validatorNodeStack, ui16 tag) override + { + ValidateVariantTag(validatorNodeStack, tag, Children_); + } + + void OnChildDone(TValidatorNodeStack* validatorNodeStack) override + { + validatorNodeStack->PopValidator(); + } + +private: + const TValidatorNodeList Children_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TRepeatedVariant8TypeUsageValidator + : public IValidatorNode +{ +public: + explicit TRepeatedVariant8TypeUsageValidator(TValidatorNodeList children) + : Children_(std::move(children)) + { } + + void BeforeVariant8Tag() override + { } + + void OnVariant8Tag(TValidatorNodeStack* validatorNodeStack, ui8 tag) override + { + ValidateVariantTag(validatorNodeStack, tag, Children_); + } + + void OnChildDone(TValidatorNodeStack* /*validatorNodeStack*/) override + { } + +private: + const TValidatorNodeList Children_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TRepeatedVariant16TypeUsageValidator + : public IValidatorNode +{ +public: + explicit TRepeatedVariant16TypeUsageValidator(TValidatorNodeList children) + : Children_(std::move(children)) + { } + + void BeforeVariant16Tag() override + { } + + void OnVariant16Tag(TValidatorNodeStack* validatorNodeStack, ui16 tag) override + { + ValidateVariantTag(validatorNodeStack, tag, Children_); + } + + void OnChildDone(TValidatorNodeStack* /*validatorNodeStack*/) override + { } + +private: + const TValidatorNodeList Children_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TTupleTypeUsageValidator + : public IValidatorNode +{ +public: + explicit TTupleTypeUsageValidator(TValidatorNodeList children) + : Children_(std::move(children)) + { } + + void OnBegin(TValidatorNodeStack* validatorNodeStack) override + { + Position_ = 0; + if (!Children_.empty()) { + validatorNodeStack->PushValidator(Children_[0].get()); + } + } + + void OnChildDone(TValidatorNodeStack* validatorNodeStack) override + { + Position_++; + if (Position_ < Children_.size()) { + validatorNodeStack->PushValidator(Children_[Position_].get()); + } else { + validatorNodeStack->PopValidator(); + } + } + +private: + const TValidatorNodeList Children_; + ui32 Position_ = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TSkiffValidator::TSkiffValidator(std::shared_ptr<TSkiffSchema> skiffSchema) + : Context_(std::make_unique<TValidatorNodeStack>(CreateUsageValidatorNode(std::move(skiffSchema)))) +{ } + +TSkiffValidator::~TSkiffValidator() +{ } + +void TSkiffValidator::BeforeVariant8Tag() +{ + Context_->PushRootIfRequired(); + Context_->Top()->BeforeVariant8Tag(); +} + +void TSkiffValidator::OnVariant8Tag(ui8 tag) +{ + Context_->PushRootIfRequired(); + Context_->Top()->OnVariant8Tag(Context_.get(), tag); +} + +void TSkiffValidator::BeforeVariant16Tag() +{ + Context_->PushRootIfRequired(); + Context_->Top()->BeforeVariant16Tag(); +} + +void TSkiffValidator::OnVariant16Tag(ui16 tag) +{ + Context_->PushRootIfRequired(); + Context_->Top()->OnVariant16Tag(Context_.get(), tag); +} + +void TSkiffValidator::OnSimpleType(EWireType value) +{ + Context_->PushRootIfRequired(); + Context_->Top()->OnSimpleType(Context_.get(), value); +} + +void TSkiffValidator::ValidateFinished() +{ + if (!Context_->IsFinished()) { + ythrow TSkiffException() << "Parse/write is not finished"; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TValidatorNodeList CreateUsageValidatorNodeList(const TSkiffSchemaList& skiffSchemaList) +{ + TValidatorNodeList result; + result.reserve(skiffSchemaList.size()); + for (const auto& skiffSchema : skiffSchemaList) { + result.push_back(CreateUsageValidatorNode(skiffSchema)); + } + return result; +} + +std::shared_ptr<IValidatorNode> CreateUsageValidatorNode(const std::shared_ptr<TSkiffSchema>& skiffSchema) +{ + switch (skiffSchema->GetWireType()) { + case EWireType::Int8: + case EWireType::Int16: + case EWireType::Int32: + case EWireType::Int64: + case EWireType::Int128: + + case EWireType::Uint8: + case EWireType::Uint16: + case EWireType::Uint32: + case EWireType::Uint64: + case EWireType::Uint128: + + case EWireType::Double: + case EWireType::Boolean: + case EWireType::String32: + case EWireType::Yson32: + return std::make_shared<TSimpleTypeUsageValidator>(skiffSchema->GetWireType()); + case EWireType::Nothing: + return std::make_shared<TNothingTypeValidator>(); + case EWireType::Tuple: + return std::make_shared<TTupleTypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren())); + case EWireType::Variant8: + return std::make_shared<TVariant8TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren())); + case EWireType::Variant16: + return std::make_shared<TVariant16TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren())); + case EWireType::RepeatedVariant8: + return std::make_shared<TRepeatedVariant8TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren())); + case EWireType::RepeatedVariant16: + return std::make_shared<TRepeatedVariant16TypeUsageValidator>(CreateUsageValidatorNodeList(skiffSchema->GetChildren())); + } + Y_FAIL(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NSkiff |