diff options
author | primorial <primorial@yandex-team.com> | 2022-09-28 16:57:04 +0300 |
---|---|---|
committer | primorial <primorial@yandex-team.com> | 2022-09-28 16:57:04 +0300 |
commit | b327caf7cfb59302e973938a4fa27c45d92a00eb (patch) | |
tree | fd0b62b01f77d5277b45cc482cc4bed4602f772b | |
parent | 4524f6bdbb266ac2004ba894f90bc0ffd4785e7f (diff) | |
download | ydb-b327caf7cfb59302e973938a4fa27c45d92a00eb.tar.gz |
Update contrib/libs/apache/orc to 1.8.0
56 files changed, 6342 insertions, 1028 deletions
diff --git a/contrib/libs/apache/orc/CMakeLists.txt b/contrib/libs/apache/orc/CMakeLists.txt index b6fc10cf929..a567c021735 100644 --- a/contrib/libs/apache/orc/CMakeLists.txt +++ b/contrib/libs/apache/orc/CMakeLists.txt @@ -61,6 +61,12 @@ target_sources(libs-apache-orc PRIVATE ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/Writer.cc ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/io/InputStream.cc ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/io/OutputStream.cc + ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.cc + ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/sargs/Literal.cc + ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.cc + ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.cc + ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.cc + ${CMAKE_SOURCE_DIR}/contrib/libs/apache/orc/c++/src/sargs/TruthValue.cc ) target_proto_addincls(libs-apache-orc ./ diff --git a/contrib/libs/apache/orc/README.md b/contrib/libs/apache/orc/README.md index 0668ee07a55..a7d959247e1 100644 --- a/contrib/libs/apache/orc/README.md +++ b/contrib/libs/apache/orc/README.md @@ -15,18 +15,18 @@ lists, maps, and unions. ## ORC File Library -This project includes both a Java library and a C++ library for reading and writing the _Optimized Row Columnar_ (ORC) file format. The C++ and Java libraries are completely independent of each other and will each read all versions of ORC files. But the C++ library only writes the original (Hive 0.11) version of ORC files, and will be extended in the future. +This project includes both a Java library and a C++ library for reading and writing the _Optimized Row Columnar_ (ORC) file format. The C++ and Java libraries are completely independent of each other and will each read all versions of ORC files. Releases: * Latest: <a href="http://orc.apache.org/releases">Apache ORC releases</a> * Maven Central: <a href="http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.orc%22"></a> * Downloads: <a href="http://orc.apache.org/downloads">Apache ORC downloads</a> +* Release tags: <a href="https://github.com/apache/orc/releases">Apache ORC release tags</a> +* Plan: <a href="https://github.com/apache/orc/milestones">Apache ORC future release plan</a> The current build status: -* Master branch <a href="https://travis-ci.org/apache/orc/branches"> -</a> -* <a href="https://travis-ci.org/apache/orc/pull_requests">Pull Requests</a> - +* Main branch <a href="https://github.com/apache/orc/actions/workflows/build_and_test.yml?query=branch%3Amain"> +</a> Bug tracking: <a href="http://orc.apache.org/bugs">Apache Jira</a> @@ -39,13 +39,12 @@ The subdirectories are: * java - the java reader and writer * proto - the protocol buffer definition for the ORC metadata * site - the website and documentation -* snap - the script to build [snaps](https://snapcraft.io/) of the ORC tools * tools - the c++ tools for reading and inspecting ORC files ### Building * Install java 1.8 or higher -* Install maven 3 or higher +* Install maven 3.8.6 or higher * Install cmake To build a release version with debug information: diff --git a/contrib/libs/apache/orc/c++/include/orc/BloomFilter.hh b/contrib/libs/apache/orc/c++/include/orc/BloomFilter.hh index 86c1288b625..91277392c7b 100644 --- a/contrib/libs/apache/orc/c++/include/orc/BloomFilter.hh +++ b/contrib/libs/apache/orc/c++/include/orc/BloomFilter.hh @@ -40,6 +40,6 @@ namespace orc { std::vector<std::shared_ptr<BloomFilter>> entries; }; -}; +} #endif //ORC_BLOOMFILTER_HH diff --git a/contrib/libs/apache/orc/c++/include/orc/Common.hh b/contrib/libs/apache/orc/c++/include/orc/Common.hh index 4aa4a85118c..e51e37e7107 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Common.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Common.hh @@ -34,6 +34,7 @@ namespace orc { public: static const FileVersion& v_0_11(); static const FileVersion& v_0_12(); + static const FileVersion& UNSTABLE_PRE_2_0(); FileVersion(uint32_t major, uint32_t minor) : majorVersion(major), minorVersion(minor) { @@ -123,6 +124,17 @@ namespace orc { }; /** + * Specific read intention when selecting a certain TypeId. + * This enum currently only being utilized by LIST, MAP, and UNION type selection. + */ + enum ReadIntent { + ReadIntent_ALL = 0, + + // Only read the offsets of selected type. Do not read the children types. + ReadIntent_OFFSETS = 1 + }; + + /** * Get the string representation of the StreamKind. */ std::string streamKindToString(StreamKind kind); @@ -281,6 +293,30 @@ namespace orc { FUTURE = INT32_MAX }; + inline bool operator<(const Decimal& lhs, const Decimal& rhs) { + return compare(lhs, rhs); + } + + inline bool operator>(const Decimal& lhs, const Decimal& rhs) { + return rhs < lhs; + } + + inline bool operator<=(const Decimal& lhs, const Decimal& rhs) { + return !(lhs > rhs); + } + + inline bool operator>=(const Decimal& lhs, const Decimal& rhs) { + return !(lhs < rhs); + } + + inline bool operator!=(const Decimal& lhs, const Decimal& rhs) { + return lhs < rhs || rhs < lhs; + } + + inline bool operator==(const Decimal& lhs, const Decimal& rhs) { + return !(lhs != rhs); + } + } #endif diff --git a/contrib/libs/apache/orc/c++/include/orc/Int128.hh b/contrib/libs/apache/orc/c++/include/orc/Int128.hh index f86d8f08a64..1f68b2b119f 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Int128.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Int128.hh @@ -311,8 +311,13 @@ namespace orc { /** * Return the base 10 string representation with a decimal point, * the given number of places after the decimal. + * + * @param scale scale of the Int128 to be interpreted as a decimal value + * @param trimTrailingZeros whether or not to trim trailing zeros + * @return converted string representation */ - std::string toDecimalString(int32_t scale=0) const; + std::string toDecimalString(int32_t scale = 0, + bool trimTrailingZeros = false) const; /** * Return the base 16 string representation of the two's complement with diff --git a/contrib/libs/apache/orc/c++/include/orc/Reader.hh b/contrib/libs/apache/orc/c++/include/orc/Reader.hh index 5d9a532c11d..ddc8b550554 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Reader.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Reader.hh @@ -23,6 +23,7 @@ #include "orc/Common.hh" #include "orc/orc-config.hh" #include "orc/Statistics.hh" +#include "orc/sargs/SearchArgument.hh" #include "orc/Type.hh" #include "orc/Vector.hh" @@ -149,6 +150,24 @@ namespace orc { RowReaderOptions& includeTypes(const std::list<uint64_t>& types); /** + * A map type of <typeId, ReadIntent>. + */ + typedef std::map<uint64_t, ReadIntent> IdReadIntentMap; + + /** + * Selects which type ids to read and specific ReadIntents for each + * type id. The ancestor types are automatically selected, but the children + * are not. + * + * This option clears any previous setting of the selected columns or + * types. + * @param idReadIntentMap a map of IdReadIntentMap. + * @return this + */ + RowReaderOptions& + includeTypesWithIntents(const IdReadIntentMap& idReadIntentMap); + + /** * Set the section of the file to process. * @param offset the starting byte offset * @param length the number of bytes to read @@ -192,6 +211,11 @@ namespace orc { RowReaderOptions& setEnableLazyDecoding(bool enable); /** + * Set search argument for predicate push down + */ + RowReaderOptions& searchArgument(std::unique_ptr<SearchArgument> sargs); + + /** * Should enable encoding block mode */ bool getEnableLazyDecoding() const; @@ -245,6 +269,26 @@ namespace orc { * What scale should all Hive 0.11 decimals be normalized to? */ int32_t getForcedScaleOnHive11Decimal() const; + + /** + * Get search argument for predicate push down + */ + std::shared_ptr<SearchArgument> getSearchArgument() const; + + /** + * Set desired timezone to return data of timestamp type + */ + RowReaderOptions& setTimezoneName(const std::string& zoneName); + + /** + * Get desired timezone to return data of timestamp type + */ + const std::string& getTimezoneName() const; + + /** + * Get the IdReadIntentMap map that was supplied by client. + */ + const IdReadIntentMap getIdReadIntentMap() const; }; diff --git a/contrib/libs/apache/orc/c++/include/orc/Statistics.hh b/contrib/libs/apache/orc/c++/include/orc/Statistics.hh index 1d4b0b6558b..4d7caeab3d8 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Statistics.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Statistics.hh @@ -305,26 +305,26 @@ namespace orc { virtual ~TimestampColumnStatistics(); /** - * Check whether column minimum. + * Check whether minimum timestamp exists. * @return true if has minimum */ virtual bool hasMinimum() const = 0; /** - * Check whether column maximum. + * Check whether maximum timestamp exists. * @return true if has maximum */ virtual bool hasMaximum() const = 0; /** - * Get the minimum value for the column. - * @return minimum value + * Get the millisecond of minimum timestamp in UTC. + * @return minimum value in millisecond */ virtual int64_t getMinimum() const = 0; /** - * Get the maximum value for the column. - * @return maximum value + * Get the millisecond of maximum timestamp in UTC. + * @return maximum value in millisecond */ virtual int64_t getMaximum() const = 0; @@ -352,7 +352,17 @@ namespace orc { */ virtual int64_t getUpperBound() const = 0; + /** + * Get the last 6 digits of nanosecond of minimum timestamp. + * @return last 6 digits of nanosecond of minimum timestamp. + */ + virtual int32_t getMinimumNanos() const = 0; + /** + * Get the last 6 digits of nanosecond of maximum timestamp. + * @return last 6 digits of nanosecond of maximum timestamp. + */ + virtual int32_t getMaximumNanos() const = 0; }; class Statistics { @@ -374,6 +384,74 @@ namespace orc { virtual uint32_t getNumberOfColumns() const = 0; }; + /** + * Statistics for all of collections such as Map and List. + */ + class CollectionColumnStatistics : public ColumnStatistics { + public: + virtual ~CollectionColumnStatistics(); + + /** + * check whether column has minimum number of children + * @return true if has minimum children count + */ + virtual bool hasMinimumChildren() const = 0; + + /** + * check whether column has maximum number of children + * @return true if has maximum children count + */ + virtual bool hasMaximumChildren() const = 0; + + /** + * check whether column has total number of children + * @return true if has total children count + */ + virtual bool hasTotalChildren() const = 0; + + /** + * set hasTotalChildren value + * @param newHasTotalChildren hasTotalChildren value + */ + virtual void setHasTotalChildren(bool newHasTotalChildren) = 0; + + /** + * Get minimum number of children in the collection. + * @return the minimum children count + */ + virtual uint64_t getMinimumChildren() const = 0; + + /** + * set new minimum children count + * @param min new minimum children count + */ + virtual void setMinimumChildren(uint64_t min) = 0; + + /** + * Get maximum number of children in the collection. + * @return the maximum children count + */ + virtual uint64_t getMaximumChildren() const = 0; + + /** + * set new maximum children count + * @param max new maximum children count + */ + virtual void setMaximumChildren(uint64_t max) = 0; + + /** + * Get the total number of children in the collection. + * @return the total number of children + */ + virtual uint64_t getTotalChildren() const = 0; + + /** + * set new total children count + * @param newTotalChildrenCount total children count to be set + */ + virtual void setTotalChildren(uint64_t newTotalChildrenCount) = 0; + }; + class StripeStatistics : public Statistics { public: virtual ~StripeStatistics(); diff --git a/contrib/libs/apache/orc/c++/include/orc/Type.hh b/contrib/libs/apache/orc/c++/include/orc/Type.hh index c0cbf2d6716..a7df8307e69 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Type.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Type.hh @@ -43,7 +43,8 @@ namespace orc { DECIMAL = 14, DATE = 15, VARCHAR = 16, - CHAR = 17 + CHAR = 17, + TIMESTAMP_INSTANT = 18 }; class Type { @@ -58,6 +59,12 @@ namespace orc { virtual uint64_t getMaximumLength() const = 0; virtual uint64_t getPrecision() const = 0; virtual uint64_t getScale() const = 0; + virtual Type& setAttribute(const std::string& key, + const std::string& value) = 0; + virtual bool hasAttributeKey(const std::string& key) const = 0; + virtual Type& removeAttribute(const std::string& key) = 0; + virtual std::vector<std::string> getAttributeKeys() const = 0; + virtual std::string getAttributeValue(const std::string& key) const = 0; virtual std::string toString() const = 0; /** diff --git a/contrib/libs/apache/orc/c++/include/orc/Vector.hh b/contrib/libs/apache/orc/c++/include/orc/Vector.hh index 629c0b7f6bd..752e1af78a8 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Vector.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Vector.hh @@ -134,7 +134,7 @@ namespace orc { DataBuffer<int64_t> dictionaryOffset; void getValueByIndex(int64_t index, char*& valPtr, int64_t& length) { - if (index < 0 || static_cast<uint64_t>(index) >= dictionaryOffset.size()) { + if (index < 0 || static_cast<uint64_t>(index) + 1 >= dictionaryOffset.size()) { throw std::out_of_range("index out of range."); } @@ -154,6 +154,7 @@ namespace orc { EncodedStringVectorBatch(uint64_t capacity, MemoryPool& pool); virtual ~EncodedStringVectorBatch(); std::string toString() const; + void resize(uint64_t capacity); std::shared_ptr<StringDictionary> dictionary; // index for dictionary entry @@ -240,7 +241,7 @@ namespace orc { explicit Decimal(const std::string& value); Decimal(); - std::string toString() const; + std::string toString(bool trimTrailingZeros = false) const; Int128 value; int32_t scale; }; diff --git a/contrib/libs/apache/orc/c++/include/orc/Writer.hh b/contrib/libs/apache/orc/c++/include/orc/Writer.hh index 5b333861b1e..78b0b97d25f 100644 --- a/contrib/libs/apache/orc/c++/include/orc/Writer.hh +++ b/contrib/libs/apache/orc/c++/include/orc/Writer.hh @@ -217,6 +217,24 @@ namespace orc { * Get version of BloomFilter */ BloomFilterVersion getBloomFilterVersion() const; + + /** + * Get writer timezone + * @return writer timezone + */ + const Timezone& getTimezone() const; + + /** + * Get writer timezone name + * @return writer timezone name + */ + const std::string& getTimezoneName() const; + + /** + * Set writer timezone + * @param zone writer timezone name + */ + WriterOptions& setTimezoneName(const std::string& zone); }; class Writer { diff --git a/contrib/libs/apache/orc/c++/include/orc/orc-config.hh b/contrib/libs/apache/orc/c++/include/orc/orc-config.hh index 18bbbd78e12..b8fb9fbd4ee 100644 --- a/contrib/libs/apache/orc/c++/include/orc/orc-config.hh +++ b/contrib/libs/apache/orc/c++/include/orc/orc-config.hh @@ -15,7 +15,7 @@ #ifndef ORC_CONFIG_HH #define ORC_CONFIG_HH -#define ORC_VERSION "1.6.12" +#define ORC_VERSION "1.8.0" #define ORC_CXX_HAS_CSTDINT #define ORC_CXX_HAS_INITIALIZER_LIST diff --git a/contrib/libs/apache/orc/c++/include/orc/sargs/Literal.hh b/contrib/libs/apache/orc/c++/include/orc/sargs/Literal.hh new file mode 100644 index 00000000000..36c9b37e3f2 --- /dev/null +++ b/contrib/libs/apache/orc/c++/include/orc/sargs/Literal.hh @@ -0,0 +1,160 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_LITERAL_HH +#define ORC_LITERAL_HH + +#include "orc/Int128.hh" +#include "orc/Vector.hh" + +namespace orc { + + /** + * Possible data types for predicates + */ + enum class PredicateDataType { + LONG = 0, FLOAT, STRING, DATE, DECIMAL, TIMESTAMP, BOOLEAN + }; + + /** + * Represents a literal value in a predicate + */ + class Literal { + public: + struct Timestamp { + Timestamp() = default; + Timestamp(const Timestamp&) = default; + Timestamp(Timestamp&&) = default; + ~Timestamp() = default; + Timestamp(int64_t second_, int32_t nanos_): second(second_), nanos(nanos_) { + // PASS + } + Timestamp& operator=(const Timestamp&) = default; + Timestamp& operator=(Timestamp&&) = default; + bool operator==(const Timestamp& r) const { + return second == r.second && nanos == r.nanos; + } + bool operator<(const Timestamp& r) const { + return second < r.second || (second == r.second && nanos < r.nanos); + } + bool operator<=(const Timestamp& r) const { + return second < r.second || (second == r.second && nanos <= r.nanos); + } + bool operator!=(const Timestamp& r) const { return !(*this == r); } + bool operator>(const Timestamp& r) const { return r < *this; } + bool operator>=(const Timestamp& r) const { return r <= *this; } + int64_t getMillis() const { return second * 1000 + nanos / 1000000; } + int64_t second; + int32_t nanos; + }; + + Literal(const Literal &r); + ~Literal(); + Literal& operator=(const Literal& r); + bool operator==(const Literal& r) const; + bool operator!=(const Literal& r) const; + + /** + * Create a literal of null value for a specific type + */ + Literal(PredicateDataType type); + + /** + * Create a literal of LONG type + */ + Literal(int64_t val); + + /** + * Create a literal of FLOAT type + */ + Literal(double val); + + /** + * Create a literal of BOOLEAN type + */ + Literal(bool val); + + /** + * Create a literal of DATE type + */ + Literal(PredicateDataType type, int64_t val); + + /** + * Create a literal of TIMESTAMP type + */ + Literal(int64_t second, int32_t nanos); + + /** + * Create a literal of STRING type + */ + Literal(const char * str, size_t size); + + /** + * Create a literal of DECIMAL type + */ + Literal(Int128 val, int32_t precision, int32_t scale); + + /** + * Getters of a specific data type for not-null literals + */ + int64_t getLong() const; + int64_t getDate() const; + Timestamp getTimestamp() const; + double getFloat() const; + std::string getString() const; + bool getBool() const; + Decimal getDecimal() const; + + /** + * Check if a literal is null + */ + bool isNull() const { return mIsNull; } + + PredicateDataType getType() const { return mType; } + std::string toString() const; + size_t getHashCode() const { return mHashCode; } + + private: + size_t hashCode() const; + + union LiteralVal { + int64_t IntVal; + double DoubleVal; + int64_t DateVal; + char * Buffer; + Timestamp TimeStampVal; + Int128 DecimalVal; + bool BooleanVal; + + // explicitly define default constructor + LiteralVal(): DecimalVal(0) {} + }; + + private: + LiteralVal mValue; // data value for this literal if not null + PredicateDataType mType; // data type of the literal + size_t mSize; // size of mValue if it is Buffer + int32_t mPrecision; // precision of decimal type + int32_t mScale; // scale of decimal type + bool mIsNull; // whether this literal is null + size_t mHashCode; // precomputed hash code for the literal + }; + +} // namespace orc + +#endif //ORC_LITERAL_HH diff --git a/contrib/libs/apache/orc/c++/include/orc/sargs/SearchArgument.hh b/contrib/libs/apache/orc/c++/include/orc/sargs/SearchArgument.hh new file mode 100644 index 00000000000..44fde8f5e90 --- /dev/null +++ b/contrib/libs/apache/orc/c++/include/orc/sargs/SearchArgument.hh @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_SEARCHARGUMENT_HH +#define ORC_SEARCHARGUMENT_HH + +#include "orc/sargs/Literal.hh" +#include "orc/sargs/TruthValue.hh" + +namespace orc { + + /** + * Primary interface for a search argument, which are the subset of predicates + * that can be pushed down to the RowReader. Each SearchArgument consists + * of a series of search clauses that must each be true for the row to be + * accepted by the filter. + * + * This requires that the filter be normalized into conjunctive normal form + * (<a href="http://en.wikipedia.org/wiki/Conjunctive_normal_form">CNF</a>). + */ + class SearchArgument { + public: + virtual ~SearchArgument(); + + /** + * Evaluate the entire predicate based on the values for the leaf predicates. + * @param leaves the value of each leaf predicate + * @return the value of hte entire predicate + */ + virtual TruthValue evaluate(const std::vector<TruthValue>& leaves) const = 0; + + virtual std::string toString() const = 0; + }; + + /** + * A builder object to create a SearchArgument from expressions. The user + * must call startOr, startAnd, or startNot before adding any leaves. + */ + class SearchArgumentBuilder { + public: + virtual ~SearchArgumentBuilder(); + + /** + * Start building an or operation and push it on the stack. + * @return this + */ + virtual SearchArgumentBuilder& startOr() = 0; + + /** + * Start building an and operation and push it on the stack. + * @return this + */ + virtual SearchArgumentBuilder& startAnd() = 0; + + /** + * Start building a not operation and push it on the stack. + * @return this + */ + virtual SearchArgumentBuilder& startNot() = 0; + + /** + * Finish the current operation and pop it off of the stack. Each start + * call must have a matching end. + * @return this + */ + virtual SearchArgumentBuilder& end() = 0; + + /** + * Add a less than leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& lessThan(const std::string& column, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add a less than leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& lessThan(uint64_t columnId, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add a less than equals leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& lessThanEquals(const std::string& column, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add a less than equals leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& lessThanEquals(uint64_t columnId, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add an equals leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& equals(const std::string& column, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add an equals leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& equals(uint64_t columnId, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add a null safe equals leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& nullSafeEquals(const std::string& column, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add a null safe equals leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + virtual SearchArgumentBuilder& nullSafeEquals(uint64_t columnId, + PredicateDataType type, + Literal literal) = 0; + + /** + * Add an in leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + virtual SearchArgumentBuilder& in(const std::string& column, + PredicateDataType type, + const std::initializer_list<Literal>& literals) = 0; + + /** + * Add an in leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + virtual SearchArgumentBuilder& in(uint64_t columnId, + PredicateDataType type, + const std::initializer_list<Literal>& literals) = 0; + + /** + * Add an in leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + virtual SearchArgumentBuilder& in(const std::string& column, + PredicateDataType type, + const std::vector<Literal>& literals) = 0; + + /** + * Add an in leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + virtual SearchArgumentBuilder& in(uint64_t columnId, + PredicateDataType type, + const std::vector<Literal>& literals) = 0; + + /** + * Add an is null leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @return this + */ + virtual SearchArgumentBuilder& isNull(const std::string& column, + PredicateDataType type) = 0; + + /** + * Add an is null leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @return this + */ + virtual SearchArgumentBuilder& isNull(uint64_t columnId, + PredicateDataType type) = 0; + + /** + * Add a between leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param lower the literal + * @param upper the literal + * @return this + */ + virtual SearchArgumentBuilder& between(const std::string& column, + PredicateDataType type, + Literal lower, + Literal upper) = 0; + + /** + * Add a between leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param lower the literal + * @param upper the literal + * @return this + */ + virtual SearchArgumentBuilder& between(uint64_t columnId, + PredicateDataType type, + Literal lower, + Literal upper) = 0; + + /** + * Add a truth value to the expression. + * @param truth truth value + * @return this + */ + virtual SearchArgumentBuilder& literal(TruthValue truth) = 0; + + /** + * Build and return the SearchArgument that has been defined. All of the + * starts must have been ended before this call. + * @return the new SearchArgument + */ + virtual std::unique_ptr<SearchArgument> build() = 0; + }; + + /** + * Factory to create SearchArgumentBuilder which builds SearchArgument + */ + class SearchArgumentFactory { + public: + static std::unique_ptr<SearchArgumentBuilder> newBuilder(); + }; + +} // namespace orc + +#endif //ORC_SEARCHARGUMENT_HH diff --git a/contrib/libs/apache/orc/c++/include/orc/sargs/TruthValue.hh b/contrib/libs/apache/orc/c++/include/orc/sargs/TruthValue.hh new file mode 100644 index 00000000000..b3ea6b76ce4 --- /dev/null +++ b/contrib/libs/apache/orc/c++/include/orc/sargs/TruthValue.hh @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_TRUTHVALUE_HH +#define ORC_TRUTHVALUE_HH + +namespace orc { + + /** + * The potential result sets of logical operations. + */ + enum class TruthValue { + YES, // all rows satisfy the predicate + NO, // all rows dissatisfy the predicate + IS_NULL, // all rows are null value + YES_NULL, // null values exist, not-null rows satisfy the predicate + NO_NULL, // null values exist, not-null rows dissatisfy the predicate + YES_NO, // some rows satisfy the predicate and the others not + YES_NO_NULL // null values exist, some rows satisfy predicate and some not + }; + + // Compute logical or between the two values. + TruthValue operator||(TruthValue left, TruthValue right); + + // Compute logical AND between the two values. + TruthValue operator&&(TruthValue left, TruthValue right); + + // Compute logical NOT for one value. + TruthValue operator!(TruthValue val); + + // Do we need to read the data based on the TruthValue? + bool isNeeded(TruthValue val); + +} // namespace orc + +#endif //ORC_TRUTHVALUE_HH diff --git a/contrib/libs/apache/orc/c++/src/Adaptor.hh b/contrib/libs/apache/orc/c++/src/Adaptor.hh index a91b9c894db..1b13ec0ca6e 100644 --- a/contrib/libs/apache/orc/c++/src/Adaptor.hh +++ b/contrib/libs/apache/orc/c++/src/Adaptor.hh @@ -30,7 +30,7 @@ #define HAS_DOUBLE_TO_STRING #define HAS_INT64_TO_STRING #define HAS_PRE_1970 -#define HAS_POST_2038 +/* #undef HAS_POST_2038 */ #define HAS_STD_ISNAN #define HAS_STD_MUTEX #ifndef _MSC_VER diff --git a/contrib/libs/apache/orc/c++/src/ByteRLE.cc b/contrib/libs/apache/orc/c++/src/ByteRLE.cc index ee1a4575dc6..1c4a645167e 100644 --- a/contrib/libs/apache/orc/c++/src/ByteRLE.cc +++ b/contrib/libs/apache/orc/c++/src/ByteRLE.cc @@ -61,6 +61,13 @@ namespace orc { virtual void recordPosition(PositionRecorder* recorder) const override; + virtual void suppress() override; + + /** + * Reset to initial state + */ + void reset(); + protected: std::unique_ptr<BufferedOutputStream> outputStream; char* literals; @@ -80,12 +87,7 @@ namespace orc { std::unique_ptr<BufferedOutputStream> output) : outputStream(std::move(output)) { literals = new char[MAX_LITERAL_SIZE]; - numLiterals = 0; - tailRunLength = 0; - repeat = false; - bufferPosition = 0; - bufferLength = 0; - buffer = nullptr; + reset(); } ByteRleEncoderImpl::~ByteRleEncoderImpl() { @@ -203,6 +205,21 @@ namespace orc { recorder->add(static_cast<uint64_t>(numLiterals)); } + void ByteRleEncoderImpl::reset() { + numLiterals = 0; + tailRunLength = 0; + repeat = false; + bufferPosition = 0; + bufferLength = 0; + buffer = nullptr; + } + + void ByteRleEncoderImpl::suppress() { + // written data can be just ignored because they are only flushed in memory + outputStream->suppress(); + reset(); + } + std::unique_ptr<ByteRleEncoder> createByteRleEncoder (std::unique_ptr<BufferedOutputStream> output) { return std::unique_ptr<ByteRleEncoder>(new ByteRleEncoderImpl diff --git a/contrib/libs/apache/orc/c++/src/ByteRLE.hh b/contrib/libs/apache/orc/c++/src/ByteRLE.hh index 71ca579cd3b..2f6e2eb4df6 100644 --- a/contrib/libs/apache/orc/c++/src/ByteRLE.hh +++ b/contrib/libs/apache/orc/c++/src/ByteRLE.hh @@ -55,6 +55,11 @@ namespace orc { * @param recorder use the recorder to record current positions */ virtual void recordPosition(PositionRecorder* recorder) const = 0; + + /** + * suppress the data and reset to initial state + */ + virtual void suppress() = 0; }; class ByteRleDecoder { diff --git a/contrib/libs/apache/orc/c++/src/ColumnPrinter.cc b/contrib/libs/apache/orc/c++/src/ColumnPrinter.cc index b4b5860cad8..ab6b690c572 100644 --- a/contrib/libs/apache/orc/c++/src/ColumnPrinter.cc +++ b/contrib/libs/apache/orc/c++/src/ColumnPrinter.cc @@ -169,22 +169,20 @@ namespace orc { private: const unsigned char *tags; const uint64_t* offsets; - std::vector<ColumnPrinter*> fieldPrinter; + std::vector<std::unique_ptr<ColumnPrinter>> fieldPrinter; public: UnionColumnPrinter(std::string&, const Type& type); - virtual ~UnionColumnPrinter() override; void printRow(uint64_t rowId) override; void reset(const ColumnVectorBatch& batch) override; }; class StructColumnPrinter: public ColumnPrinter { private: - std::vector<ColumnPrinter*> fieldPrinter; + std::vector<std::unique_ptr<ColumnPrinter>> fieldPrinter; std::vector<std::string> fieldNames; public: StructColumnPrinter(std::string&, const Type& type); - virtual ~StructColumnPrinter() override; void printRow(uint64_t rowId) override; void reset(const ColumnVectorBatch& batch) override; }; @@ -251,6 +249,7 @@ namespace orc { break; case TIMESTAMP: + case TIMESTAMP_INSTANT: result = new TimestampColumnPrinter(buffer); break; @@ -540,14 +539,7 @@ namespace orc { tags(nullptr), offsets(nullptr) { for(unsigned int i=0; i < type.getSubtypeCount(); ++i) { - fieldPrinter.push_back(createColumnPrinter(buffer, type.getSubtype(i)) - .release()); - } - } - - UnionColumnPrinter::~UnionColumnPrinter() { - for (size_t i = 0; i < fieldPrinter.size(); i++) { - delete fieldPrinter[i]; + fieldPrinter.push_back(createColumnPrinter(buffer, type.getSubtype(i))); } } @@ -582,15 +574,7 @@ namespace orc { ): ColumnPrinter(_buffer) { for(unsigned int i=0; i < type.getSubtypeCount(); ++i) { fieldNames.push_back(type.getFieldName(i)); - fieldPrinter.push_back(createColumnPrinter(buffer, - type.getSubtype(i)) - .release()); - } - } - - StructColumnPrinter::~StructColumnPrinter() { - for (size_t i = 0; i < fieldPrinter.size(); i++) { - delete fieldPrinter[i]; + fieldPrinter.push_back(createColumnPrinter(buffer, type.getSubtype(i))); } } diff --git a/contrib/libs/apache/orc/c++/src/ColumnReader.cc b/contrib/libs/apache/orc/c++/src/ColumnReader.cc index 8cf660be11a..f4a4df92486 100644 --- a/contrib/libs/apache/orc/c++/src/ColumnReader.cc +++ b/contrib/libs/apache/orc/c++/src/ColumnReader.cc @@ -305,10 +305,14 @@ namespace orc { std::unique_ptr<orc::RleDecoder> secondsRle; std::unique_ptr<orc::RleDecoder> nanoRle; const Timezone& writerTimezone; + const Timezone& readerTimezone; const int64_t epochOffset; + const bool sameTimezone; public: - TimestampColumnReader(const Type& type, StripeStreams& stripe); + TimestampColumnReader(const Type& type, + StripeStreams& stripe, + bool isInstantType); ~TimestampColumnReader() override; uint64_t skip(uint64_t numValues) override; @@ -323,10 +327,17 @@ namespace orc { TimestampColumnReader::TimestampColumnReader(const Type& type, - StripeStreams& stripe + StripeStreams& stripe, + bool isInstantType ): ColumnReader(type, stripe), - writerTimezone(stripe.getWriterTimezone()), - epochOffset(writerTimezone.getEpoch()) { + writerTimezone(isInstantType ? + getTimezoneByName("GMT") : + stripe.getWriterTimezone()), + readerTimezone(isInstantType ? + getTimezoneByName("GMT") : + stripe.getReaderTimezone()), + epochOffset(writerTimezone.getEpoch()), + sameTimezone(&writerTimezone == &readerTimezone){ RleVersion vers = convertRleVersion(stripe.getEncoding(columnId).kind()); std::unique_ptr<SeekableInputStream> stream = stripe.getStream(columnId, proto::Stream_Kind_DATA, true); @@ -373,7 +384,20 @@ namespace orc { } } int64_t writerTime = secsBuffer[i] + epochOffset; - secsBuffer[i] = writerTimezone.convertToUTC(writerTime); + if (!sameTimezone) { + // adjust timestamp value to same wall clock time if writer and reader + // time zones have different rules, which is required for Apache Orc. + const auto& wv = writerTimezone.getVariant(writerTime); + const auto& rv = readerTimezone.getVariant(writerTime); + if (!wv.hasSameTzRule(rv)) { + // If the timezone adjustment moves the millis across a DST boundary, + // we need to reevaluate the offsets. + int64_t adjustedTime = writerTime + wv.gmtOffset - rv.gmtOffset; + const auto& adjustedReader = readerTimezone.getVariant(adjustedTime); + writerTime = writerTime + wv.gmtOffset - adjustedReader.gmtOffset; + } + } + secsBuffer[i] = writerTime; if (secsBuffer[i] < 0 && nanoBuffer[i] > 999999) { secsBuffer[i] -= 1; } @@ -388,10 +412,11 @@ namespace orc { nanoRle->seek(positions.at(columnId)); } + template<TypeKind columnKind, bool isLittleEndian> class DoubleColumnReader: public ColumnReader { public: DoubleColumnReader(const Type& type, StripeStreams& stripe); - ~DoubleColumnReader() override; + ~DoubleColumnReader() override {} uint64_t skip(uint64_t numValues) override; @@ -404,8 +429,7 @@ namespace orc { private: std::unique_ptr<SeekableInputStream> inputStream; - TypeKind columnKind; - const uint64_t bytesPerValue ; + const uint64_t bytesPerValue = (columnKind == FLOAT) ? 4 : 8; const char *bufferPointer; const char *bufferEnd; @@ -423,8 +447,24 @@ namespace orc { double readDouble() { int64_t bits = 0; - for (uint64_t i=0; i < 8; i++) { - bits |= static_cast<int64_t>(readByte()) << (i*8); + if (bufferEnd - bufferPointer >= 8) { + if (isLittleEndian) { + bits = *(reinterpret_cast<const int64_t*>(bufferPointer)); + } else { + bits = static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[0])); + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[1])) << 8; + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[2])) << 16; + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[3])) << 24; + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[4])) << 32; + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[5])) << 40; + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[6])) << 48; + bits |= static_cast<int64_t>(static_cast<unsigned char>(bufferPointer[7])) << 56; + } + bufferPointer += 8; + } else { + for (uint64_t i = 0; i < 8; i++) { + bits |= static_cast<int64_t>(readByte()) << (i * 8); + } } double *result = reinterpret_cast<double*>(&bits); return *result; @@ -432,32 +472,40 @@ namespace orc { double readFloat() { int32_t bits = 0; - for (uint64_t i=0; i < 4; i++) { - bits |= readByte() << (i*8); + if (bufferEnd - bufferPointer >= 4) { + if (isLittleEndian) { + bits = *(reinterpret_cast<const int32_t*>(bufferPointer)); + } else { + bits = static_cast<unsigned char>(bufferPointer[0]); + bits |= static_cast<unsigned char>(bufferPointer[1]) << 8; + bits |= static_cast<unsigned char>(bufferPointer[2]) << 16; + bits |= static_cast<unsigned char>(bufferPointer[3]) << 24; + } + bufferPointer += 4; + } else { + for (uint64_t i = 0; i < 4; i++) { + bits |= readByte() << (i * 8); + } } float *result = reinterpret_cast<float*>(&bits); return static_cast<double>(*result); } }; - DoubleColumnReader::DoubleColumnReader(const Type& type, - StripeStreams& stripe - ): ColumnReader(type, stripe), - columnKind(type.getKind()), - bytesPerValue((type.getKind() == - FLOAT) ? 4 : 8), - bufferPointer(nullptr), - bufferEnd(nullptr) { + template<TypeKind columnKind, bool isLittleEndian> + DoubleColumnReader<columnKind, isLittleEndian>::DoubleColumnReader( + const Type& type, + StripeStreams& stripe + ): ColumnReader(type, stripe), + bufferPointer(nullptr), + bufferEnd(nullptr) { inputStream = stripe.getStream(columnId, proto::Stream_Kind_DATA, true); if (inputStream == nullptr) throw ParseError("DATA stream not found in Double column"); } - DoubleColumnReader::~DoubleColumnReader() { - // PASS - } - - uint64_t DoubleColumnReader::skip(uint64_t numValues) { + template<TypeKind columnKind, bool isLittleEndian> + uint64_t DoubleColumnReader<columnKind, isLittleEndian>::skip(uint64_t numValues) { numValues = ColumnReader::skip(numValues); if (static_cast<size_t>(bufferEnd - bufferPointer) >= @@ -479,9 +527,11 @@ namespace orc { return numValues; } - void DoubleColumnReader::next(ColumnVectorBatch& rowBatch, - uint64_t numValues, - char *notNull) { + template<TypeKind columnKind, bool isLittleEndian> + void DoubleColumnReader<columnKind, isLittleEndian>::next( + ColumnVectorBatch& rowBatch, + uint64_t numValues, + char *notNull) { ColumnReader::next(rowBatch, numValues, notNull); // update the notNull from the parent class notNull = rowBatch.hasNulls ? rowBatch.notNull.data() : nullptr; @@ -507,13 +557,33 @@ namespace orc { } } } else { - for(size_t i=0; i < numValues; ++i) { + // Number of values in the buffer that we can copy directly. + // Only viable when the machine is little-endian. + uint64_t bufferNum = 0; + if (isLittleEndian) { + bufferNum = std::min(numValues, + static_cast<size_t>(bufferEnd - bufferPointer) / bytesPerValue); + uint64_t bufferBytes = bufferNum * bytesPerValue; + memcpy(outArray, bufferPointer, bufferBytes); + bufferPointer += bufferBytes; + } + for (size_t i = bufferNum; i < numValues; ++i) { outArray[i] = readDouble(); } } } } + template<TypeKind columnKind, bool isLittleEndian> + void DoubleColumnReader<columnKind, isLittleEndian>::seekToRowGroup( + std::unordered_map<uint64_t, PositionProvider>& positions) { + ColumnReader::seekToRowGroup(positions); + inputStream->seek(positions.at(columnId)); + // clear buffer state after seek + bufferEnd = nullptr; + bufferPointer = nullptr; + } + void readFully(char* buffer, int64_t bufferSize, SeekableInputStream* stream) { int64_t posn = 0; while (posn < bufferSize) { @@ -530,12 +600,6 @@ namespace orc { } } - void DoubleColumnReader::seekToRowGroup( - std::unordered_map<uint64_t, PositionProvider>& positions) { - ColumnReader::seekToRowGroup(positions); - inputStream->seek(positions.at(columnId)); - } - class StringDictionaryColumnReader: public ColumnReader { private: std::shared_ptr<StringDictionary> dictionary; @@ -567,30 +631,37 @@ namespace orc { RleVersion rleVersion = convertRleVersion(stripe.getEncoding(columnId) .kind()); uint32_t dictSize = stripe.getEncoding(columnId).dictionarysize(); - rle = createRleDecoder(stripe.getStream(columnId, - proto::Stream_Kind_DATA, - true), - false, rleVersion, memoryPool); + std::unique_ptr<SeekableInputStream> stream = + stripe.getStream(columnId, proto::Stream_Kind_DATA, true); + if (stream == nullptr) { + throw ParseError("DATA stream not found in StringDictionaryColumn"); + } + rle = createRleDecoder(std::move(stream), false, rleVersion, memoryPool); + stream = stripe.getStream(columnId, proto::Stream_Kind_LENGTH, false); + if (dictSize > 0 && stream == nullptr) { + throw ParseError("LENGTH stream not found in StringDictionaryColumn"); + } std::unique_ptr<RleDecoder> lengthDecoder = - createRleDecoder(stripe.getStream(columnId, - proto::Stream_Kind_LENGTH, - false), - false, rleVersion, memoryPool); + createRleDecoder(std::move(stream), false, rleVersion, memoryPool); dictionary->dictionaryOffset.resize(dictSize + 1); int64_t* lengthArray = dictionary->dictionaryOffset.data(); lengthDecoder->next(lengthArray + 1, dictSize, nullptr); lengthArray[0] = 0; for(uint32_t i = 1; i < dictSize + 1; ++i) { + if (lengthArray[i] < 0) { + throw ParseError("Negative dictionary entry length"); + } lengthArray[i] += lengthArray[i - 1]; } - dictionary->dictionaryBlob.resize( - static_cast<uint64_t>(lengthArray[dictSize])); + int64_t blobSize = lengthArray[dictSize]; + dictionary->dictionaryBlob.resize(static_cast<uint64_t>(blobSize)); std::unique_ptr<SeekableInputStream> blobStream = stripe.getStream(columnId, proto::Stream_Kind_DICTIONARY_DATA, false); - readFully( - dictionary->dictionaryBlob.data(), - lengthArray[dictSize], - blobStream.get()); + if (blobSize > 0 && blobStream == nullptr) { + throw ParseError( + "DICTIONARY_DATA stream not found in StringDictionaryColumn"); + } + readFully(dictionary->dictionaryBlob.data(), blobSize, blobStream.get()); } StringDictionaryColumnReader::~StringDictionaryColumnReader() { @@ -831,15 +902,17 @@ namespace orc { ColumnReader::seekToRowGroup(positions); blobStream->seek(positions.at(columnId)); lengthRle->seek(positions.at(columnId)); + // clear buffer state after seek + lastBuffer = nullptr; + lastBufferLength = 0; } class StructColumnReader: public ColumnReader { private: - std::vector<ColumnReader*> children; + std::vector<std::unique_ptr<ColumnReader>> children; public: StructColumnReader(const Type& type, StripeStreams& stipe); - ~StructColumnReader() override; uint64_t skip(uint64_t numValues) override; @@ -871,7 +944,7 @@ namespace orc { for(unsigned int i=0; i < type.getSubtypeCount(); ++i) { const Type& child = *type.getSubtype(i); if (selectedColumns[static_cast<uint64_t>(child.getColumnId())]) { - children.push_back(buildReader(child, stripe).release()); + children.push_back(buildReader(child, stripe)); } } break; @@ -883,16 +956,10 @@ namespace orc { } } - StructColumnReader::~StructColumnReader() { - for (size_t i=0; i<children.size(); i++) { - delete children[i]; - } - } - uint64_t StructColumnReader::skip(uint64_t numValues) { numValues = ColumnReader::skip(numValues); - for(std::vector<ColumnReader*>::iterator ptr=children.begin(); ptr != children.end(); ++ptr) { - (*ptr)->skip(numValues); + for(auto& ptr : children) { + ptr->skip(numValues); } return numValues; } @@ -916,13 +983,12 @@ namespace orc { ColumnReader::next(rowBatch, numValues, notNull); uint64_t i=0; notNull = rowBatch.hasNulls? rowBatch.notNull.data() : nullptr; - for(std::vector<ColumnReader*>::iterator ptr=children.begin(); - ptr != children.end(); ++ptr, ++i) { + for(auto iter = children.begin(); iter != children.end(); ++iter, ++i) { if (encoded) { - (*ptr)->nextEncoded(*(dynamic_cast<StructVectorBatch&>(rowBatch).fields[i]), + (*iter)->nextEncoded(*(dynamic_cast<StructVectorBatch&>(rowBatch).fields[i]), numValues, notNull); } else { - (*ptr)->next(*(dynamic_cast<StructVectorBatch&>(rowBatch).fields[i]), + (*iter)->next(*(dynamic_cast<StructVectorBatch&>(rowBatch).fields[i]), numValues, notNull); } } @@ -932,10 +998,8 @@ namespace orc { std::unordered_map<uint64_t, PositionProvider>& positions) { ColumnReader::seekToRowGroup(positions); - for(std::vector<ColumnReader*>::iterator ptr = children.begin(); - ptr != children.end(); - ++ptr) { - (*ptr)->seekToRowGroup(positions); + for(auto& ptr : children) { + ptr->seekToRowGroup(positions); } } @@ -1230,13 +1294,12 @@ namespace orc { class UnionColumnReader: public ColumnReader { private: std::unique_ptr<ByteRleDecoder> rle; - std::vector<ColumnReader*> childrenReader; + std::vector<std::unique_ptr<ColumnReader>> childrenReader; std::vector<int64_t> childrenCounts; uint64_t numChildren; public: UnionColumnReader(const Type& type, StripeStreams& stipe); - ~UnionColumnReader() override; uint64_t skip(uint64_t numValues) override; @@ -1275,18 +1338,11 @@ namespace orc { for(unsigned int i=0; i < numChildren; ++i) { const Type &child = *type.getSubtype(i); if (selectedColumns[static_cast<size_t>(child.getColumnId())]) { - childrenReader[i] = buildReader(child, stripe).release(); + childrenReader[i] = buildReader(child, stripe); } } } - UnionColumnReader::~UnionColumnReader() { - for(std::vector<ColumnReader*>::iterator itr = childrenReader.begin(); - itr != childrenReader.end(); ++itr) { - delete *itr; - } - } - uint64_t UnionColumnReader::skip(uint64_t numValues) { numValues = ColumnReader::skip(numValues); const uint64_t BUFFER_SIZE = 1024; @@ -1564,6 +1620,9 @@ namespace orc { ColumnReader::seekToRowGroup(positions); valueStream->seek(positions.at(columnId)); scaleDecoder->seek(positions.at(columnId)); + // clear buffer state after seek + buffer = nullptr; + bufferEnd = nullptr; } class Decimal128ColumnReader: public Decimal64ColumnReader { @@ -1634,6 +1693,60 @@ namespace orc { } } + class Decimal64ColumnReaderV2: public ColumnReader { + protected: + std::unique_ptr<RleDecoder> valueDecoder; + int32_t precision; + int32_t scale; + + public: + Decimal64ColumnReaderV2(const Type& type, StripeStreams& stripe); + ~Decimal64ColumnReaderV2() override; + + uint64_t skip(uint64_t numValues) override; + + void next(ColumnVectorBatch& rowBatch, + uint64_t numValues, + char *notNull) override; + }; + + Decimal64ColumnReaderV2::Decimal64ColumnReaderV2(const Type& type, + StripeStreams& stripe + ): ColumnReader(type, stripe) { + scale = static_cast<int32_t>(type.getScale()); + precision = static_cast<int32_t>(type.getPrecision()); + std::unique_ptr<SeekableInputStream> stream = + stripe.getStream(columnId, proto::Stream_Kind_DATA, true); + if (stream == nullptr) { + std::stringstream ss; + ss << "DATA stream not found in Decimal64V2 column. ColumnId=" << columnId; + throw ParseError(ss.str()); + } + valueDecoder = createRleDecoder(std::move(stream), true, RleVersion_2, memoryPool); + } + + Decimal64ColumnReaderV2::~Decimal64ColumnReaderV2() { + // PASS + } + + uint64_t Decimal64ColumnReaderV2::skip(uint64_t numValues) { + numValues = ColumnReader::skip(numValues); + valueDecoder->skip(numValues); + return numValues; + } + + void Decimal64ColumnReaderV2::next(ColumnVectorBatch& rowBatch, + uint64_t numValues, + char *notNull) { + ColumnReader::next(rowBatch, numValues, notNull); + notNull = rowBatch.hasNulls ? rowBatch.notNull.data() : nullptr; + Decimal64VectorBatch &batch = + dynamic_cast<Decimal64VectorBatch&>(rowBatch); + valueDecoder->next(batch.values.data(), numValues, notNull); + batch.precision = precision; + batch.scale = scale; + } + class DecimalHive11ColumnReader: public Decimal64ColumnReader { private: bool throwOnOverflow; @@ -1748,6 +1861,11 @@ namespace orc { } } + static bool isLittleEndian() { + static union { uint32_t i; char c[4]; } num = { 0x01020304 }; + return num.c[0] == 4; + } + /** * Create a reader for the given stripe. */ @@ -1802,31 +1920,47 @@ namespace orc { new StructColumnReader(type, stripe)); case FLOAT: + if (isLittleEndian()) { + return std::unique_ptr<ColumnReader>( + new DoubleColumnReader<FLOAT, true>(type, stripe)); + } + return std::unique_ptr<ColumnReader>( + new DoubleColumnReader<FLOAT, false>(type, stripe)); + case DOUBLE: + if (isLittleEndian()) { + return std::unique_ptr<ColumnReader>( + new DoubleColumnReader<DOUBLE, true>(type, stripe)); + } return std::unique_ptr<ColumnReader>( - new DoubleColumnReader(type, stripe)); + new DoubleColumnReader<DOUBLE, false>(type, stripe)); case TIMESTAMP: return std::unique_ptr<ColumnReader> - (new TimestampColumnReader(type, stripe)); + (new TimestampColumnReader(type, stripe, false)); + + case TIMESTAMP_INSTANT: + return std::unique_ptr<ColumnReader> + (new TimestampColumnReader(type, stripe, true)); case DECIMAL: // is this a Hive 0.11 or 0.12 file? if (type.getPrecision() == 0) { return std::unique_ptr<ColumnReader> (new DecimalHive11ColumnReader(type, stripe)); - + } // can we represent the values using int64_t? - } else if (type.getPrecision() <= - Decimal64ColumnReader::MAX_PRECISION_64) { + if (type.getPrecision() <= Decimal64ColumnReader::MAX_PRECISION_64) { + if (stripe.isDecimalAsLong()) { + return std::unique_ptr<ColumnReader> + (new Decimal64ColumnReaderV2(type, stripe)); + } return std::unique_ptr<ColumnReader> (new Decimal64ColumnReader(type, stripe)); - - // otherwise we use the Int128 implementation - } else { - return std::unique_ptr<ColumnReader> - (new Decimal128ColumnReader(type, stripe)); } + // otherwise we use the Int128 implementation + return std::unique_ptr<ColumnReader> + (new Decimal128ColumnReader(type, stripe)); default: throw NotImplementedYet("buildReader unhandled type"); diff --git a/contrib/libs/apache/orc/c++/src/ColumnReader.hh b/contrib/libs/apache/orc/c++/src/ColumnReader.hh index 0c64e5b80f3..80b59de2c12 100644 --- a/contrib/libs/apache/orc/c++/src/ColumnReader.hh +++ b/contrib/libs/apache/orc/c++/src/ColumnReader.hh @@ -69,6 +69,11 @@ namespace orc { virtual const Timezone& getWriterTimezone() const = 0; /** + * Get the reader's timezone, so that we can convert their dates correctly. + */ + virtual const Timezone& getReaderTimezone() const = 0; + + /** * Get the error stream. * @return a pointer to the stream that should get error messages */ @@ -86,6 +91,12 @@ namespace orc { * @return the number of scale digits */ virtual int32_t getForcedScaleOnHive11Decimal() const = 0; + + /** + * Whether decimals that have precision <=18 are encoded as fixed scale and values + * encoded in RLE. + */ + virtual bool isDecimalAsLong() const = 0; }; /** diff --git a/contrib/libs/apache/orc/c++/src/ColumnWriter.cc b/contrib/libs/apache/orc/c++/src/ColumnWriter.cc index 1408a15457c..32b68af3490 100644 --- a/contrib/libs/apache/orc/c++/src/ColumnWriter.cc +++ b/contrib/libs/apache/orc/c++/src/ColumnWriter.cc @@ -100,7 +100,8 @@ namespace orc { enableBloomFilter(false), memPool(*options.getMemoryPool()), indexStream(), - bloomFilterStream() { + bloomFilterStream(), + hasNullValue(false) { std::unique_ptr<BufferedOutputStream> presentStream = factory.createStream(proto::Stream_Kind_PRESENT); @@ -139,10 +140,22 @@ namespace orc { uint64_t offset, uint64_t numValues, const char* incomingMask) { - notNullEncoder->add(batch.notNull.data() + offset, numValues, incomingMask); + const char* notNull = batch.notNull.data() + offset; + notNullEncoder->add(notNull, numValues, incomingMask); + hasNullValue |= batch.hasNulls; + for (uint64_t i = 0; !hasNullValue && i < numValues; ++i) { + if (!notNull[i]) { + hasNullValue = true; + } + } } void ColumnWriter::flush(std::vector<proto::Stream>& streams) { + if (!hasNullValue) { + // supress the present stream + notNullEncoder->suppress(); + return; + } proto::Stream stream; stream.set_kind(proto::Stream_Kind_PRESENT); stream.set_column(static_cast<uint32_t>(columnId)); @@ -199,6 +212,21 @@ namespace orc { } void ColumnWriter::writeIndex(std::vector<proto::Stream> &streams) const { + if (!hasNullValue) { + // remove positions of present stream + int presentCount = indexStream->isCompressed() ? 4 : 3; + for (int i = 0; i != rowIndex->entry_size(); ++i) { + proto::RowIndexEntry* entry = rowIndex->mutable_entry(i); + std::vector<uint64_t> positions; + for (int j = presentCount; j < entry->positions_size(); ++j) { + positions.push_back(entry->positions(j)); + } + entry->clear_positions(); + for (size_t j = 0; j != positions.size(); ++j) { + entry->add_positions(positions[j]); + } + } + } // write row index to output stream rowIndex->SerializeToZeroCopyStream(indexStream.get()); @@ -252,7 +280,6 @@ namespace orc { const Type& type, const StreamsFactory& factory, const WriterOptions& options); - ~StructColumnWriter() override; virtual void add(ColumnVectorBatch& rowBatch, uint64_t offset, @@ -285,7 +312,7 @@ namespace orc { virtual void reset() override; private: - std::vector<ColumnWriter *> children; + std::vector<std::unique_ptr<ColumnWriter>> children; }; StructColumnWriter::StructColumnWriter( @@ -295,7 +322,7 @@ namespace orc { ColumnWriter(type, factory, options) { for(unsigned int i = 0; i < type.getSubtypeCount(); ++i) { const Type& child = *type.getSubtype(i); - children.push_back(buildWriter(child, factory, options).release()); + children.push_back(buildWriter(child, factory, options)); } if (enableIndex) { @@ -303,12 +330,6 @@ namespace orc { } } - StructColumnWriter::~StructColumnWriter() { - for (uint32_t i = 0; i < children.size(); ++i) { - delete children[i]; - } - } - void StructColumnWriter::add( ColumnVectorBatch& rowBatch, uint64_t offset, @@ -1690,6 +1711,9 @@ namespace orc { if (!notNull || notNull[i]) { directDataStream->write(data[i], unsignedLength); + if (enableBloomFilter) { + bloomFilter->addBytes(data[i], length[i]); + } binStats->update(unsignedLength); ++count; } @@ -1705,7 +1729,8 @@ namespace orc { public: TimestampColumnWriter(const Type& type, const StreamsFactory& factory, - const WriterOptions& options); + const WriterOptions& options, + bool isInstantType); virtual void add(ColumnVectorBatch& rowBatch, uint64_t offset, @@ -1727,15 +1752,21 @@ namespace orc { private: RleVersion rleVersion; const Timezone& timezone; + const bool isUTC; }; TimestampColumnWriter::TimestampColumnWriter( const Type& type, const StreamsFactory& factory, - const WriterOptions& options) : + const WriterOptions& options, + bool isInstantType) : ColumnWriter(type, factory, options), rleVersion(options.getRleVersion()), - timezone(getTimezoneByName("GMT")){ + timezone(isInstantType ? + getTimezoneByName("GMT") : + options.getTimezone()), + isUTC(isInstantType || + options.getTimezoneName() == "GMT") { std::unique_ptr<BufferedOutputStream> dataStream = factory.createStream(proto::Stream_Kind_DATA); std::unique_ptr<BufferedOutputStream> secondaryStream = @@ -1805,11 +1836,14 @@ namespace orc { if (notNull == nullptr || notNull[i]) { // TimestampVectorBatch already stores data in UTC int64_t millsUTC = secs[i] * 1000 + nanos[i] / 1000000; + if (!isUTC) { + millsUTC = timezone.convertToUTC(secs[i]) * 1000 + nanos[i] / 1000000; + } ++count; if (enableBloomFilter) { bloomFilter->addLong(millsUTC); } - tsStats->update(millsUTC); + tsStats->update(millsUTC, static_cast<int32_t>(nanos[i] % 1000000)); if (secs[i] < 0 && nanos[i] > 999999) { secs[i] += 1; @@ -2026,7 +2060,7 @@ namespace orc { ++count; if (enableBloomFilter) { std::string decimal = Decimal( - values[i], static_cast<int32_t>(scale)).toString(); + values[i], static_cast<int32_t>(scale)).toString(true); bloomFilter->addBytes( decimal.c_str(), static_cast<int64_t>(decimal.size())); } @@ -2081,6 +2115,127 @@ namespace orc { scaleEncoder->recordPosition(rowIndexPosition.get()); } + class Decimal64ColumnWriterV2 : public ColumnWriter { + public: + Decimal64ColumnWriterV2(const Type& type, + const StreamsFactory& factory, + const WriterOptions& options); + + virtual void add(ColumnVectorBatch& rowBatch, + uint64_t offset, + uint64_t numValues, + const char* incomingMask) override; + + virtual void flush(std::vector<proto::Stream>& streams) override; + + virtual uint64_t getEstimatedSize() const override; + + virtual void getColumnEncoding( + std::vector<proto::ColumnEncoding>& encodings) const override; + + virtual void recordPosition() const override; + + protected: + uint64_t precision; + uint64_t scale; + std::unique_ptr<RleEncoder> valueEncoder; + }; + + Decimal64ColumnWriterV2::Decimal64ColumnWriterV2( + const Type& type, + const StreamsFactory& factory, + const WriterOptions& options) : + ColumnWriter(type, factory, options), + precision(type.getPrecision()), + scale(type.getScale()) { + std::unique_ptr<BufferedOutputStream> dataStream = + factory.createStream(proto::Stream_Kind_DATA); + valueEncoder = createRleEncoder(std::move(dataStream), + true, + RleVersion_2, + memPool, + options.getAlignedBitpacking()); + + if (enableIndex) { + recordPosition(); + } + } + + void Decimal64ColumnWriterV2::add(ColumnVectorBatch& rowBatch, + uint64_t offset, + uint64_t numValues, + const char* incomingMask) { + const Decimal64VectorBatch* decBatch = + dynamic_cast<const Decimal64VectorBatch*>(&rowBatch); + if (decBatch == nullptr) { + throw InvalidArgument("Failed to cast to Decimal64VectorBatch"); + } + + DecimalColumnStatisticsImpl* decStats = + dynamic_cast<DecimalColumnStatisticsImpl*>(colIndexStatistics.get()); + if (decStats == nullptr) { + throw InvalidArgument("Failed to cast to DecimalColumnStatisticsImpl"); + } + + ColumnWriter::add(rowBatch, offset, numValues, incomingMask); + + const int64_t* data = decBatch->values.data() + offset; + const char* notNull = decBatch->hasNulls ? + decBatch->notNull.data() + offset : nullptr; + + valueEncoder->add(data, numValues, notNull); + + uint64_t count = 0; + for (uint64_t i = 0; i < numValues; ++i) { + if (!notNull || notNull[i]) { + ++count; + if (enableBloomFilter) { + std::string decimal = Decimal( + data[i], static_cast<int32_t>(scale)).toString(true); + bloomFilter->addBytes( + decimal.c_str(), static_cast<int64_t>(decimal.size())); + } + decStats->update(Decimal(data[i], static_cast<int32_t>(scale))); + } + } + decStats->increase(count); + if (count < numValues) { + decStats->setHasNull(true); + } + } + + void Decimal64ColumnWriterV2::flush(std::vector<proto::Stream>& streams) { + ColumnWriter::flush(streams); + + proto::Stream dataStream; + dataStream.set_kind(proto::Stream_Kind_DATA); + dataStream.set_column(static_cast<uint32_t>(columnId)); + dataStream.set_length(valueEncoder->flush()); + streams.push_back(dataStream); + } + + uint64_t Decimal64ColumnWriterV2::getEstimatedSize() const { + uint64_t size = ColumnWriter::getEstimatedSize(); + size += valueEncoder->getBufferSize(); + return size; + } + + void Decimal64ColumnWriterV2::getColumnEncoding( + std::vector<proto::ColumnEncoding>& encodings) const { + proto::ColumnEncoding encoding; + encoding.set_kind(RleVersionMapper(RleVersion_2)); + encoding.set_dictionarysize(0); + if (enableBloomFilter) { + encoding.set_bloomencoding(BloomFilterVersion::UTF8); + } + encodings.push_back(encoding); + } + + void Decimal64ColumnWriterV2::recordPosition() const { + ColumnWriter::recordPosition(); + valueEncoder->recordPosition(rowIndexPosition.get()); + } + class Decimal128ColumnWriter : public Decimal64ColumnWriter { public: Decimal128ColumnWriter(const Type& type, @@ -2160,7 +2315,7 @@ namespace orc { ++count; if (enableBloomFilter) { std::string decimal = Decimal( - values[i], static_cast<int32_t>(scale)).toString(); + values[i], static_cast<int32_t>(scale)).toString(true); bloomFilter->addBytes( decimal.c_str(), static_cast<int64_t>(decimal.size())); } @@ -2256,6 +2411,11 @@ namespace orc { if (listBatch == nullptr) { throw InvalidArgument("Failed to cast to ListVectorBatch"); } + CollectionColumnStatisticsImpl* collectionStats = + dynamic_cast<CollectionColumnStatisticsImpl*>(colIndexStatistics.get()); + if (collectionStats == nullptr) { + throw InvalidArgument("Failed to cast to CollectionColumnStatisticsImpl"); + } ColumnWriter::add(rowBatch, offset, numValues, incomingMask); @@ -2279,20 +2439,21 @@ namespace orc { if (enableIndex) { if (!notNull) { - colIndexStatistics->increase(numValues); + collectionStats->increase(numValues); } else { uint64_t count = 0; for (uint64_t i = 0; i < numValues; ++i) { if (notNull[i]) { ++count; + collectionStats->update(static_cast<uint64_t>(offsets[i])); if (enableBloomFilter) { bloomFilter->addLong(offsets[i]); } } } - colIndexStatistics->increase(count); + collectionStats->increase(count); if (count < numValues) { - colIndexStatistics->setHasNull(true); + collectionStats->setHasNull(true); } } } @@ -2482,6 +2643,11 @@ namespace orc { if (mapBatch == nullptr) { throw InvalidArgument("Failed to cast to MapVectorBatch"); } + CollectionColumnStatisticsImpl* collectionStats = + dynamic_cast<CollectionColumnStatisticsImpl*>(colIndexStatistics.get()); + if (collectionStats == nullptr) { + throw InvalidArgument("Failed to cast to CollectionColumnStatisticsImpl"); + } ColumnWriter::add(rowBatch, offset, numValues, incomingMask); @@ -2509,20 +2675,21 @@ namespace orc { if (enableIndex) { if (!notNull) { - colIndexStatistics->increase(numValues); + collectionStats->increase(numValues); } else { uint64_t count = 0; for (uint64_t i = 0; i < numValues; ++i) { if (notNull[i]) { ++count; + collectionStats->update(static_cast<uint64_t>(offsets[i])); if (enableBloomFilter) { bloomFilter->addLong(offsets[i]); } } } - colIndexStatistics->increase(count); + collectionStats->increase(count); if (count < numValues) { - colIndexStatistics->setHasNull(true); + collectionStats->setHasNull(true); } } } @@ -2666,7 +2833,6 @@ namespace orc { UnionColumnWriter(const Type& type, const StreamsFactory& factory, const WriterOptions& options); - ~UnionColumnWriter() override; virtual void add(ColumnVectorBatch& rowBatch, uint64_t offset, @@ -2703,7 +2869,7 @@ namespace orc { private: std::unique_ptr<ByteRleEncoder> rleEncoder; - std::vector<ColumnWriter*> children; + std::vector<std::unique_ptr<ColumnWriter>> children; }; UnionColumnWriter::UnionColumnWriter(const Type& type, @@ -2718,7 +2884,7 @@ namespace orc { for (uint64_t i = 0; i != type.getSubtypeCount(); ++i) { children.push_back(buildWriter(*type.getSubtype(i), factory, - options).release()); + options)); } if (enableIndex) { @@ -2726,12 +2892,6 @@ namespace orc { } } - UnionColumnWriter::~UnionColumnWriter() { - for (uint32_t i = 0; i < children.size(); ++i) { - delete children[i]; - } - } - void UnionColumnWriter::add(ColumnVectorBatch& rowBatch, uint64_t offset, uint64_t numValues, @@ -2969,9 +3129,24 @@ namespace orc { new TimestampColumnWriter( type, factory, - options)); + options, + false)); + case TIMESTAMP_INSTANT: + return std::unique_ptr<ColumnWriter>( + new TimestampColumnWriter( + type, + factory, + options, + true)); case DECIMAL: if (type.getPrecision() <= Decimal64ColumnWriter::MAX_PRECISION_64) { + if (options.getFileVersion() == FileVersion::UNSTABLE_PRE_2_0()) { + return std::unique_ptr<ColumnWriter>( + new Decimal64ColumnWriterV2( + type, + factory, + options)); + } return std::unique_ptr<ColumnWriter>( new Decimal64ColumnWriter( type, diff --git a/contrib/libs/apache/orc/c++/src/ColumnWriter.hh b/contrib/libs/apache/orc/c++/src/ColumnWriter.hh index cbbb5d00dc7..20983774c4c 100644 --- a/contrib/libs/apache/orc/c++/src/ColumnWriter.hh +++ b/contrib/libs/apache/orc/c++/src/ColumnWriter.hh @@ -207,6 +207,7 @@ namespace orc { MemoryPool& memPool; std::unique_ptr<BufferedOutputStream> indexStream; std::unique_ptr<BufferedOutputStream> bloomFilterStream; + bool hasNullValue; }; /** diff --git a/contrib/libs/apache/orc/c++/src/Common.cc b/contrib/libs/apache/orc/c++/src/Common.cc index dbf073797ef..477bfd3b4c8 100644 --- a/contrib/libs/apache/orc/c++/src/Common.cc +++ b/contrib/libs/apache/orc/c++/src/Common.cc @@ -131,8 +131,11 @@ namespace orc { } std::string FileVersion::toString() const { + if (majorVersion == 1 && minorVersion == 9999) { + return "UNSTABLE-PRE-2.0"; + } std::stringstream ss; - ss << getMajor() << '.' << getMinor(); + ss << majorVersion << '.' << minorVersion; return ss.str(); } @@ -145,4 +148,17 @@ namespace orc { static FileVersion version(0,12); return version; } + + /** + * Do not use this format except for testing. It will not be compatible + * with other versions of the software. While we iterate on the ORC 2.0 + * format, we will make incompatible format changes under this version + * without providing any forward or backward compatibility. + * + * When 2.0 is released, this version identifier will be completely removed. + */ + const FileVersion& FileVersion::UNSTABLE_PRE_2_0() { + static FileVersion version(1, 9999); + return version; + } } diff --git a/contrib/libs/apache/orc/c++/src/Compression.cc b/contrib/libs/apache/orc/c++/src/Compression.cc index 4278ed7aaec..ea101715078 100644 --- a/contrib/libs/apache/orc/c++/src/Compression.cc +++ b/contrib/libs/apache/orc/c++/src/Compression.cc @@ -36,6 +36,15 @@ #define ZSTD_CLEVEL_DEFAULT 3 #endif +/* These macros are defined in lz4.c */ +#ifndef LZ4_ACCELERATION_DEFAULT +#define LZ4_ACCELERATION_DEFAULT 1 +#endif + +#ifndef LZ4_ACCELERATION_MAX +#define LZ4_ACCELERATION_MAX 65537 +#endif + namespace orc { class CompressionStreamBase: public BufferedOutputStream { @@ -312,152 +321,166 @@ DIAGNOSTIC_PUSH DECOMPRESS_ORIGINAL, DECOMPRESS_EOF}; - class ZlibDecompressionStream: public SeekableInputStream { + std::string decompressStateToString(DecompressState state) { + switch (state) { + case DECOMPRESS_HEADER: return "DECOMPRESS_HEADER"; + case DECOMPRESS_START: return "DECOMPRESS_START"; + case DECOMPRESS_CONTINUE: return "DECOMPRESS_CONTINUE"; + case DECOMPRESS_ORIGINAL: return "DECOMPRESS_ORIGINAL"; + case DECOMPRESS_EOF: return "DECOMPRESS_EOF"; + } + return "unknown"; + } + + class DecompressionStream : public SeekableInputStream { public: - ZlibDecompressionStream(std::unique_ptr<SeekableInputStream> inStream, - size_t blockSize, - MemoryPool& pool); - virtual ~ZlibDecompressionStream() override; + DecompressionStream(std::unique_ptr<SeekableInputStream> inStream, + size_t bufferSize, + MemoryPool& pool); + virtual ~DecompressionStream() override {} virtual bool Next(const void** data, int*size) override; virtual void BackUp(int count) override; virtual bool Skip(int count) override; virtual int64_t ByteCount() const override; virtual void seek(PositionProvider& position) override; - virtual std::string getName() const override; + virtual std::string getName() const override = 0; - private: - void readBuffer(bool failOnEof) { - int length; - if (!input->Next(reinterpret_cast<const void**>(&inputBuffer), - &length)) { - if (failOnEof) { - throw ParseError("Read past EOF in " - "ZlibDecompressionStream::readBuffer"); - } - state = DECOMPRESS_EOF; - inputBuffer = nullptr; - inputBufferEnd = nullptr; - } else { - inputBufferEnd = inputBuffer + length; - } - } + protected: + virtual void NextDecompress(const void** data, + int*size, + size_t availableSize) = 0; - uint32_t readByte(bool failOnEof) { - if (inputBuffer == inputBufferEnd) { - readBuffer(failOnEof); - if (state == DECOMPRESS_EOF) { - return 0; - } - } - return static_cast<unsigned char>(*(inputBuffer++)); - } - - void readHeader() { - uint32_t header = readByte(false); - if (state != DECOMPRESS_EOF) { - header |= readByte(true) << 8; - header |= readByte(true) << 16; - if (header & 1) { - state = DECOMPRESS_ORIGINAL; - } else { - state = DECOMPRESS_START; - } - remainingLength = header >> 1; - } else { - remainingLength = 0; - } - } + std::string getStreamName() const; + void readBuffer(bool failOnEof); + uint32_t readByte(bool failOnEof); + void readHeader(); MemoryPool& pool; - const size_t blockSize; std::unique_ptr<SeekableInputStream> input; - z_stream zstream; - DataBuffer<char> buffer; + + // uncompressed output + DataBuffer<char> outputDataBuffer; // the current state DecompressState state; - // the start of the current buffer - // This pointer is not owned by us. It is either owned by zstream or - // the underlying stream. - const char* outputBuffer; - // the size of the current buffer + // The starting and current position of the buffer for the uncompressed + // data. It either points to the data buffer or the underlying input stream. + const char *outputBufferStart; + const char *outputBuffer; size_t outputBufferLength; - // the size of the current chunk + // The uncompressed buffer length. For compressed chunk, it's the original + // (ie. the overall) and the actual length of the decompressed data. + // For uncompressed chunk, it's the length of the loaded data of this chunk. + size_t uncompressedBufferLength; + + // The remaining size of the current chunk that is not yet consumed + // ie. decompressed or returned in output if state==DECOMPRESS_ORIGINAL size_t remainingLength; // the last buffer returned from the input + const char *inputBufferStart; const char *inputBuffer; const char *inputBufferEnd; + // Variables for saving the position of the header and the start of the + // buffer. Used when we have to seek a position. + size_t headerPosition; + size_t inputBufferStartPosition; + // roughly the number of bytes returned off_t bytesReturned; }; -DIAGNOSTIC_PUSH + DecompressionStream::DecompressionStream( + std::unique_ptr<SeekableInputStream> inStream, + size_t bufferSize, + MemoryPool& _pool + ) : pool(_pool), + input(std::move(inStream)), + outputDataBuffer(pool, bufferSize), + state(DECOMPRESS_HEADER), + outputBufferStart(nullptr), + outputBuffer(nullptr), + outputBufferLength(0), + uncompressedBufferLength(0), + remainingLength(0), + inputBufferStart(nullptr), + inputBuffer(nullptr), + inputBufferEnd(nullptr), + headerPosition(0), + inputBufferStartPosition(0), + bytesReturned(0) { + } -#if defined(__GNUC__) || defined(__clang__) - DIAGNOSTIC_IGNORE("-Wold-style-cast") -#endif + std::string DecompressionStream::getStreamName() const { + return input->getName(); + } - ZlibDecompressionStream::ZlibDecompressionStream - (std::unique_ptr<SeekableInputStream> inStream, - size_t _blockSize, - MemoryPool& _pool - ): pool(_pool), - blockSize(_blockSize), - buffer(pool, _blockSize) { - input.reset(inStream.release()); - zstream.next_in = nullptr; - zstream.avail_in = 0; - zstream.zalloc = nullptr; - zstream.zfree = nullptr; - zstream.opaque = nullptr; - zstream.next_out = reinterpret_cast<Bytef*>(buffer.data()); - zstream.avail_out = static_cast<uInt>(blockSize); - int64_t result = inflateInit2(&zstream, -15); - switch (result) { - case Z_OK: - break; - case Z_MEM_ERROR: - throw std::logic_error("Memory error from inflateInit2"); - case Z_VERSION_ERROR: - throw std::logic_error("Version error from inflateInit2"); - case Z_STREAM_ERROR: - throw std::logic_error("Stream error from inflateInit2"); - default: - throw std::logic_error("Unknown error from inflateInit2"); + void DecompressionStream::readBuffer(bool failOnEof) { + int length; + if (!input->Next(reinterpret_cast<const void**>(&inputBuffer), + &length)) { + if (failOnEof) { + throw ParseError("Read past EOF in DecompressionStream::readBuffer"); + } + state = DECOMPRESS_EOF; + inputBuffer = nullptr; + inputBufferEnd = nullptr; + inputBufferStart = nullptr; + } else { + inputBufferEnd = inputBuffer + length; + inputBufferStartPosition + = static_cast<size_t>(input->ByteCount() - length); + inputBufferStart = inputBuffer; } - outputBuffer = nullptr; - outputBufferLength = 0; - remainingLength = 0; - state = DECOMPRESS_HEADER; - inputBuffer = nullptr; - inputBufferEnd = nullptr; - bytesReturned = 0; } -DIAGNOSTIC_POP + uint32_t DecompressionStream::readByte(bool failOnEof) { + if (inputBuffer == inputBufferEnd) { + readBuffer(failOnEof); + if (state == DECOMPRESS_EOF) { + return 0; + } + } + return static_cast<unsigned char>(*(inputBuffer++)); + } - ZlibDecompressionStream::~ZlibDecompressionStream() { - int64_t result = inflateEnd(&zstream); - if (result != Z_OK) { - // really can't throw in destructors - std::cout << "Error in ~ZlibDecompressionStream() " << result << "\n"; + void DecompressionStream::readHeader() { + uint32_t header = readByte(false); + if (state != DECOMPRESS_EOF) { + header |= readByte(true) << 8; + header |= readByte(true) << 16; + if (header & 1) { + state = DECOMPRESS_ORIGINAL; + } else { + state = DECOMPRESS_START; + } + remainingLength = header >> 1; + } else { + remainingLength = 0; } } - bool ZlibDecompressionStream::Next(const void** data, int*size) { - // if the user pushed back, return them the partial buffer + bool DecompressionStream::Next(const void** data, int*size) { + // If we are starting a new header, we will have to store its positions + // after decompressing. + bool saveBufferPositions = false; + // If the user pushed back or seeked within the same chunk. if (outputBufferLength) { *data = outputBuffer; *size = static_cast<int>(outputBufferLength); outputBuffer += outputBufferLength; + bytesReturned += static_cast<off_t>(outputBufferLength); outputBufferLength = 0; return true; } if (state == DECOMPRESS_HEADER || remainingLength == 0) { readHeader(); + // Here we already read the three bytes of the header. + headerPosition = inputBufferStartPosition + + static_cast<size_t>(inputBuffer - inputBufferStart) - 3; + saveBufferPositions = true; } if (state == DECOMPRESS_EOF) { return false; @@ -465,83 +488,44 @@ DIAGNOSTIC_POP if (inputBuffer == inputBufferEnd) { readBuffer(true); } - size_t availSize = + size_t availableSize = std::min(static_cast<size_t>(inputBufferEnd - inputBuffer), remainingLength); if (state == DECOMPRESS_ORIGINAL) { *data = inputBuffer; - *size = static_cast<int>(availSize); - outputBuffer = inputBuffer + availSize; + *size = static_cast<int>(availableSize); + outputBuffer = inputBuffer + availableSize; outputBufferLength = 0; + inputBuffer += availableSize; + remainingLength -= availableSize; } else if (state == DECOMPRESS_START) { - zstream.next_in = - reinterpret_cast<Bytef*>(const_cast<char*>(inputBuffer)); - zstream.avail_in = static_cast<uInt>(availSize); - outputBuffer = buffer.data(); - zstream.next_out = - reinterpret_cast<Bytef*>(const_cast<char*>(outputBuffer)); - zstream.avail_out = static_cast<uInt>(blockSize); - if (inflateReset(&zstream) != Z_OK) { - throw std::logic_error("Bad inflateReset in " - "ZlibDecompressionStream::Next"); - } - int64_t result; - do { - result = inflate(&zstream, availSize == remainingLength ? Z_FINISH : - Z_SYNC_FLUSH); - switch (result) { - case Z_OK: - remainingLength -= availSize; - inputBuffer += availSize; - readBuffer(true); - availSize = - std::min(static_cast<size_t>(inputBufferEnd - inputBuffer), - remainingLength); - zstream.next_in = - reinterpret_cast<Bytef*>(const_cast<char*>(inputBuffer)); - zstream.avail_in = static_cast<uInt>(availSize); - break; - case Z_STREAM_END: - break; - case Z_BUF_ERROR: - throw std::logic_error("Buffer error in " - "ZlibDecompressionStream::Next"); - case Z_DATA_ERROR: - throw std::logic_error("Data error in " - "ZlibDecompressionStream::Next"); - case Z_STREAM_ERROR: - throw std::logic_error("Stream error in " - "ZlibDecompressionStream::Next"); - default: - throw std::logic_error("Unknown error in " - "ZlibDecompressionStream::Next"); - } - } while (result != Z_STREAM_END); - *size = static_cast<int>(blockSize - zstream.avail_out); - *data = outputBuffer; - outputBufferLength = 0; - outputBuffer += *size; + NextDecompress(data, size, availableSize); } else { throw std::logic_error("Unknown compression state in " - "ZlibDecompressionStream::Next"); + "DecompressionStream::Next"); + } + bytesReturned += static_cast<off_t>(*size); + if (saveBufferPositions) { + uncompressedBufferLength = static_cast<size_t>(*size); + outputBufferStart = reinterpret_cast<const char*>(*data); } - inputBuffer += availSize; - remainingLength -= availSize; - bytesReturned += *size; return true; } - void ZlibDecompressionStream::BackUp(int count) { + void DecompressionStream::BackUp(int count) { if (outputBuffer == nullptr || outputBufferLength != 0) { - throw std::logic_error("Backup without previous Next in " - "ZlibDecompressionStream"); + throw std::logic_error("Backup without previous Next in " + getName()); } outputBuffer -= static_cast<size_t>(count); outputBufferLength = static_cast<size_t>(count); bytesReturned -= count; } - bool ZlibDecompressionStream::Skip(int count) { + int64_t DecompressionStream::ByteCount() const { + return bytesReturned; + } + + bool DecompressionStream::Skip(int count) { bytesReturned += count; // this is a stupid implementation for now. // should skip entire blocks without decompressing @@ -561,271 +545,266 @@ DIAGNOSTIC_POP return true; } - int64_t ZlibDecompressionStream::ByteCount() const { - return bytesReturned; - } - - void ZlibDecompressionStream::seek(PositionProvider& position) { - // clear state to force seek to read from the right position + /** There are four possible scenarios when seeking a position: + * 1. The chunk of the seeked position is the current chunk that has been read and + * decompressed. For uncompressed chunk, it could be partially read. So there are two + * sub-cases: + * a. The seeked position is inside the uncompressed buffer. + * b. The seeked position is outside the uncompressed buffer. + * 2. The chunk of the seeked position is read from the input stream, but has not been + * decompressed yet, ie. it's not in the output stream. + * 3. The chunk of the seeked position is not read yet from the input stream. + */ + void DecompressionStream::seek(PositionProvider& position) { + size_t seekedHeaderPosition = position.current(); + // Case 1: the seeked position is in the current chunk and it's buffered and + // decompressed/uncompressed. Note that after the headerPosition comes the 3 bytes of + // the header. + if (headerPosition == seekedHeaderPosition + && inputBufferStartPosition <= headerPosition + 3 && inputBufferStart) { + position.next(); // Skip the input level position, i.e. seekedHeaderPosition. + size_t posInChunk = position.next(); // Chunk level position. + // Case 1.a: The position is in the decompressed/uncompressed buffer. Here we only + // need to set the output buffer's pointer to the seeked position. + if (uncompressedBufferLength >= posInChunk) { + outputBufferLength = uncompressedBufferLength - posInChunk; + outputBuffer = outputBufferStart + posInChunk; + return; + } + // Case 1.b: The position is outside the decompressed/uncompressed buffer. + // Skip bytes to seek. + if (!Skip(static_cast<int>(posInChunk - uncompressedBufferLength))) { + std::ostringstream ss; + ss << "Bad seek to (chunkHeader=" << seekedHeaderPosition << ", posInChunk=" + << posInChunk << ") in " << getName() << ". DecompressionState: " + << decompressStateToString(state); + throw ParseError(ss.str()); + } + return; + } + // Clear state to prepare reading from a new chunk header. state = DECOMPRESS_HEADER; outputBuffer = nullptr; outputBufferLength = 0; remainingLength = 0; - inputBuffer = nullptr; - inputBufferEnd = nullptr; - - input->seek(position); + if (seekedHeaderPosition < static_cast<uint64_t>(input->ByteCount()) && + seekedHeaderPosition >= inputBufferStartPosition) { + // Case 2: The input is buffered, but not yet decompressed. No need to + // force re-reading the inputBuffer, we just have to move it to the + // seeked position. + position.next(); // Skip the input level position. + inputBuffer + = inputBufferStart + (seekedHeaderPosition - inputBufferStartPosition); + } else { + // Case 3: The seeked position is not in the input buffer, here we are + // forcing to read it. + inputBuffer = nullptr; + inputBufferEnd = nullptr; + input->seek(position); // Actually use the input level position. + } bytesReturned = static_cast<off_t>(input->ByteCount()); if (!Skip(static_cast<int>(position.next()))) { - throw ParseError("Bad skip in ZlibDecompressionStream::seek"); + throw ParseError("Bad skip in " + getName()); + } + } + + class ZlibDecompressionStream : public DecompressionStream { + public: + ZlibDecompressionStream(std::unique_ptr<SeekableInputStream> inStream, + size_t blockSize, + MemoryPool& pool); + virtual ~ZlibDecompressionStream() override; + virtual std::string getName() const override; + + protected: + virtual void NextDecompress(const void** data, + int* size, + size_t availableSize) override; + private: + z_stream zstream; + }; + +DIAGNOSTIC_PUSH + +#if defined(__GNUC__) || defined(__clang__) + DIAGNOSTIC_IGNORE("-Wold-style-cast") +#endif + + ZlibDecompressionStream::ZlibDecompressionStream + (std::unique_ptr<SeekableInputStream> inStream, + size_t bufferSize, + MemoryPool& _pool + ): DecompressionStream + (std::move(inStream), bufferSize, _pool) { + zstream.next_in = nullptr; + zstream.avail_in = 0; + zstream.zalloc = nullptr; + zstream.zfree = nullptr; + zstream.opaque = nullptr; + zstream.next_out = reinterpret_cast<Bytef*>(outputDataBuffer.data()); + zstream.avail_out = static_cast<uInt>(outputDataBuffer.capacity()); + int64_t result = inflateInit2(&zstream, -15); + switch (result) { + case Z_OK: + break; + case Z_MEM_ERROR: + throw std::logic_error("Memory error from inflateInit2"); + case Z_VERSION_ERROR: + throw std::logic_error("Version error from inflateInit2"); + case Z_STREAM_ERROR: + throw std::logic_error("Stream error from inflateInit2"); + default: + throw std::logic_error("Unknown error from inflateInit2"); } } +DIAGNOSTIC_POP + + ZlibDecompressionStream::~ZlibDecompressionStream() { + int64_t result = inflateEnd(&zstream); + if (result != Z_OK) { + // really can't throw in destructors + std::cout << "Error in ~ZlibDecompressionStream() " << result << "\n"; + } + } + + void ZlibDecompressionStream::NextDecompress(const void** data, int* size, + size_t availableSize) { + zstream.next_in = + reinterpret_cast<Bytef*>(const_cast<char*>(inputBuffer)); + zstream.avail_in = static_cast<uInt>(availableSize); + outputBuffer = outputDataBuffer.data(); + zstream.next_out = + reinterpret_cast<Bytef*>(const_cast<char*>(outputBuffer)); + zstream.avail_out = static_cast<uInt>(outputDataBuffer.capacity()); + if (inflateReset(&zstream) != Z_OK) { + throw std::logic_error("Bad inflateReset in " + "ZlibDecompressionStream::NextDecompress"); + } + int64_t result; + do { + result = inflate(&zstream, availableSize == remainingLength ? Z_FINISH : + Z_SYNC_FLUSH); + switch (result) { + case Z_OK: + remainingLength -= availableSize; + inputBuffer += availableSize; + readBuffer(true); + availableSize = + std::min(static_cast<size_t>(inputBufferEnd - inputBuffer), + remainingLength); + zstream.next_in = + reinterpret_cast<Bytef*>(const_cast<char*>(inputBuffer)); + zstream.avail_in = static_cast<uInt>(availableSize); + break; + case Z_STREAM_END: + break; + case Z_BUF_ERROR: + throw std::logic_error("Buffer error in " + "ZlibDecompressionStream::NextDecompress"); + case Z_DATA_ERROR: + throw std::logic_error("Data error in " + "ZlibDecompressionStream::NextDecompress"); + case Z_STREAM_ERROR: + throw std::logic_error("Stream error in " + "ZlibDecompressionStream::NextDecompress"); + default: + throw std::logic_error("Unknown error in " + "ZlibDecompressionStream::NextDecompress"); + } + } while (result != Z_STREAM_END); + *size = static_cast<int>(outputDataBuffer.capacity() - zstream.avail_out); + *data = outputBuffer; + outputBufferLength = 0; + outputBuffer += *size; + inputBuffer += availableSize; + remainingLength -= availableSize; + } + std::string ZlibDecompressionStream::getName() const { std::ostringstream result; result << "zlib(" << input->getName() << ")"; return result.str(); } - class BlockDecompressionStream: public SeekableInputStream { + class BlockDecompressionStream: public DecompressionStream { public: BlockDecompressionStream(std::unique_ptr<SeekableInputStream> inStream, size_t blockSize, MemoryPool& pool); virtual ~BlockDecompressionStream() override {} - virtual bool Next(const void** data, int*size) override; - virtual void BackUp(int count) override; - virtual bool Skip(int count) override; - virtual int64_t ByteCount() const override; - virtual void seek(PositionProvider& position) override; virtual std::string getName() const override = 0; protected: + virtual void NextDecompress(const void** data, + int* size, + size_t availableSize) override; + virtual uint64_t decompress(const char *input, uint64_t length, char *output, size_t maxOutputLength) = 0; - - std::string getStreamName() const { - return input->getName(); - } - private: - void readBuffer(bool failOnEof) { - int length; - if (!input->Next(reinterpret_cast<const void**>(&inputBufferPtr), - &length)) { - if (failOnEof) { - throw ParseError(getName() + "read past EOF"); - } - state = DECOMPRESS_EOF; - inputBufferPtr = nullptr; - inputBufferPtrEnd = nullptr; - } else { - inputBufferPtrEnd = inputBufferPtr + length; - } - } - - uint32_t readByte(bool failOnEof) { - if (inputBufferPtr == inputBufferPtrEnd) { - readBuffer(failOnEof); - if (state == DECOMPRESS_EOF) { - return 0; - } - } - return static_cast<unsigned char>(*(inputBufferPtr++)); - } - - void readHeader() { - uint32_t header = readByte(false); - if (state != DECOMPRESS_EOF) { - header |= readByte(true) << 8; - header |= readByte(true) << 16; - if (header & 1) { - state = DECOMPRESS_ORIGINAL; - } else { - state = DECOMPRESS_START; - } - remainingLength = header >> 1; - } else { - remainingLength = 0; - } - } - - std::unique_ptr<SeekableInputStream> input; - MemoryPool& pool; - // may need to stitch together multiple input buffers; // to give snappy a contiguous block - DataBuffer<char> inputBuffer; - - // uncompressed output - DataBuffer<char> outputBuffer; - - // the current state - DecompressState state; - - // the start of the current output buffer - const char* outputBufferPtr; - // the size of the current output buffer - size_t outputBufferLength; - - // the size of the current chunk - size_t remainingLength; - - // the last buffer returned from the input - const char *inputBufferPtr; - const char *inputBufferPtrEnd; - - // bytes returned by this stream - off_t bytesReturned; + DataBuffer<char> inputDataBuffer; }; BlockDecompressionStream::BlockDecompressionStream (std::unique_ptr<SeekableInputStream> inStream, - size_t bufferSize, + size_t blockSize, MemoryPool& _pool - ) : pool(_pool), - inputBuffer(pool, bufferSize), - outputBuffer(pool, bufferSize), - state(DECOMPRESS_HEADER), - outputBufferPtr(nullptr), - outputBufferLength(0), - remainingLength(0), - inputBufferPtr(nullptr), - inputBufferPtrEnd(nullptr), - bytesReturned(0) { - input.reset(inStream.release()); - } - - bool BlockDecompressionStream::Next(const void** data, int*size) { - // if the user pushed back, return them the partial buffer - if (outputBufferLength) { - *data = outputBufferPtr; - *size = static_cast<int>(outputBufferLength); - outputBufferPtr += outputBufferLength; - bytesReturned += static_cast<off_t>(outputBufferLength); - outputBufferLength = 0; - return true; - } - if (state == DECOMPRESS_HEADER || remainingLength == 0) { - readHeader(); - } - if (state == DECOMPRESS_EOF) { - return false; - } - if (inputBufferPtr == inputBufferPtrEnd) { - readBuffer(true); - } - - size_t availSize = - std::min(static_cast<size_t>(inputBufferPtrEnd - inputBufferPtr), - remainingLength); - if (state == DECOMPRESS_ORIGINAL) { - *data = inputBufferPtr; - *size = static_cast<int>(availSize); - outputBufferPtr = inputBufferPtr + availSize; - outputBufferLength = 0; - inputBufferPtr += availSize; - remainingLength -= availSize; - } else if (state == DECOMPRESS_START) { - // Get contiguous bytes of compressed block. - const char *compressed = inputBufferPtr; - if (remainingLength == availSize) { - inputBufferPtr += availSize; - } else { - // Did not read enough from input. - if (inputBuffer.capacity() < remainingLength) { - inputBuffer.resize(remainingLength); - } - ::memcpy(inputBuffer.data(), inputBufferPtr, availSize); - inputBufferPtr += availSize; - compressed = inputBuffer.data(); - - for (size_t pos = availSize; pos < remainingLength; ) { - readBuffer(true); - size_t avail = - std::min(static_cast<size_t>(inputBufferPtrEnd - - inputBufferPtr), - remainingLength - pos); - ::memcpy(inputBuffer.data() + pos, inputBufferPtr, avail); - pos += avail; - inputBufferPtr += avail; - } - } - - outputBufferLength = decompress(compressed, remainingLength, - outputBuffer.data(), - outputBuffer.capacity()); - - remainingLength = 0; - state = DECOMPRESS_HEADER; - *data = outputBuffer.data(); - *size = static_cast<int>(outputBufferLength); - outputBufferPtr = outputBuffer.data() + outputBufferLength; - outputBufferLength = 0; - } - - bytesReturned += *size; - return true; + ) : DecompressionStream + (std::move(inStream), blockSize, _pool), + inputDataBuffer(pool, blockSize) { } - void BlockDecompressionStream::BackUp(int count) { - if (outputBufferPtr == nullptr || outputBufferLength != 0) { - throw std::logic_error("Backup without previous Next in "+getName()); - } - outputBufferPtr -= static_cast<size_t>(count); - outputBufferLength = static_cast<size_t>(count); - bytesReturned -= count; - } - bool BlockDecompressionStream::Skip(int count) { - bytesReturned += count; - // this is a stupid implementation for now. - // should skip entire blocks without decompressing - while (count > 0) { - const void *ptr; - int len; - if (!Next(&ptr, &len)) { - return false; + void BlockDecompressionStream::NextDecompress(const void** data, int* size, + size_t availableSize) { + // Get contiguous bytes of compressed block. + const char *compressed = inputBuffer; + if (remainingLength == availableSize) { + inputBuffer += availableSize; + } else { + // Did not read enough from input. + if (inputDataBuffer.capacity() < remainingLength) { + inputDataBuffer.resize(remainingLength); } - if (len > count) { - BackUp(len - count); - count = 0; - } else { - count -= len; + ::memcpy(inputDataBuffer.data(), inputBuffer, availableSize); + inputBuffer += availableSize; + compressed = inputDataBuffer.data(); + + for (size_t pos = availableSize; pos < remainingLength; ) { + readBuffer(true); + size_t avail = + std::min(static_cast<size_t>(inputBufferEnd - + inputBuffer), + remainingLength - pos); + ::memcpy(inputDataBuffer.data() + pos, inputBuffer, avail); + pos += avail; + inputBuffer += avail; } } - return true; - } - - int64_t BlockDecompressionStream::ByteCount() const { - return bytesReturned; - } - - void BlockDecompressionStream::seek(PositionProvider& position) { - // clear state to force seek to read from the right position + outputBufferLength = decompress(compressed, remainingLength, + outputDataBuffer.data(), + outputDataBuffer.capacity()); + remainingLength = 0; state = DECOMPRESS_HEADER; - outputBufferPtr = nullptr; + *data = outputDataBuffer.data(); + *size = static_cast<int>(outputBufferLength); + outputBuffer = outputDataBuffer.data() + outputBufferLength; outputBufferLength = 0; - remainingLength = 0; - inputBufferPtr = nullptr; - inputBufferPtrEnd = nullptr; - - input->seek(position); - if (!Skip(static_cast<int>(position.next()))) { - throw ParseError("Bad skip in " + getName()); - } } class SnappyDecompressionStream: public BlockDecompressionStream { public: SnappyDecompressionStream(std::unique_ptr<SeekableInputStream> inStream, size_t blockSize, - MemoryPool& pool + MemoryPool& _pool ): BlockDecompressionStream (std::move(inStream), blockSize, - pool) { + _pool) { // PASS } @@ -841,12 +820,12 @@ DIAGNOSTIC_POP ) override; }; - uint64_t SnappyDecompressionStream::decompress(const char *input, + uint64_t SnappyDecompressionStream::decompress(const char *_input, uint64_t length, char *output, size_t maxOutputLength) { size_t outLength; - if (!snappy::GetUncompressedLength(input, length, &outLength)) { + if (!snappy::GetUncompressedLength(_input, length, &outLength)) { throw ParseError("SnappyDecompressionStream choked on corrupt input"); } @@ -854,7 +833,7 @@ DIAGNOSTIC_POP throw std::logic_error("Snappy length exceeds block size"); } - if (!snappy::RawUncompress(input, length, output)) { + if (!snappy::RawUncompress(_input, length, output)) { throw ParseError("SnappyDecompressionStream choked on corrupt input"); } return outLength; @@ -864,11 +843,11 @@ DIAGNOSTIC_POP public: LzoDecompressionStream(std::unique_ptr<SeekableInputStream> inStream, size_t blockSize, - MemoryPool& pool + MemoryPool& _pool ): BlockDecompressionStream - (std::move(inStream), - blockSize, - pool) { + (std::move(inStream), + blockSize, + _pool) { // PASS } @@ -884,11 +863,11 @@ DIAGNOSTIC_POP ) override; }; - uint64_t LzoDecompressionStream::decompress(const char *input, + uint64_t LzoDecompressionStream::decompress(const char *inputPtr, uint64_t length, char *output, size_t maxOutputLength) { - return lzoDecompress(input, input + length, output, + return lzoDecompress(inputPtr, inputPtr + length, output, output + maxOutputLength); } @@ -896,11 +875,11 @@ DIAGNOSTIC_POP public: Lz4DecompressionStream(std::unique_ptr<SeekableInputStream> inStream, size_t blockSize, - MemoryPool& pool + MemoryPool& _pool ): BlockDecompressionStream (std::move(inStream), blockSize, - pool) { + _pool) { // PASS } @@ -916,11 +895,11 @@ DIAGNOSTIC_POP ) override; }; - uint64_t Lz4DecompressionStream::decompress(const char *input, + uint64_t Lz4DecompressionStream::decompress(const char *inputPtr, uint64_t length, char *output, size_t maxOutputLength) { - int result = LZ4_decompress_safe(input, output, static_cast<int>(length), + int result = LZ4_decompress_safe(inputPtr, output, static_cast<int>(length), static_cast<int>(maxOutputLength)); if (result < 0) { throw ParseError(getName() + " - failed to decompress"); @@ -1017,6 +996,113 @@ DIAGNOSTIC_POP } /** + * LZ4 block compression + */ + class Lz4CompressionSteam: public BlockCompressionStream { + public: + Lz4CompressionSteam(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : BlockCompressionStream(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + this->init(); + } + + virtual std::string getName() const override { + return "Lz4CompressionStream"; + } + + virtual ~Lz4CompressionSteam() override { + this->end(); + } + + protected: + virtual uint64_t doBlockCompression() override; + + virtual uint64_t estimateMaxCompressionSize() override { + return static_cast<uint64_t>(LZ4_compressBound(bufferSize)); + } + + private: + void init(); + void end(); + LZ4_stream_t *state; + }; + + uint64_t Lz4CompressionSteam::doBlockCompression() { + int result = LZ4_compress_fast_extState(static_cast<void*>(state), + reinterpret_cast<const char*>(rawInputBuffer.data()), + reinterpret_cast<char*>(compressorBuffer.data()), + bufferSize, + static_cast<int>(compressorBuffer.size()), + level); + if (result == 0) { + throw std::runtime_error("Error during block compression using lz4."); + } + return static_cast<uint64_t>(result); + } + + void Lz4CompressionSteam::init() { + state = LZ4_createStream(); + if (!state) { + throw std::runtime_error("Error while allocating state for lz4."); + } + } + + void Lz4CompressionSteam::end() { + (void)LZ4_freeStream(state); + state = nullptr; + } + + /** + * Snappy block compression + */ + class SnappyCompressionStream: public BlockCompressionStream { + public: + SnappyCompressionStream(OutputStream * outStream, + int compressionLevel, + uint64_t capacity, + uint64_t blockSize, + MemoryPool& pool) + : BlockCompressionStream(outStream, + compressionLevel, + capacity, + blockSize, + pool) { + } + + virtual std::string getName() const override { + return "SnappyCompressionStream"; + } + + virtual ~SnappyCompressionStream() override { + // PASS + } + + protected: + virtual uint64_t doBlockCompression() override; + + virtual uint64_t estimateMaxCompressionSize() override { + return static_cast<uint64_t> + (snappy::MaxCompressedLength(static_cast<size_t>(bufferSize))); + } + }; + + uint64_t SnappyCompressionStream::doBlockCompression() { + size_t compressedLength; + snappy::RawCompress(reinterpret_cast<const char*>(rawInputBuffer.data()), + static_cast<size_t>(bufferSize), + reinterpret_cast<char*>(compressorBuffer.data()), + &compressedLength); + return static_cast<uint64_t>(compressedLength); + } + + /** * ZSTD block compression */ class ZSTDCompressionStream: public BlockCompressionStream { @@ -1093,10 +1179,10 @@ DIAGNOSTIC_PUSH public: ZSTDDecompressionStream(std::unique_ptr<SeekableInputStream> inStream, size_t blockSize, - MemoryPool& pool) + MemoryPool& _pool) : BlockDecompressionStream(std::move(inStream), blockSize, - pool) { + _pool) { this->init(); } @@ -1122,14 +1208,14 @@ DIAGNOSTIC_PUSH ZSTD_DCtx *dctx; }; - uint64_t ZSTDDecompressionStream::decompress(const char *input, + uint64_t ZSTDDecompressionStream::decompress(const char *inputPtr, uint64_t length, char *output, size_t maxOutputLength) { return static_cast<uint64_t>(ZSTD_decompressDCtx(dctx, output, maxOutputLength, - input, + inputPtr, length)); } @@ -1183,9 +1269,20 @@ DIAGNOSTIC_PUSH (new ZSTDCompressionStream( outStream, level, bufferCapacity, compressionBlockSize, pool)); } - case CompressionKind_SNAPPY: + case CompressionKind_LZ4: { + int level = (strategy == CompressionStrategy_SPEED) ? + LZ4_ACCELERATION_MAX : LZ4_ACCELERATION_DEFAULT; + return std::unique_ptr<BufferedOutputStream> + (new Lz4CompressionSteam( + outStream, level, bufferCapacity, compressionBlockSize, pool)); + } + case CompressionKind_SNAPPY: { + int level = 0; + return std::unique_ptr<BufferedOutputStream> + (new SnappyCompressionStream( + outStream, level, bufferCapacity, compressionBlockSize, pool)); + } case CompressionKind_LZO: - case CompressionKind_LZ4: default: throw NotImplementedYet("compression codec"); } diff --git a/contrib/libs/apache/orc/c++/src/Int128.cc b/contrib/libs/apache/orc/c++/src/Int128.cc index 433e6fa1936..4ff500fbaca 100644 --- a/contrib/libs/apache/orc/c++/src/Int128.cc +++ b/contrib/libs/apache/orc/c++/src/Int128.cc @@ -391,41 +391,51 @@ namespace orc { return buf.str(); } - std::string Int128::toDecimalString(int32_t scale) const { + std::string Int128::toDecimalString(int32_t scale, bool trimTrailingZeros) const { std::string str = toString(); + std::string result; if (scale == 0) { return str; } else if (*this < 0) { int32_t len = static_cast<int32_t>(str.length()); if (len - 1 > scale) { - return str.substr(0, static_cast<size_t>(len - scale)) + "." + - str.substr(static_cast<size_t>(len - scale), - static_cast<size_t>(scale)); + result = str.substr(0, static_cast<size_t>(len - scale)) + "." + + str.substr(static_cast<size_t>(len - scale), + static_cast<size_t>(len)); } else if (len - 1 == scale) { - return "-0." + str.substr(1, std::string::npos); + result = "-0." + str.substr(1, std::string::npos); } else { - std::string result = "-0."; - for(int32_t i=0; i < scale - len + 1; ++i) { + result = "-0."; + for (int32_t i = 0; i < scale - len + 1; ++i) { result += "0"; } - return result + str.substr(1, std::string::npos); + result += str.substr(1, std::string::npos); } } else { int32_t len = static_cast<int32_t>(str.length()); if (len > scale) { - return str.substr(0, static_cast<size_t>(len - scale)) + "." + - str.substr(static_cast<size_t>(len - scale), - static_cast<size_t>(scale)); + result = str.substr(0, static_cast<size_t>(len - scale)) + "." + + str.substr(static_cast<size_t>(len - scale), + static_cast<size_t>(len)); } else if (len == scale) { - return "0." + str; + result = "0." + str; } else { - std::string result = "0."; - for(int32_t i=0; i < scale - len; ++i) { + result = "0."; + for (int32_t i = 0; i < scale - len; ++i) { result += "0"; } - return result + str; + result += str; } } + if (trimTrailingZeros) { + size_t pos = result.find_last_not_of('0'); + if (result[pos] == '.') { + result = result.substr(0, pos); + } else { + result = result.substr(0, pos + 1); + } + } + return result; } std::string Int128::toHexString() const { diff --git a/contrib/libs/apache/orc/c++/src/LzoDecompressor.cc b/contrib/libs/apache/orc/c++/src/LzoDecompressor.cc index d1ba183aebb..21bf194fed6 100644 --- a/contrib/libs/apache/orc/c++/src/LzoDecompressor.cc +++ b/contrib/libs/apache/orc/c++/src/LzoDecompressor.cc @@ -312,13 +312,11 @@ namespace orc { output += SIZE_OF_INT; matchAddress += increment32; - *reinterpret_cast<int32_t*>(output) = - *reinterpret_cast<int32_t*>(matchAddress); + memcpy(output, matchAddress, SIZE_OF_INT); output += SIZE_OF_INT; matchAddress -= decrement64; } else { - *reinterpret_cast<int64_t*>(output) = - *reinterpret_cast<int64_t*>(matchAddress); + memcpy(output, matchAddress, SIZE_OF_LONG); matchAddress += SIZE_OF_LONG; output += SIZE_OF_LONG; } @@ -329,8 +327,7 @@ namespace orc { } while (output < fastOutputLimit) { - *reinterpret_cast<int64_t*>(output) = - *reinterpret_cast<int64_t*>(matchAddress); + memcpy(output, matchAddress, SIZE_OF_LONG); matchAddress += SIZE_OF_LONG; output += SIZE_OF_LONG; } @@ -340,8 +337,7 @@ namespace orc { } } else { while (output < matchOutputLimit) { - *reinterpret_cast<int64_t*>(output) = - *reinterpret_cast<int64_t*>(matchAddress); + memcpy(output, matchAddress, SIZE_OF_LONG); matchAddress += SIZE_OF_LONG; output += SIZE_OF_LONG; } @@ -366,8 +362,7 @@ namespace orc { // fast copy. We may over-copy but there's enough room in input // and output to not overrun them do { - *reinterpret_cast<int64_t*>(output) = - *reinterpret_cast<const int64_t*>(input); + memcpy(output, input, SIZE_OF_LONG); input += SIZE_OF_LONG; output += SIZE_OF_LONG; } while (output < literalOutputLimit); diff --git a/contrib/libs/apache/orc/c++/src/Options.hh b/contrib/libs/apache/orc/c++/src/Options.hh index 795e166138f..d8331b3c0a7 100644 --- a/contrib/libs/apache/orc/c++/src/Options.hh +++ b/contrib/libs/apache/orc/c++/src/Options.hh @@ -64,9 +64,7 @@ namespace orc { ReaderOptions::ReaderOptions(ReaderOptions& rhs) { // swap privateBits with rhs - ReaderOptionsPrivate* l = privateBits.release(); - privateBits.reset(rhs.privateBits.release()); - rhs.privateBits.reset(l); + privateBits.swap(rhs.privateBits); } ReaderOptions& ReaderOptions::operator=(const ReaderOptions& rhs) { @@ -130,6 +128,9 @@ namespace orc { bool throwOnHive11DecimalOverflow; int32_t forcedScaleOnHive11Decimal; bool enableLazyDecoding; + std::shared_ptr<SearchArgument> sargs; + std::string readerTimezone; + RowReaderOptions::IdReadIntentMap idReadIntentMap; RowReaderOptionsPrivate() { selection = ColumnSelection_NONE; @@ -138,6 +139,7 @@ namespace orc { throwOnHive11DecimalOverflow = true; forcedScaleOnHive11Decimal = 6; enableLazyDecoding = false; + readerTimezone = "GMT"; } }; @@ -155,9 +157,7 @@ namespace orc { RowReaderOptions::RowReaderOptions(RowReaderOptions& rhs) { // swap privateBits with rhs - RowReaderOptionsPrivate* l = privateBits.release(); - privateBits.reset(rhs.privateBits.release()); - rhs.privateBits.reset(l); + privateBits.swap(rhs.privateBits); } RowReaderOptions& RowReaderOptions::operator=(const RowReaderOptions& rhs) { @@ -175,6 +175,7 @@ namespace orc { privateBits->selection = ColumnSelection_FIELD_IDS; privateBits->includedColumnIndexes.assign(include.begin(), include.end()); privateBits->includedColumnNames.clear(); + privateBits->idReadIntentMap.clear(); return *this; } @@ -182,6 +183,7 @@ namespace orc { privateBits->selection = ColumnSelection_NAMES; privateBits->includedColumnNames.assign(include.begin(), include.end()); privateBits->includedColumnIndexes.clear(); + privateBits->idReadIntentMap.clear(); return *this; } @@ -189,6 +191,20 @@ namespace orc { privateBits->selection = ColumnSelection_TYPE_IDS; privateBits->includedColumnIndexes.assign(types.begin(), types.end()); privateBits->includedColumnNames.clear(); + privateBits->idReadIntentMap.clear(); + return *this; + } + + RowReaderOptions& + RowReaderOptions::includeTypesWithIntents(const IdReadIntentMap& idReadIntentMap) { + privateBits->selection = ColumnSelection_TYPE_IDS; + privateBits->includedColumnIndexes.clear(); + privateBits->idReadIntentMap.clear(); + for (const auto& typeIntentPair : idReadIntentMap) { + privateBits->idReadIntentMap[typeIntentPair.first] = typeIntentPair.second; + privateBits->includedColumnIndexes.push_back(typeIntentPair.first); + } + privateBits->includedColumnNames.clear(); return *this; } @@ -253,6 +269,29 @@ namespace orc { privateBits->enableLazyDecoding = enable; return *this; } + + RowReaderOptions& RowReaderOptions::searchArgument(std::unique_ptr<SearchArgument> sargs) { + privateBits->sargs = std::move(sargs); + return *this; + } + + std::shared_ptr<SearchArgument> RowReaderOptions::getSearchArgument() const { + return privateBits->sargs; + } + + RowReaderOptions& RowReaderOptions::setTimezoneName(const std::string& zoneName) { + privateBits->readerTimezone = zoneName; + return *this; + } + + const std::string& RowReaderOptions::getTimezoneName() const { + return privateBits->readerTimezone; + } + + const RowReaderOptions::IdReadIntentMap + RowReaderOptions::getIdReadIntentMap() const { + return privateBits->idReadIntentMap; + } } #endif diff --git a/contrib/libs/apache/orc/c++/src/RLEv2.hh b/contrib/libs/apache/orc/c++/src/RLEv2.hh index f85dabd9e6e..b1e68fb125e 100644 --- a/contrib/libs/apache/orc/c++/src/RLEv2.hh +++ b/contrib/libs/apache/orc/c++/src/RLEv2.hh @@ -25,6 +25,7 @@ #include <vector> +#define MAX_LITERAL_SIZE 512 #define MIN_REPEAT 3 #define HIST_LEN 32 namespace orc { @@ -93,6 +94,7 @@ private: int64_t* adjDeltas; uint32_t getOpCode(EncodingType encoding); + int64_t* prepareForDirectOrPatchedBase(EncodingOption& option); void determineEncoding(EncodingOption& option); void computeZigZagLiterals(EncodingOption& option); void preparePatchedBlob(EncodingOption& option); @@ -130,25 +132,18 @@ public: private: - // Used by PATCHED_BASE - void adjustGapAndPatch() { - curGap = static_cast<uint64_t>(unpackedPatch[patchIdx]) >> - patchBitSize; - curPatch = unpackedPatch[patchIdx] & patchMask; - actualGap = 0; - - // special case: gap is >255 then patch value will be 0. - // if gap is <=255 then patch value cannot be 0 - while (curGap == 255 && curPatch == 0) { - actualGap += 255; - ++patchIdx; - curGap = static_cast<uint64_t>(unpackedPatch[patchIdx]) >> - patchBitSize; - curPatch = unpackedPatch[patchIdx] & patchMask; - } - // add the left over gap - actualGap += curGap; - } + /** + * Decode the next gap and patch from 'unpackedPatch' and update the index on it. + * Used by PATCHED_BASE. + * + * @param patchBitSize bit size of the patch value + * @param patchMask mask for the patch value + * @param resGap result of gap + * @param resPatch result of patch + * @param patchIdx current index in the 'unpackedPatch' buffer + */ + void adjustGapAndPatch(uint32_t patchBitSize, int64_t patchMask, + int64_t* resGap, int64_t* resPatch, uint64_t* patchIdx); void resetReadLongs() { bitsLeft = 0; @@ -157,59 +152,25 @@ private: void resetRun() { resetReadLongs(); - bitSize = 0; } - unsigned char readByte() { - if (bufferStart == bufferEnd) { - int bufferLength; - const void* bufferPointer; - if (!inputStream->Next(&bufferPointer, &bufferLength)) { - throw ParseError("bad read in RleDecoderV2::readByte"); - } - bufferStart = static_cast<const char*>(bufferPointer); - bufferEnd = bufferStart + bufferLength; - } - - unsigned char result = static_cast<unsigned char>(*bufferStart++); - return result; -} + unsigned char readByte(); int64_t readLongBE(uint64_t bsz); int64_t readVslong(); uint64_t readVulong(); - uint64_t readLongs(int64_t *data, uint64_t offset, uint64_t len, - uint64_t fb, const char* notNull = nullptr) { - uint64_t ret = 0; - - // TODO: unroll to improve performance - for(uint64_t i = offset; i < (offset + len); i++) { - // skip null positions - if (notNull && !notNull[i]) { - continue; - } - uint64_t result = 0; - uint64_t bitsLeftToRead = fb; - while (bitsLeftToRead > bitsLeft) { - result <<= bitsLeft; - result |= curByte & ((1 << bitsLeft) - 1); - bitsLeftToRead -= bitsLeft; - curByte = readByte(); - bitsLeft = 8; - } - - // handle the left over bits - if (bitsLeftToRead > 0) { - result <<= bitsLeftToRead; - bitsLeft -= static_cast<uint32_t>(bitsLeftToRead); - result |= (curByte >> bitsLeft) & ((1 << bitsLeftToRead) - 1); - } - data[i] = static_cast<int64_t>(result); - ++ret; - } - - return ret; -} + void readLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs); + void plainUnpackLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs); + + void unrolledUnpack4(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack8(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack16(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack24(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack32(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack40(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack48(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack56(int64_t *data, uint64_t offset, uint64_t len); + void unrolledUnpack64(int64_t *data, uint64_t offset, uint64_t len); uint64_t nextShortRepeats(int64_t* data, uint64_t offset, uint64_t numValues, const char* notNull); @@ -220,31 +181,21 @@ private: uint64_t nextDelta(int64_t* data, uint64_t offset, uint64_t numValues, const char* notNull); + uint64_t copyDataFromBuffer(int64_t* data, uint64_t offset, uint64_t numValues, + const char* notNull); + const std::unique_ptr<SeekableInputStream> inputStream; const bool isSigned; unsigned char firstByte; - uint64_t runLength; - uint64_t runRead; + uint64_t runLength; // Length of the current run + uint64_t runRead; // Number of returned values of the current run const char *bufferStart; const char *bufferEnd; - int64_t deltaBase; // Used by DELTA - uint64_t byteSize; // Used by SHORT_REPEAT and PATCHED_BASE - int64_t firstValue; // Used by SHORT_REPEAT and DELTA - int64_t prevValue; // Used by DELTA - uint32_t bitSize; // Used by DIRECT, PATCHED_BASE and DELTA - uint32_t bitsLeft; // Used by anything that uses readLongs + uint32_t bitsLeft; // Used by readLongs when bitSize < 8 uint32_t curByte; // Used by anything that uses readLongs - uint32_t patchBitSize; // Used by PATCHED_BASE - uint64_t unpackedIdx; // Used by PATCHED_BASE - uint64_t patchIdx; // Used by PATCHED_BASE - int64_t base; // Used by PATCHED_BASE - uint64_t curGap; // Used by PATCHED_BASE - int64_t curPatch; // Used by PATCHED_BASE - int64_t patchMask; // Used by PATCHED_BASE - int64_t actualGap; // Used by PATCHED_BASE - DataBuffer<int64_t> unpacked; // Used by PATCHED_BASE DataBuffer<int64_t> unpackedPatch; // Used by PATCHED_BASE + DataBuffer<int64_t> literals; // Values of the current run }; } // namespace orc diff --git a/contrib/libs/apache/orc/c++/src/Reader.cc b/contrib/libs/apache/orc/c++/src/Reader.cc index f35106ee44f..6a9068f2022 100644 --- a/contrib/libs/apache/orc/c++/src/Reader.cc +++ b/contrib/libs/apache/orc/c++/src/Reader.cc @@ -35,6 +35,15 @@ #include <set> namespace orc { + // ORC files writen by these versions of cpp writers have inconsistent bloom filter + // hashing. Bloom filters of them should not be used. + static const char* BAD_CPP_BLOOM_FILTER_VERSIONS[] = { + "1.6.0", "1.6.1", "1.6.2", "1.6.3", "1.6.4", "1.6.5", "1.6.6", "1.6.7", "1.6.8", + "1.6.9", "1.6.10", "1.6.11", "1.7.0"}; + + const RowReaderOptions::IdReadIntentMap EMPTY_IDREADINTENTMAP() { + return {}; + } const WriterVersionImpl &WriterVersionImpl::VERSION_HIVE_8732() { static const WriterVersionImpl version(WriterVersion_HIVE_8732); @@ -68,13 +77,38 @@ namespace orc { return columnPath.substr(0, columnPath.length() - 1); } + WriterVersion getWriterVersionImpl(const FileContents * contents) { + if (!contents->postscript->has_writerversion()) { + return WriterVersion_ORIGINAL; + } + return static_cast<WriterVersion>(contents->postscript->writerversion()); + } void ColumnSelector::selectChildren(std::vector<bool>& selectedColumns, const Type& type) { + return selectChildren(selectedColumns, type, EMPTY_IDREADINTENTMAP()); + } + + void ColumnSelector::selectChildren( + std::vector<bool> &selectedColumns, + const Type &type, + const RowReaderOptions::IdReadIntentMap& idReadIntentMap) { size_t id = static_cast<size_t>(type.getColumnId()); + TypeKind kind = type.getKind(); if (!selectedColumns[id]) { selectedColumns[id] = true; - for(size_t c = id; c <= type.getMaximumColumnId(); ++c){ - selectedColumns[c] = true; + bool selectChild = true; + if (kind == TypeKind::LIST || kind == TypeKind::MAP || kind == TypeKind::UNION) { + auto elem = idReadIntentMap.find(id); + if (elem != idReadIntentMap.end() && + elem->second == ReadIntent_OFFSETS) { + selectChild = false; + } + } + + if (selectChild) { + for (size_t c = id; c <= type.getMaximumColumnId(); ++c) { + selectedColumns[c] = true; + } } } } @@ -86,10 +120,24 @@ namespace orc { bool ColumnSelector::selectParents(std::vector<bool>& selectedColumns, const Type& type) { size_t id = static_cast<size_t>(type.getColumnId()); bool result = selectedColumns[id]; + uint64_t numSubtypeSelected = 0; for(uint64_t c=0; c < type.getSubtypeCount(); ++c) { - result |= selectParents(selectedColumns, *type.getSubtype(c)); + if (selectParents(selectedColumns, *type.getSubtype(c))) { + result = true; + numSubtypeSelected++; + } } selectedColumns[id] = result; + + if (type.getKind() == TypeKind::UNION && selectedColumns[id]) { + if (0 < numSubtypeSelected && numSubtypeSelected < type.getSubtypeCount()) { + // Subtypes of UNION should be fully selected or not selected at all. + // Override partial subtype selections with full selections. + for (uint64_t c = 0; c < type.getSubtypeCount(); ++c) { + selectChildren(selectedColumns, *type.getSubtype(c)); + } + } + } return result; } @@ -131,9 +179,11 @@ namespace orc { updateSelectedByName(selectedColumns, *field); } } else if (options.getTypeIdsSet()) { + const RowReaderOptions::IdReadIntentMap idReadIntentMap = + options.getIdReadIntentMap(); for(std::list<uint64_t>::const_iterator typeId = options.getInclude().begin(); typeId != options.getInclude().end(); ++typeId) { - updateSelectedByTypeId(selectedColumns, *typeId); + updateSelectedByTypeId(selectedColumns, *typeId, idReadIntentMap); } } else { // default is to select all columns @@ -156,9 +206,16 @@ namespace orc { } void ColumnSelector::updateSelectedByTypeId(std::vector<bool>& selectedColumns, uint64_t typeId) { + updateSelectedByTypeId(selectedColumns, typeId, EMPTY_IDREADINTENTMAP()); + } + + void ColumnSelector::updateSelectedByTypeId( + std::vector<bool> &selectedColumns, + uint64_t typeId, + const RowReaderOptions::IdReadIntentMap& idReadIntentMap) { if (typeId < selectedColumns.size()) { const Type& type = *idTypeMap[typeId]; - selectChildren(selectedColumns, type); + selectChildren(selectedColumns, type, idReadIntentMap); } else { std::stringstream buffer; buffer << "Invalid type id selected " << typeId << " out of " @@ -173,7 +230,15 @@ namespace orc { if (ite != nameIdMap.end()) { updateSelectedByTypeId(selectedColumns, ite->second); } else { - throw ParseError("Invalid column selected " + fieldName); + bool first = true; + std::ostringstream ss; + ss << "Invalid column selected " << fieldName << ". Valid names are "; + for (auto it = nameIdMap.begin(); it != nameIdMap.end(); ++it) { + if (!first) ss << ", "; + ss << it->first; + first = false; + } + throw ParseError(ss.str()); } } @@ -189,7 +254,8 @@ namespace orc { forcedScaleOnHive11Decimal(opts.getForcedScaleOnHive11Decimal()), footer(contents->footer.get()), firstRowOfStripe(*contents->pool, 0), - enableEncodedBlock(opts.getEnableLazyDecoding()) { + enableEncodedBlock(opts.getEnableLazyDecoding()), + readerTimezone(getTimezoneByName(opts.getTimezoneName())) { uint64_t numberOfStripes; numberOfStripes = static_cast<uint64_t>(footer->stripes_size()); currentStripe = numberOfStripes; @@ -227,6 +293,43 @@ namespace orc { ColumnSelector column_selector(contents.get()); column_selector.updateSelected(selectedColumns, opts); + + // prepare SargsApplier if SearchArgument is available + if (opts.getSearchArgument() && footer->rowindexstride() > 0) { + sargs = opts.getSearchArgument(); + sargsApplier.reset(new SargsApplier(*contents->schema, + sargs.get(), + footer->rowindexstride(), + getWriterVersionImpl(_contents.get()))); + } + + skipBloomFilters = hasBadBloomFilters(); + } + + // Check if the file has inconsistent bloom filters. + bool RowReaderImpl::hasBadBloomFilters() { + // Only C++ writer in old releases could have bad bloom filters. + if (footer->writer() != ORC_CPP_WRITER) return false; + // 'softwareVersion' is added in 1.5.13, 1.6.11, and 1.7.0. + // 1.6.x releases before 1.6.11 won't have it. On the other side, the C++ writer + // supports writing bloom filters since 1.6.0. So files written by the C++ writer + // and with 'softwareVersion' unset would have bad bloom filters. + if (!footer->has_softwareversion()) return true; + + const std::string &fullVersion = footer->softwareversion(); + std::string version; + // Deal with snapshot versions, e.g. 1.6.12-SNAPSHOT. + if (fullVersion.find('-') != std::string::npos) { + version = fullVersion.substr(0, fullVersion.find('-')); + } else { + version = fullVersion; + } + for (const char *v : BAD_CPP_BLOOM_FILTER_VERSIONS) { + if (version == v) { + return true; + } + } + return false; } CompressionKind RowReaderImpl::getCompression() const { @@ -294,24 +397,35 @@ namespace orc { startNextStripe(); uint64_t rowsToSkip = currentRowInStripe; - - if (footer->rowindexstride() > 0 && - currentStripeInfo.indexlength() > 0) { - uint32_t rowGroupId = - static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride()); - rowsToSkip -= rowGroupId * footer->rowindexstride(); - - if (rowGroupId != 0) { - seekToRowGroup(rowGroupId); + auto rowIndexStride = footer->rowindexstride(); + // seek to the target row group if row indexes exists + if (rowIndexStride > 0 && currentStripeInfo.indexlength() > 0) { + // when predicate push down is enabled, above call to startNextStripe() + // will move current row to 1st matching row group; here we only need + // to deal with the case when PPD is not enabled. + if (!sargsApplier) { + if (rowIndexes.empty()) { + loadStripeIndex(); + } + auto rowGroupId = static_cast<uint32_t>(rowsToSkip / rowIndexStride); + if (rowGroupId != 0) { + seekToRowGroup(rowGroupId); + } } + // skip leading rows in the target row group + rowsToSkip %= rowIndexStride; + } + // 'reader' is reset in startNextStripe(). It could be nullptr if 'rowsToSkip' is 0, + // e.g. when startNextStripe() skips all remaining rows of the file. + if (rowsToSkip > 0) { + reader->skip(rowsToSkip); } - - reader->skip(rowsToSkip); } - void RowReaderImpl::seekToRowGroup(uint32_t rowGroupEntryId) { + void RowReaderImpl::loadStripeIndex() { // reset all previous row indexes rowIndexes.clear(); + bloomFilterIndex.clear(); // obtain row indexes for selected columns uint64_t offset = currentStripeInfo.offset(); @@ -319,7 +433,8 @@ namespace orc { const proto::Stream& pbStream = currentStripeFooter.streams(i); uint64_t colId = pbStream.column(); if (selectedColumns[colId] && pbStream.has_kind() - && pbStream.kind() == proto::Stream_Kind_ROW_INDEX) { + && (pbStream.kind() == proto::Stream_Kind_ROW_INDEX || + pbStream.kind() == proto::Stream_Kind_BLOOM_FILTER_UTF8)) { std::unique_ptr<SeekableInputStream> inStream = createDecompressor(getCompression(), std::unique_ptr<SeekableInputStream> @@ -331,18 +446,35 @@ namespace orc { getCompressionSize(), *contents->pool); - proto::RowIndex rowIndex; - if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) { - throw ParseError("Failed to parse the row index"); + if (pbStream.kind() == proto::Stream_Kind_ROW_INDEX) { + proto::RowIndex rowIndex; + if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) { + throw ParseError("Failed to parse the row index"); + } + rowIndexes[colId] = rowIndex; + } else if (!skipBloomFilters) { // Stream_Kind_BLOOM_FILTER_UTF8 + proto::BloomFilterIndex pbBFIndex; + if (!pbBFIndex.ParseFromZeroCopyStream(inStream.get())) { + throw ParseError("Failed to parse bloom filter index"); + } + BloomFilterIndex bfIndex; + for (int j = 0; j < pbBFIndex.bloomfilter_size(); j++) { + bfIndex.entries.push_back(BloomFilterUTF8Utils::deserialize( + pbStream.kind(), + currentStripeFooter.columns(static_cast<int>(pbStream.column())), + pbBFIndex.bloomfilter(j))); + } + // add bloom filters to result for one column + bloomFilterIndex[pbStream.column()] = bfIndex; } - - rowIndexes[colId] = rowIndex; } offset += pbStream.length(); } + } + void RowReaderImpl::seekToRowGroup(uint32_t rowGroupEntryId) { // store positions for selected columns - std::vector<std::list<uint64_t>> positions; + std::list<std::list<uint64_t>> positions; // store position providers for selected colimns std::unordered_map<uint64_t, PositionProvider> positionProviders; @@ -372,6 +504,10 @@ namespace orc { return throwOnHive11DecimalOverflow; } + bool RowReaderImpl::getIsDecimalAsLong() const { + return contents->isDecimalAsLong; + } + int32_t RowReaderImpl::getForcedScaleOnHive11Decimal() const { return forcedScaleOnHive11Decimal; } @@ -395,6 +531,13 @@ namespace orc { throw ParseError(std::string("bad StripeFooter from ") + pbStream->getName()); } + // Verify StripeFooter in case it's corrupt + if (result.columns_size() != contents.footer->types_size()) { + std::stringstream msg; + msg << "bad number of ColumnEncodings in StripeFooter: expected=" + << contents.footer->types_size() << ", actual=" << result.columns_size(); + throw ParseError(msg.str()); + } return result; } @@ -450,8 +593,8 @@ namespace orc { if (!isMetadataLoaded) { readMetadata(); } - return metadata.get() == nullptr ? 0 : - static_cast<uint64_t>(metadata->stripestats_size()); + return contents->metadata == nullptr ? 0 : + static_cast<uint64_t>(contents->metadata->stripestats_size()); } std::unique_ptr<StripeInformation> @@ -479,9 +622,7 @@ namespace orc { if (contents->postscript->version_size() != 2) { return FileVersion::v_0_11(); } - return FileVersion( - contents->postscript->version(0), - contents->postscript->version(1)); + return {contents->postscript->version(0), contents->postscript->version(1)}; } uint64_t ReaderImpl::getNumberOfRows() const { @@ -518,10 +659,7 @@ namespace orc { } WriterVersion ReaderImpl::getWriterVersion() const { - if (!contents->postscript->has_writerversion()) { - return WriterVersion_ORIGINAL; - } - return static_cast<WriterVersion>(contents->postscript->writerversion()); + return getWriterVersionImpl(contents.get()); } uint64_t ReaderImpl::getContentLength() const { @@ -631,11 +769,11 @@ namespace orc { if (!isMetadataLoaded) { readMetadata(); } - if (metadata.get() == nullptr) { + if (contents->metadata == nullptr) { throw std::logic_error("No stripe statistics in file"); } size_t num_cols = static_cast<size_t>( - metadata->stripestats( + contents->metadata->stripestats( static_cast<int>(stripeIndex)).colstats_size()); std::vector<std::vector<proto::ColumnStatistics> > indexStats(num_cols); @@ -652,7 +790,7 @@ namespace orc { getLocalTimezone(); StatContext statContext(hasCorrectStatistics(), &writerTZ); return std::unique_ptr<StripeStatistics> - (new StripeStatisticsImpl(metadata->stripestats(static_cast<int>(stripeIndex)), + (new StripeStatisticsImpl(contents->metadata->stripestats(static_cast<int>(stripeIndex)), indexStats, statContext)); } @@ -695,8 +833,8 @@ namespace orc { *contents->pool)), contents->blockSize, *contents->pool); - metadata.reset(new proto::Metadata()); - if (!metadata->ParseFromZeroCopyStream(pbStream.get())) { + contents->metadata.reset(new proto::Metadata()); + if (!contents->metadata->ParseFromZeroCopyStream(pbStream.get())) { throw ParseError("Failed to parse the metadata"); } } @@ -724,6 +862,10 @@ namespace orc { std::unique_ptr<RowReader> ReaderImpl::createRowReader( const RowReaderOptions& opts) const { + if (opts.getSearchArgument() && !isMetadataLoaded) { + // load stripe statistics for PPD + readMetadata(); + } return std::unique_ptr<RowReader>(new RowReaderImpl(contents, opts)); } @@ -746,6 +888,7 @@ namespace orc { case proto::Type_Kind_BINARY: case proto::Type_Kind_DECIMAL: case proto::Type_Kind_TIMESTAMP: + case proto::Type_Kind_TIMESTAMP_INSTANT: return 3; case proto::Type_Kind_CHAR: case proto::Type_Kind_STRING: @@ -892,42 +1035,113 @@ namespace orc { return memory + decompressorMemory ; } + // Update fields to indicate we've reached the end of file + void RowReaderImpl::markEndOfFile() { + currentStripe = lastStripe; + currentRowInStripe = 0; + rowsInCurrentStripe = 0; + if (lastStripe == 0) { + // Empty file + previousRow = 0; + } else { + previousRow = firstRowOfStripe[lastStripe - 1] + + footer->stripes(static_cast<int>(lastStripe - 1)).numberofrows(); + } + } + void RowReaderImpl::startNextStripe() { reader.reset(); // ColumnReaders use lots of memory; free old memory first - currentStripeInfo = footer->stripes(static_cast<int>(currentStripe)); - uint64_t fileLength = contents->stream->getLength(); - if (currentStripeInfo.offset() + currentStripeInfo.indexlength() + + rowIndexes.clear(); + bloomFilterIndex.clear(); + + // evaluate file statistics if it exists + if (sargsApplier && !sargsApplier->evaluateFileStatistics(*footer)) { + // skip the entire file + markEndOfFile(); + return; + } + + do { + currentStripeInfo = footer->stripes(static_cast<int>(currentStripe)); + uint64_t fileLength = contents->stream->getLength(); + if (currentStripeInfo.offset() + currentStripeInfo.indexlength() + currentStripeInfo.datalength() + currentStripeInfo.footerlength() >= fileLength) { - std::stringstream msg; - msg << "Malformed StripeInformation at stripe index " << currentStripe << ": fileLength=" - << fileLength << ", StripeInfo=(offset=" << currentStripeInfo.offset() << ", indexLength=" - << currentStripeInfo.indexlength() << ", dataLength=" << currentStripeInfo.datalength() - << ", footerLength=" << currentStripeInfo.footerlength() << ")"; - throw ParseError(msg.str()); + std::stringstream msg; + msg << "Malformed StripeInformation at stripe index " << currentStripe << ": fileLength=" + << fileLength << ", StripeInfo=(offset=" << currentStripeInfo.offset() << ", indexLength=" + << currentStripeInfo.indexlength() << ", dataLength=" << currentStripeInfo.datalength() + << ", footerLength=" << currentStripeInfo.footerlength() << ")"; + throw ParseError(msg.str()); + } + currentStripeFooter = getStripeFooter(currentStripeInfo, *contents.get()); + rowsInCurrentStripe = currentStripeInfo.numberofrows(); + + if (sargsApplier) { + bool isStripeNeeded = true; + if (contents->metadata) { + const auto& currentStripeStats = + contents->metadata->stripestats(static_cast<int>(currentStripe)); + // skip this stripe after stats fail to satisfy sargs + isStripeNeeded = sargsApplier->evaluateStripeStatistics(currentStripeStats); + } + + if (isStripeNeeded) { + // read row group statistics and bloom filters of current stripe + loadStripeIndex(); + + // select row groups to read in the current stripe + sargsApplier->pickRowGroups(rowsInCurrentStripe, + rowIndexes, + bloomFilterIndex); + if (sargsApplier->hasSelectedFrom(currentRowInStripe)) { + // current stripe has at least one row group matching the predicate + break; + } + isStripeNeeded = false; + } + if (!isStripeNeeded) { + // advance to next stripe when current stripe has no matching rows + currentStripe += 1; + currentRowInStripe = 0; + } + } + } while (sargsApplier && currentStripe < lastStripe); + + if (currentStripe < lastStripe) { + // get writer timezone info from stripe footer to help understand timestamp values. + const Timezone& writerTimezone = + currentStripeFooter.has_writertimezone() ? + getTimezoneByName(currentStripeFooter.writertimezone()) : + localTimezone; + StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo, + currentStripeFooter, + currentStripeInfo.offset(), + *contents->stream, + writerTimezone, + readerTimezone); + reader = buildReader(*contents->schema, stripeStreams); + + if (sargsApplier) { + // move to the 1st selected row group when PPD is enabled. + currentRowInStripe = advanceToNextRowGroup(currentRowInStripe, + rowsInCurrentStripe, + footer->rowindexstride(), + sargsApplier->getNextSkippedRows()); + previousRow = firstRowOfStripe[currentStripe] + currentRowInStripe - 1; + if (currentRowInStripe > 0) { + seekToRowGroup(static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride())); + } + } + } else { + // All remaining stripes are skipped. + markEndOfFile(); } - currentStripeFooter = getStripeFooter(currentStripeInfo, *contents.get()); - rowsInCurrentStripe = currentStripeInfo.numberofrows(); - const Timezone& writerTimezone = - currentStripeFooter.has_writertimezone() ? - getTimezoneByName(currentStripeFooter.writertimezone()) : - localTimezone; - StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo, - currentStripeFooter, - currentStripeInfo.offset(), - *(contents->stream.get()), - writerTimezone); - reader = buildReader(*contents->schema.get(), stripeStreams); } bool RowReaderImpl::next(ColumnVectorBatch& data) { if (currentStripe >= lastStripe) { data.numElements = 0; - if (lastStripe > 0) { - previousRow = firstRowOfStripe[lastStripe - 1] + - footer->stripes(static_cast<int>(lastStripe - 1)).numberofrows(); - } else { - previousRow = 0; - } + markEndOfFile(); return false; } if (currentRowInStripe == 0) { @@ -936,7 +1150,18 @@ namespace orc { uint64_t rowsToRead = std::min(static_cast<uint64_t>(data.capacity), rowsInCurrentStripe - currentRowInStripe); + if (sargsApplier && rowsToRead > 0) { + rowsToRead = computeBatchSize(rowsToRead, + currentRowInStripe, + rowsInCurrentStripe, + footer->rowindexstride(), + sargsApplier->getNextSkippedRows()); + } data.numElements = rowsToRead; + if (rowsToRead == 0) { + markEndOfFile(); + return false; + } if (enableEncodedBlock) { reader->nextEncoded(data, rowsToRead, nullptr); } @@ -946,6 +1171,22 @@ namespace orc { // update row number previousRow = firstRowOfStripe[currentStripe] + currentRowInStripe; currentRowInStripe += rowsToRead; + + // check if we need to advance to next selected row group + if (sargsApplier) { + uint64_t nextRowToRead = advanceToNextRowGroup(currentRowInStripe, + rowsInCurrentStripe, + footer->rowindexstride(), + sargsApplier->getNextSkippedRows()); + if (currentRowInStripe != nextRowToRead) { + // it is guaranteed to be at start of a row group + currentRowInStripe = nextRowToRead; + if (currentRowInStripe < rowsInCurrentStripe) { + seekToRowGroup(static_cast<uint32_t>(currentRowInStripe / footer->rowindexstride())); + } + } + } + if (currentRowInStripe >= rowsInCurrentStripe) { currentStripe += 1; currentRowInStripe = 0; @@ -953,6 +1194,52 @@ namespace orc { return rowsToRead != 0; } + uint64_t RowReaderImpl::computeBatchSize(uint64_t requestedSize, + uint64_t currentRowInStripe, + uint64_t rowsInCurrentStripe, + uint64_t rowIndexStride, + const std::vector<uint64_t>& nextSkippedRows) { + // In case of PPD, batch size should be aware of row group boundaries. If only a subset of row + // groups are selected then marker position is set to the end of range (subset of row groups + // within stripe). + uint64_t endRowInStripe = rowsInCurrentStripe; + uint64_t groupsInStripe = nextSkippedRows.size(); + if (groupsInStripe > 0) { + auto rg = static_cast<uint32_t>(currentRowInStripe / rowIndexStride); + if (rg >= groupsInStripe) return 0; + uint64_t nextSkippedRow = nextSkippedRows[rg]; + if (nextSkippedRow == 0) return 0; + endRowInStripe = nextSkippedRow; + } + return std::min(requestedSize, endRowInStripe - currentRowInStripe); + } + + uint64_t RowReaderImpl::advanceToNextRowGroup(uint64_t currentRowInStripe, + uint64_t rowsInCurrentStripe, + uint64_t rowIndexStride, + const std::vector<uint64_t>& nextSkippedRows) { + auto groupsInStripe = nextSkippedRows.size(); + if (groupsInStripe == 0) { + // No PPD, keeps using the current row in stripe + return std::min(currentRowInStripe, rowsInCurrentStripe); + } + auto rg = static_cast<uint32_t>(currentRowInStripe / rowIndexStride); + if (rg >= groupsInStripe) { + // Points to the end of the stripe + return rowsInCurrentStripe; + } + if (nextSkippedRows[rg] != 0) { + // Current row group is selected + return currentRowInStripe; + } + // Advance to the next selected row group + while (rg < groupsInStripe && nextSkippedRows[rg] == 0) ++rg; + if (rg < groupsInStripe) { + return rg * rowIndexStride; + } + return rowsInCurrentStripe; + } + std::unique_ptr<ColumnVectorBatch> RowReaderImpl::createRowBatch (uint64_t capacity) const { return getSelectedType().createRowBatch(capacity, *contents->pool, enableEncodedBlock); @@ -1017,10 +1304,11 @@ namespace orc { } /** - * Check that indices in the type tree are valid, so we won't crash - * when we convert the proto::Types to TypeImpls. + * Check that proto Types are valid. Indices in the type tree should be valid, + * so we won't crash when we convert the proto::Types to TypeImpls (ORC-317). + * For STRUCT types, fieldName size should match subTypes size (ORC-581). */ - void checkProtoTypeIds(const proto::Footer &footer) { + void checkProtoTypes(const proto::Footer &footer) { std::stringstream msg; int maxId = footer.types_size(); if (maxId <= 0) { @@ -1028,6 +1316,12 @@ namespace orc { } for (int i = 0; i < maxId; ++i) { const proto::Type& type = footer.types(i); + if (type.kind() == proto::Type_Kind_STRUCT + && type.subtypes_size() != type.fieldnames_size()) { + msg << "Footer is corrupt: STRUCT type " << i << " has " << type.subtypes_size() + << " subTypes, but has " << type.fieldnames_size() << " fieldNames"; + throw ParseError(msg.str()); + } for (int j = 0; j < type.subtypes_size(); ++j) { int subTypeId = static_cast<int>(type.subtypes(j)); if (subTypeId <= i) { @@ -1079,7 +1373,7 @@ namespace orc { stream->getName()); } - checkProtoTypeIds(*footer); + checkProtoTypes(*footer); return REDUNDANT_MOVE(footer); } @@ -1137,6 +1431,13 @@ namespace orc { contents->footer = REDUNDANT_MOVE(readFooter(stream.get(), buffer.get(), footerOffset, *contents->postscript, *contents->pool)); } + contents->isDecimalAsLong = false; + if (contents->postscript->version_size() == 2) { + FileVersion v(contents->postscript->version(0), contents->postscript->version(1)); + if (v == FileVersion::UNSTABLE_PRE_2_0()) { + contents->isDecimalAsLong = true; + } + } contents->stream = std::move(stream); return std::unique_ptr<Reader>(new ReaderImpl(std::move(contents), options, diff --git a/contrib/libs/apache/orc/c++/src/Reader.hh b/contrib/libs/apache/orc/c++/src/Reader.hh index 49e9d033d9f..ffaff4176e3 100644 --- a/contrib/libs/apache/orc/c++/src/Reader.hh +++ b/contrib/libs/apache/orc/c++/src/Reader.hh @@ -19,13 +19,14 @@ #ifndef ORC_READER_IMPL_HH #define ORC_READER_IMPL_HH +#include "orc/Exceptions.hh" #include "orc/Int128.hh" #include "orc/OrcFile.hh" #include "orc/Reader.hh" #include "ColumnReader.hh" -#include "orc/Exceptions.hh" #include "RLE.hh" +#include "sargs/SargsApplier.hh" #include "TypeImpl.hh" namespace orc { @@ -62,12 +63,17 @@ namespace orc { CompressionKind compression; MemoryPool *pool; std::ostream *errorStream; + /// Decimal64 in ORCv2 uses RLE to store values. This flag indicates whether + /// this new encoding is used. + bool isDecimalAsLong; + std::unique_ptr<proto::Metadata> metadata; }; proto::StripeFooter getStripeFooter(const proto::StripeInformation& info, const FileContents& contents); class ReaderImpl; + class Timezone; class ColumnSelector { private: @@ -87,13 +93,22 @@ namespace orc { void updateSelectedByFieldId(std::vector<bool>& selectedColumns, uint64_t fieldId); // Select a type by id void updateSelectedByTypeId(std::vector<bool>& selectedColumns, uint64_t typeId); + // Select a type by id and read intent map. + void updateSelectedByTypeId(std::vector<bool>& selectedColumns, uint64_t typeId, + const RowReaderOptions::IdReadIntentMap& idReadIntentMap); // Select all of the recursive children of the given type. void selectChildren(std::vector<bool>& selectedColumns, const Type& type); + // Select a type id of the given type. + // This function may also select all of the recursive children of the given type + // depending on the read intent of that type in idReadIntentMap. + void selectChildren(std::vector<bool>& selectedColumns, const Type& type, + const RowReaderOptions::IdReadIntentMap& idReadIntentMap); // For each child of type, select it if one of its children // is selected. bool selectParents(std::vector<bool>& selectedColumns, const Type& type); + /** * Constructor that selects columns. * @param contents of the file @@ -140,9 +155,37 @@ namespace orc { bool enableEncodedBlock; // internal methods void startNextStripe(); + inline void markEndOfFile(); // row index of current stripe with column id as the key std::unordered_map<uint64_t, proto::RowIndex> rowIndexes; + std::map<uint32_t, BloomFilterIndex> bloomFilterIndex; + std::shared_ptr<SearchArgument> sargs; + std::unique_ptr<SargsApplier> sargsApplier; + + // desired timezone to return data of timestamp types. + const Timezone& readerTimezone; + + // load stripe index if not done so + void loadStripeIndex(); + + // In case of PPD, batch size should be aware of row group boundaries. + // If only a subset of row groups are selected then the next read should + // stop at the end of selected range. + static uint64_t computeBatchSize(uint64_t requestedSize, + uint64_t currentRowInStripe, + uint64_t rowsInCurrentStripe, + uint64_t rowIndexStride, + const std::vector<uint64_t>& nextSkippedRows); + + // Skip non-selected rows + static uint64_t advanceToNextRowGroup(uint64_t currentRowInStripe, + uint64_t rowsInCurrentStripe, + uint64_t rowIndexStride, + const std::vector<uint64_t>& nextSkippedRows); + + friend class TestRowReader_advanceToNextRowGroup_Test; + friend class TestRowReader_computeBatchSize_Test; /** * Seek to the start of a row group in the current stripe @@ -167,7 +210,6 @@ namespace orc { const RowReaderOptions& options); // Select the columns from the options object - void updateSelected(); const std::vector<bool> getSelectedColumns() const override; const Type& getSelectedType() const override; @@ -187,6 +229,7 @@ namespace orc { const FileContents& getFileContents() const; bool getThrowOnHive11DecimalOverflow() const; + bool getIsDecimalAsLong() const; int32_t getForcedScaleOnHive11Decimal() const; }; @@ -213,7 +256,6 @@ namespace orc { std::vector<std::vector<proto::ColumnStatistics> >* indexStats) const; // metadata - mutable std::unique_ptr<proto::Metadata> metadata; mutable bool isMetadataLoaded; public: /** diff --git a/contrib/libs/apache/orc/c++/src/RleDecoderV2.cc b/contrib/libs/apache/orc/c++/src/RleDecoderV2.cc index c5c6f6a8017..8ab57b1f6e3 100644 --- a/contrib/libs/apache/orc/c++/src/RleDecoderV2.cc +++ b/contrib/libs/apache/orc/c++/src/RleDecoderV2.cc @@ -23,6 +23,21 @@ namespace orc { +unsigned char RleDecoderV2::readByte() { + if (bufferStart == bufferEnd) { + int bufferLength; + const void* bufferPointer; + if (!inputStream->Next(&bufferPointer, &bufferLength)) { + throw ParseError("bad read in RleDecoderV2::readByte"); + } + bufferStart = static_cast<const char*>(bufferPointer); + bufferEnd = bufferStart + bufferLength; + } + + unsigned char result = static_cast<unsigned char>(*bufferStart++); + return result; +} + int64_t RleDecoderV2::readLongBE(uint64_t bsz) { int64_t ret = 0, val; uint64_t n = bsz; @@ -49,6 +64,332 @@ uint64_t RleDecoderV2::readVulong() { return ret; } +void RleDecoderV2::readLongs(int64_t *data, uint64_t offset, uint64_t len, uint64_t fbs) { + switch (fbs) { + case 4: + unrolledUnpack4(data, offset, len); + return; + case 8: + unrolledUnpack8(data, offset, len); + return; + case 16: + unrolledUnpack16(data, offset, len); + return; + case 24: + unrolledUnpack24(data, offset, len); + return; + case 32: + unrolledUnpack32(data, offset, len); + return; + case 40: + unrolledUnpack40(data, offset, len); + return; + case 48: + unrolledUnpack48(data, offset, len); + return; + case 56: + unrolledUnpack56(data, offset, len); + return; + case 64: + unrolledUnpack64(data, offset, len); + return; + default: + // Fallback to the default implementation for deprecated bit size. + plainUnpackLongs(data, offset, len, fbs); + return; + } +} + +void RleDecoderV2::unrolledUnpack4(int64_t* data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Make sure bitsLeft is 0 before the loop. bitsLeft can only be 0, 4, or 8. + while (bitsLeft > 0 && curIdx < offset + len) { + bitsLeft -= 4; + data[curIdx++] = (curByte >> bitsLeft) & 15; + } + if (curIdx == offset + len) return; + + // Exhaust the buffer + uint64_t numGroups = (offset + len - curIdx) / 2; + numGroups = std::min(numGroups, static_cast<uint64_t>(bufferEnd - bufferStart)); + // Avoid updating 'bufferStart' inside the loop. + const auto *buffer = reinterpret_cast<const unsigned char*>(bufferStart); + uint32_t localByte; + for (uint64_t i = 0; i < numGroups; ++i) { + localByte = *buffer++; + data[curIdx] = (localByte >> 4) & 15; + data[curIdx + 1] = localByte & 15; + curIdx += 2; + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // readByte() will update 'bufferStart' and 'bufferEnd' + curByte = readByte(); + bitsLeft = 8; + } +} + +void RleDecoderV2::unrolledUnpack8(int64_t* data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = bufferEnd - bufferStart; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + data[curIdx++] = *buffer++; + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // readByte() will update 'bufferStart' and 'bufferEnd'. + data[curIdx++] = readByte(); + } +} + +void RleDecoderV2::unrolledUnpack16(int64_t* data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 2; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint16_t b0, b1; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint16_t>(*buffer); + b1 = static_cast<uint16_t>(*(buffer + 1)); + buffer += 2; + data[curIdx++] = (b0 << 8) | b1; + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + data[curIdx++] = (b0 << 8) | b1; + } +} + +void RleDecoderV2::unrolledUnpack24(int64_t* data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 3; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint32_t b0, b1, b2; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint32_t>(*buffer); + b1 = static_cast<uint32_t>(*(buffer + 1)); + b2 = static_cast<uint32_t>(*(buffer + 2)); + buffer += 3; + data[curIdx++] = static_cast<int64_t>((b0 << 16) | (b1 << 8) | b2); + } + bufferStart += bufferNum * 3; + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + b2 = readByte(); + data[curIdx++] = static_cast<int64_t>((b0 << 16) | (b1 << 8) | b2); + } +} + +void RleDecoderV2::unrolledUnpack32(int64_t* data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 4; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint32_t b0, b1, b2, b3; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint32_t>(*buffer); + b1 = static_cast<uint32_t>(*(buffer + 1)); + b2 = static_cast<uint32_t>(*(buffer + 2)); + b3 = static_cast<uint32_t>(*(buffer + 3)); + buffer += 4; + data[curIdx++] = static_cast<int64_t>((b0 << 24) | (b1 << 16) | (b2 << 8) | b3); + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + b2 = readByte(); + b3 = readByte(); + data[curIdx++] = static_cast<int64_t>((b0 << 24) | (b1 << 16) | (b2 << 8) | b3); + } +} + +void RleDecoderV2::unrolledUnpack40(int64_t* data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 5; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint64_t b0, b1, b2, b3, b4; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint32_t>(*buffer); + b1 = static_cast<uint32_t>(*(buffer + 1)); + b2 = static_cast<uint32_t>(*(buffer + 2)); + b3 = static_cast<uint32_t>(*(buffer + 3)); + b4 = static_cast<uint32_t>(*(buffer + 4)); + buffer += 5; + data[curIdx++] = static_cast<int64_t>((b0 << 32) | (b1 << 24) | (b2 << 16) | (b3 << 8) | b4); + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + b2 = readByte(); + b3 = readByte(); + b4 = readByte(); + data[curIdx++] = static_cast<int64_t>((b0 << 32) | (b1 << 24) | (b2 << 16) | (b3 << 8) | b4); + } +} + +void RleDecoderV2::unrolledUnpack48(int64_t *data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 6; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint64_t b0, b1, b2, b3, b4, b5; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint32_t>(*buffer); + b1 = static_cast<uint32_t>(*(buffer + 1)); + b2 = static_cast<uint32_t>(*(buffer + 2)); + b3 = static_cast<uint32_t>(*(buffer + 3)); + b4 = static_cast<uint32_t>(*(buffer + 4)); + b5 = static_cast<uint32_t>(*(buffer + 5)); + buffer += 6; + data[curIdx++] = static_cast<int64_t>((b0 << 40) | (b1 << 32) | (b2 << 24) | (b3 << 16) | (b4 << 8) | b5); + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + b2 = readByte(); + b3 = readByte(); + b4 = readByte(); + b5 = readByte(); + data[curIdx++] = static_cast<int64_t>((b0 << 40) | (b1 << 32) | (b2 << 24) | (b3 << 16) | (b4 << 8) | b5); + } +} + +void RleDecoderV2::unrolledUnpack56(int64_t *data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 7; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint64_t b0, b1, b2, b3, b4, b5, b6; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint32_t>(*buffer); + b1 = static_cast<uint32_t>(*(buffer + 1)); + b2 = static_cast<uint32_t>(*(buffer + 2)); + b3 = static_cast<uint32_t>(*(buffer + 3)); + b4 = static_cast<uint32_t>(*(buffer + 4)); + b5 = static_cast<uint32_t>(*(buffer + 5)); + b6 = static_cast<uint32_t>(*(buffer + 6)); + buffer += 7; + data[curIdx++] = static_cast<int64_t>((b0 << 48) | (b1 << 40) | (b2 << 32) | (b3 << 24) | (b4 << 16) | (b5 << 8) | b6); + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + b2 = readByte(); + b3 = readByte(); + b4 = readByte(); + b5 = readByte(); + b6 = readByte(); + data[curIdx++] = static_cast<int64_t>((b0 << 48) | (b1 << 40) | (b2 << 32) | (b3 << 24) | (b4 << 16) | (b5 << 8) | b6); + } +} + +void RleDecoderV2::unrolledUnpack64(int64_t *data, uint64_t offset, uint64_t len) { + uint64_t curIdx = offset; + while (curIdx < offset + len) { + // Exhaust the buffer + int64_t bufferNum = (bufferEnd - bufferStart) / 8; + bufferNum = std::min(bufferNum, static_cast<int64_t>(offset + len - curIdx)); + uint64_t b0, b1, b2, b3, b4, b5, b6, b7; + // Avoid updating 'bufferStart' inside the loop. + const auto* buffer = reinterpret_cast<const unsigned char*>(bufferStart); + for (int i = 0; i < bufferNum; ++i) { + b0 = static_cast<uint32_t>(*buffer); + b1 = static_cast<uint32_t>(*(buffer + 1)); + b2 = static_cast<uint32_t>(*(buffer + 2)); + b3 = static_cast<uint32_t>(*(buffer + 3)); + b4 = static_cast<uint32_t>(*(buffer + 4)); + b5 = static_cast<uint32_t>(*(buffer + 5)); + b6 = static_cast<uint32_t>(*(buffer + 6)); + b7 = static_cast<uint32_t>(*(buffer + 7)); + buffer += 8; + data[curIdx++] = static_cast<int64_t>((b0 << 56) | (b1 << 48) | (b2 << 40) | (b3 << 32) | (b4 << 24) | (b5 << 16) | (b6 << 8) | b7); + } + bufferStart = reinterpret_cast<const char*>(buffer); + if (curIdx == offset + len) return; + + // One of the following readByte() will update 'bufferStart' and 'bufferEnd'. + b0 = readByte(); + b1 = readByte(); + b2 = readByte(); + b3 = readByte(); + b4 = readByte(); + b5 = readByte(); + b6 = readByte(); + b7 = readByte(); + data[curIdx++] = static_cast<int64_t>((b0 << 56) | (b1 << 48) | (b2 << 40) | (b3 << 32) | (b4 << 24) | (b5 << 16) | (b6 << 8) | b7); + } +} + +void RleDecoderV2::plainUnpackLongs(int64_t *data, uint64_t offset, uint64_t len, + uint64_t fbs) { + for (uint64_t i = offset; i < (offset + len); i++) { + uint64_t result = 0; + uint64_t bitsLeftToRead = fbs; + while (bitsLeftToRead > bitsLeft) { + result <<= bitsLeft; + result |= curByte & ((1 << bitsLeft) - 1); + bitsLeftToRead -= bitsLeft; + curByte = readByte(); + bitsLeft = 8; + } + + // handle the left over bits + if (bitsLeftToRead > 0) { + result <<= bitsLeftToRead; + bitsLeft -= static_cast<uint32_t>(bitsLeftToRead); + result |= (curByte >> bitsLeft) & ((1 << bitsLeftToRead) - 1); + } + data[i] = static_cast<int64_t>(result); + } +} + RleDecoderV2::RleDecoderV2(std::unique_ptr<SeekableInputStream> input, bool _isSigned, MemoryPool& pool ): inputStream(std::move(input)), @@ -58,23 +399,10 @@ RleDecoderV2::RleDecoderV2(std::unique_ptr<SeekableInputStream> input, runRead(0), bufferStart(nullptr), bufferEnd(bufferStart), - deltaBase(0), - byteSize(0), - firstValue(0), - prevValue(0), - bitSize(0), bitsLeft(0), curByte(0), - patchBitSize(0), - unpackedIdx(0), - patchIdx(0), - base(0), - curGap(0), - curPatch(0), - patchMask(0), - actualGap(0), - unpacked(pool, 0), - unpackedPatch(pool, 0) { + unpackedPatch(pool, 0), + literals(pool, MAX_LITERAL_SIZE) { // PASS } @@ -148,7 +476,7 @@ uint64_t RleDecoderV2::nextShortRepeats(int64_t* const data, const char* const notNull) { if (runRead == runLength) { // extract the number of fixed bytes - byteSize = (firstByte >> 3) & 0x07; + uint64_t byteSize = (firstByte >> 3) & 0x07; byteSize += 1; runLength = firstByte & 0x07; @@ -157,10 +485,10 @@ uint64_t RleDecoderV2::nextShortRepeats(int64_t* const data, runRead = 0; // read the repeated value which is store using fixed bytes - firstValue = readLongBE(byteSize); + literals[0] = readLongBE(byteSize); if (isSigned) { - firstValue = unZigZag(static_cast<uint64_t>(firstValue)); + literals[0] = unZigZag(static_cast<uint64_t>(literals[0])); } } @@ -169,13 +497,13 @@ uint64_t RleDecoderV2::nextShortRepeats(int64_t* const data, if (notNull) { for(uint64_t pos = offset; pos < offset + nRead; ++pos) { if (notNull[pos]) { - data[pos] = firstValue; + data[pos] = literals[0]; ++runRead; } } } else { for(uint64_t pos = offset; pos < offset + nRead; ++pos) { - data[pos] = firstValue; + data[pos] = literals[0]; ++runRead; } } @@ -190,7 +518,7 @@ uint64_t RleDecoderV2::nextDirect(int64_t* const data, if (runRead == runLength) { // extract the number of fixed bits unsigned char fbo = (firstByte >> 1) & 0x1f; - bitSize = decodeBitWidth(fbo); + uint32_t bitSize = decodeBitWidth(fbo); // extract the run length runLength = static_cast<uint64_t>(firstByte & 0x01) << 8; @@ -198,27 +526,40 @@ uint64_t RleDecoderV2::nextDirect(int64_t* const data, // runs are one off runLength += 1; runRead = 0; - } - - uint64_t nRead = std::min(runLength - runRead, numValues); - - runRead += readLongs(data, offset, nRead, bitSize, notNull); - if (isSigned) { - if (notNull) { - for (uint64_t pos = offset; pos < offset + nRead; ++pos) { - if (notNull[pos]) { - data[pos] = unZigZag(static_cast<uint64_t>(data[pos])); - } - } - } else { - for (uint64_t pos = offset; pos < offset + nRead; ++pos) { - data[pos] = unZigZag(static_cast<uint64_t>(data[pos])); + readLongs(literals.data(), 0, runLength, bitSize); + if (isSigned) { + for (uint64_t i = 0; i < runLength; ++i) { + literals[i] = unZigZag(static_cast<uint64_t>(literals[i])); } } } - return nRead; + return copyDataFromBuffer(data, offset, numValues, notNull); +} + +void RleDecoderV2::adjustGapAndPatch(uint32_t patchBitSize, int64_t patchMask, + int64_t* resGap, int64_t* resPatch, + uint64_t* patchIdx) { + uint64_t idx = *patchIdx; + uint64_t gap = static_cast<uint64_t>(unpackedPatch[idx]) >> patchBitSize; + int64_t patch = unpackedPatch[idx] & patchMask; + int64_t actualGap = 0; + + // special case: gap is >255 then patch value will be 0. + // if gap is <=255 then patch value cannot be 0 + while (gap == 255 && patch == 0) { + actualGap += 255; + ++idx; + gap = static_cast<uint64_t>(unpackedPatch[idx]) >> patchBitSize; + patch = unpackedPatch[idx] & patchMask; + } + // add the left over gap + actualGap += gap; + + *resGap = actualGap; + *resPatch = patch; + *patchIdx = idx; } uint64_t RleDecoderV2::nextPatched(int64_t* const data, @@ -228,7 +569,7 @@ uint64_t RleDecoderV2::nextPatched(int64_t* const data, if (runRead == runLength) { // extract the number of fixed bits unsigned char fbo = (firstByte >> 1) & 0x1f; - bitSize = decodeBitWidth(fbo); + uint32_t bitSize = decodeBitWidth(fbo); // extract the run length runLength = static_cast<uint64_t>(firstByte & 0x01) << 8; @@ -239,13 +580,13 @@ uint64_t RleDecoderV2::nextPatched(int64_t* const data, // extract the number of bytes occupied by base uint64_t thirdByte = readByte(); - byteSize = (thirdByte >> 5) & 0x07; + uint64_t byteSize = (thirdByte >> 5) & 0x07; // base width is one off byteSize += 1; // extract patch width uint32_t pwo = thirdByte & 0x1f; - patchBitSize = decodeBitWidth(pwo); + uint32_t patchBitSize = decodeBitWidth(pwo); // read fourth byte and extract patch gap width uint64_t fourthByte = readByte(); @@ -260,7 +601,7 @@ uint64_t RleDecoderV2::nextPatched(int64_t* const data, } // read the next base width number of bytes to extract base value - base = readLongBE(byteSize); + int64_t base = readLongBE(byteSize); int64_t mask = (static_cast<int64_t>(1) << ((byteSize * 8) - 1)); // if mask of base value is 1 then base is negative value else positive if ((base & mask) != 0) { @@ -268,16 +609,12 @@ uint64_t RleDecoderV2::nextPatched(int64_t* const data, base = -base; } - // TODO: something more efficient than resize - unpacked.resize(runLength); - unpackedIdx = 0; - readLongs(unpacked.data(), 0, runLength, bitSize); + readLongs(literals.data(), 0, runLength, bitSize); // any remaining bits are thrown out resetReadLongs(); // TODO: something more efficient than resize unpackedPatch.resize(pl); - patchIdx = 0; // TODO: Skip corrupt? // if ((patchBitSize + pgw) > 64 && !skipCorrupt) { if ((patchBitSize + pgw) > 64) { @@ -290,44 +627,39 @@ uint64_t RleDecoderV2::nextPatched(int64_t* const data, resetReadLongs(); // apply the patch directly when decoding the packed data - patchMask = ((static_cast<int64_t>(1) << patchBitSize) - 1); + int64_t patchMask = ((static_cast<int64_t>(1) << patchBitSize) - 1); - adjustGapAndPatch(); - } + int64_t gap = 0; + int64_t patch = 0; + uint64_t patchIdx = 0; + adjustGapAndPatch(patchBitSize, patchMask, &gap, &patch, &patchIdx); - uint64_t nRead = std::min(runLength - runRead, numValues); + for (uint64_t i = 0; i < runLength; ++i) { + if (static_cast<int64_t>(i) != gap) { + // no patching required. add base to unpacked value to get final value + literals[i] += base; + } else { + // extract the patch value + int64_t patchedVal = literals[i] | (patch << bitSize); - for(uint64_t pos = offset; pos < offset + nRead; ++pos) { - // skip null positions - if (notNull && !notNull[pos]) { - continue; - } - if (static_cast<int64_t>(unpackedIdx) != actualGap) { - // no patching required. add base to unpacked value to get final value - data[pos] = base + unpacked[unpackedIdx]; - } else { - // extract the patch value - int64_t patchedVal = unpacked[unpackedIdx] | (curPatch << bitSize); - - // add base to patched value - data[pos] = base + patchedVal; + // add base to patched value + literals[i] = base + patchedVal; - // increment the patch to point to next entry in patch list - ++patchIdx; + // increment the patch to point to next entry in patch list + ++patchIdx; - if (patchIdx < unpackedPatch.size()) { - adjustGapAndPatch(); + if (patchIdx < unpackedPatch.size()) { + adjustGapAndPatch(patchBitSize, patchMask, &gap, &patch, + &patchIdx); - // next gap is relative to the current gap - actualGap += unpackedIdx; + // next gap is relative to the current gap + gap += i; + } } } - - ++runRead; - ++unpackedIdx; } - return nRead; + return copyDataFromBuffer(data, offset, numValues, notNull); } uint64_t RleDecoderV2::nextDelta(int64_t* const data, @@ -337,6 +669,7 @@ uint64_t RleDecoderV2::nextDelta(int64_t* const data, if (runRead == runLength) { // extract the number of fixed bits unsigned char fbo = (firstByte >> 1) & 0x1f; + uint32_t bitSize; if (fbo != 0) { bitSize = decodeBitWidth(fbo); } else { @@ -347,79 +680,67 @@ uint64_t RleDecoderV2::nextDelta(int64_t* const data, runLength = static_cast<uint64_t>(firstByte & 0x01) << 8; runLength |= readByte(); ++runLength; // account for first value - runRead = deltaBase = 0; + runRead = 0; + int64_t prevValue; // read the first value stored as vint if (isSigned) { - firstValue = static_cast<int64_t>(readVslong()); + prevValue = readVslong(); } else { - firstValue = static_cast<int64_t>(readVulong()); + prevValue = static_cast<int64_t>(readVulong()); } - prevValue = firstValue; + literals[0] = prevValue; // read the fixed delta value stored as vint (deltas can be negative even // if all number are positive) - deltaBase = static_cast<int64_t>(readVslong()); - } - - uint64_t nRead = std::min(runLength - runRead, numValues); - - uint64_t pos = offset; - for ( ; pos < offset + nRead; ++pos) { - // skip null positions - if (!notNull || notNull[pos]) break; - } - if (runRead == 0 && pos < offset + nRead) { - data[pos++] = firstValue; - ++runRead; - } - - if (bitSize == 0) { - // add fixed deltas to adjacent values - for ( ; pos < offset + nRead; ++pos) { - // skip null positions - if (notNull && !notNull[pos]) { - continue; - } - prevValue = data[pos] = prevValue + deltaBase; - ++runRead; - } - } else { - for ( ; pos < offset + nRead; ++pos) { - // skip null positions - if (!notNull || notNull[pos]) break; - } - if (runRead < 2 && pos < offset + nRead) { - // add delta base and first value - prevValue = data[pos++] = firstValue + deltaBase; - ++runRead; - } + int64_t deltaBase = readVslong(); - // write the unpacked values, add it to previous value and store final - // value to result buffer. if the delta base value is negative then it - // is a decreasing sequence else an increasing sequence - uint64_t remaining = (offset + nRead) - pos; - runRead += readLongs(data, pos, remaining, bitSize, notNull); - - if (deltaBase < 0) { - for ( ; pos < offset + nRead; ++pos) { - // skip null positions - if (notNull && !notNull[pos]) { - continue; - } - prevValue = data[pos] = prevValue - data[pos]; + if (bitSize == 0) { + // add fixed deltas to adjacent values + for (uint64_t i = 1; i < runLength; ++i) { + literals[i] = literals[i - 1] + deltaBase; } } else { - for ( ; pos < offset + nRead; ++pos) { - // skip null positions - if (notNull && !notNull[pos]) { - continue; + prevValue = literals[1] = prevValue + deltaBase; + if (runLength < 2) { + std::stringstream ss; + ss << "Illegal run length for delta encoding: " << runLength; + throw ParseError(ss.str()); + } + // write the unpacked values, add it to previous value and store final + // value to result buffer. if the delta base value is negative then it + // is a decreasing sequence else an increasing sequence. + // read deltas using the literals buffer. + readLongs(literals.data(), 2, runLength - 2, bitSize); + if (deltaBase < 0) { + for (uint64_t i = 2; i < runLength; ++i) { + prevValue = literals[i] = prevValue - literals[i]; + } + } else { + for (uint64_t i = 2; i < runLength; ++i) { + prevValue = literals[i] = prevValue + literals[i]; } - prevValue = data[pos] = prevValue + data[pos]; } } } + + return copyDataFromBuffer(data, offset, numValues, notNull); +} + +uint64_t RleDecoderV2::copyDataFromBuffer(int64_t* data, uint64_t offset, + uint64_t numValues, const char* notNull) { + uint64_t nRead = std::min(runLength - runRead, numValues); + if (notNull) { + for (uint64_t i = offset; i < (offset + nRead); ++i) { + if (notNull[i]) { + data[i] = literals[runRead++]; + } + } + } else { + memcpy(data + offset, literals.data() + runRead, nRead * sizeof(int64_t)); + runRead += nRead; + } return nRead; } diff --git a/contrib/libs/apache/orc/c++/src/RleEncoderV2.cc b/contrib/libs/apache/orc/c++/src/RleEncoderV2.cc index 44e2761b74c..4e7a145a5ac 100644 --- a/contrib/libs/apache/orc/c++/src/RleEncoderV2.cc +++ b/contrib/libs/apache/orc/c++/src/RleEncoderV2.cc @@ -21,7 +21,6 @@ #include "RLEv2.hh" #include "RLEV2Util.hh" -#define MAX_LITERAL_SIZE 512 #define MAX_SHORT_REPEAT_LENGTH 10 namespace orc { @@ -67,7 +66,7 @@ RleEncoderV2::RleEncoderV2(std::unique_ptr<BufferedOutputStream> outStream, prevDelta(0){ literals = new int64_t[MAX_LITERAL_SIZE]; gapVsPatchList = new int64_t[MAX_LITERAL_SIZE]; - zigzagLiterals = new int64_t[MAX_LITERAL_SIZE]; + zigzagLiterals = hasSigned ? new int64_t[MAX_LITERAL_SIZE] : nullptr; baseRedLiterals = new int64_t[MAX_LITERAL_SIZE]; adjDeltas = new int64_t[MAX_LITERAL_SIZE]; } @@ -124,7 +123,8 @@ void RleEncoderV2::write(int64_t val) { } if (fixedRunLength == MAX_LITERAL_SIZE) { - determineEncoding(option); + option.encoding = DELTA; + option.isFixedDelta = true; writeValues(option); } return; @@ -168,14 +168,9 @@ void RleEncoderV2::write(int64_t val) { } void RleEncoderV2::computeZigZagLiterals(EncodingOption &option) { - int64_t zzEncVal = 0; + assert (isSigned); for (size_t i = 0; i < numLiterals; i++) { - if (isSigned) { - zzEncVal = zigZag(literals[i]); - } else { - zzEncVal = literals[i]; - } - zigzagLiterals[option.zigzagLiteralsCount++] = zzEncVal; + zigzagLiterals[option.zigzagLiteralsCount++] = zigZag(literals[i]); } } @@ -281,6 +276,20 @@ void RleEncoderV2::preparePatchedBlob(EncodingOption& option) { } } +/** + * Prepare for Direct or PatchedBase encoding + * compute zigZagLiterals and zzBits100p (Max number of encoding bits required) + * @return zigzagLiterals + */ +int64_t* RleEncoderV2::prepareForDirectOrPatchedBase(EncodingOption& option) { + if (isSigned) { + computeZigZagLiterals(option); + } + int64_t* currentZigzagLiterals = isSigned ? zigzagLiterals : literals; + option.zzBits100p = percentileBits(currentZigzagLiterals, 0, numLiterals, 1.0); + return currentZigzagLiterals; +} + void RleEncoderV2::determineEncoding(EncodingOption& option) { // We need to compute zigzag values for DIRECT and PATCHED_BASE encodings, // but not for SHORT_REPEAT or DELTA. So we only perform the zigzag @@ -290,8 +299,7 @@ void RleEncoderV2::determineEncoding(EncodingOption& option) { if (numLiterals <= MIN_REPEAT) { // we need to compute zigzag values for DIRECT encoding if we decide to // break early for delta overflows or for shorter runs - computeZigZagLiterals(option); - option.zzBits100p = percentileBits(zigzagLiterals, 0, numLiterals, 1.0); + prepareForDirectOrPatchedBase(option); option.encoding = DIRECT; return; } @@ -331,8 +339,7 @@ void RleEncoderV2::determineEncoding(EncodingOption& option) { // PATCHED_BASE condition as encoding using DIRECT is faster and has less // overhead than PATCHED_BASE if (!isSafeSubtract(max, option.min)) { - computeZigZagLiterals(option); - option.zzBits100p = percentileBits(zigzagLiterals, 0, numLiterals, 1.0); + prepareForDirectOrPatchedBase(option); option.encoding = DIRECT; return; } @@ -388,9 +395,8 @@ void RleEncoderV2::determineEncoding(EncodingOption& option) { // beyond a threshold then we need to patch the values. if the variation // is not significant then we can use direct encoding - computeZigZagLiterals(option); - option.zzBits100p = percentileBits(zigzagLiterals, 0, numLiterals, 1.0); - option.zzBits90p = percentileBits(zigzagLiterals, 0, numLiterals, 0.9, true); + int64_t* currentZigzagLiterals = prepareForDirectOrPatchedBase(option); + option.zzBits90p = percentileBits(currentZigzagLiterals, 0, numLiterals, 0.9, true); uint32_t diffBitsLH = option.zzBits100p - option.zzBits90p; // if the difference between 90th percentile and 100th percentile fixed @@ -539,7 +545,8 @@ void RleEncoderV2::writeDirectValues(EncodingOption& option) { writeByte(headerSecondByte); // bit packing the zigzag encoded literals - writeInts(zigzagLiterals, 0, numLiterals, fb); + int64_t* currentZigzagLiterals = isSigned ? zigzagLiterals : literals; + writeInts(currentZigzagLiterals, 0, numLiterals, fb); // reset run length variableRunLength = 0; diff --git a/contrib/libs/apache/orc/c++/src/Statistics.cc b/contrib/libs/apache/orc/c++/src/Statistics.cc index 2401f5e0cb4..ccc54c291cc 100644 --- a/contrib/libs/apache/orc/c++/src/Statistics.cc +++ b/contrib/libs/apache/orc/c++/src/Statistics.cc @@ -30,6 +30,8 @@ namespace orc { return new IntegerColumnStatisticsImpl(s); } else if (s.has_doublestatistics()) { return new DoubleColumnStatisticsImpl(s); + } else if (s.has_collectionstatistics()) { + return new CollectionColumnStatisticsImpl(s); } else if (s.has_stringstatistics()) { return new StringColumnStatisticsImpl(s, statContext); } else if (s.has_bucketstatistics()) { @@ -135,6 +137,10 @@ namespace orc { // PASS } + CollectionColumnStatistics::~CollectionColumnStatistics() { + // PASS + } + MutableColumnStatistics::~MutableColumnStatistics() { // PASS } @@ -167,6 +173,10 @@ namespace orc { // PASS } + CollectionColumnStatisticsImpl::~CollectionColumnStatisticsImpl() { + // PASS + } + StringColumnStatisticsImpl::~StringColumnStatisticsImpl() { // PASS } @@ -305,6 +315,8 @@ namespace orc { _stats.setMaximum(0); _lowerBound = 0; _upperBound = 0; + _minimumNanos = DEFAULT_MIN_NANOS; + _maximumNanos = DEFAULT_MAX_NANOS; }else{ const proto::TimestampStatistics& stats = pb.timestampstatistics(); _stats.setHasMinimum( @@ -315,6 +327,12 @@ namespace orc { (stats.has_maximum() && (statContext.writerTimezone != nullptr))); _hasLowerBound = stats.has_minimumutc() || stats.has_minimum(); _hasUpperBound = stats.has_maximumutc() || stats.has_maximum(); + // to be consistent with java side, non-default minimumnanos and maximumnanos + // are added by one in their serialized form. + _minimumNanos = stats.has_minimumnanos() ? + stats.minimumnanos() - 1 : DEFAULT_MIN_NANOS; + _maximumNanos = stats.has_maximumnanos() ? + stats.maximumnanos() - 1 : DEFAULT_MAX_NANOS; // Timestamp stats are stored in milliseconds if (stats.has_minimumutc()) { @@ -361,6 +379,26 @@ namespace orc { } } + CollectionColumnStatisticsImpl::CollectionColumnStatisticsImpl + (const proto::ColumnStatistics& pb) { + _stats.setNumberOfValues(pb.numberofvalues()); + _stats.setHasNull(pb.hasnull()); + if (!pb.has_collectionstatistics()) { + _stats.setMinimum(0); + _stats.setMaximum(0); + _stats.setSum(0); + } else { + const proto::CollectionStatistics& stats = pb.collectionstatistics(); + _stats.setHasMinimum(stats.has_minchildren()); + _stats.setHasMaximum(stats.has_maxchildren()); + _stats.setHasSum(stats.has_totalchildren()); + + _stats.setMinimum(stats.minchildren()); + _stats.setMaximum(stats.maxchildren()); + _stats.setSum(stats.totalchildren()); + } + } + std::unique_ptr<MutableColumnStatistics> createColumnStatistics( const Type& type) { switch (static_cast<int64_t>(type.getKind())) { @@ -373,9 +411,11 @@ namespace orc { case SHORT: return std::unique_ptr<MutableColumnStatistics>( new IntegerColumnStatisticsImpl()); - case STRUCT: case MAP: case LIST: + return std::unique_ptr<MutableColumnStatistics>( + new CollectionColumnStatisticsImpl()); + case STRUCT: case UNION: return std::unique_ptr<MutableColumnStatistics>( new ColumnStatisticsImpl()); @@ -395,6 +435,7 @@ namespace orc { return std::unique_ptr<MutableColumnStatistics>( new DateColumnStatisticsImpl()); case TIMESTAMP: + case TIMESTAMP_INSTANT: return std::unique_ptr<MutableColumnStatistics>( new TimestampColumnStatisticsImpl()); case DECIMAL: diff --git a/contrib/libs/apache/orc/c++/src/Statistics.hh b/contrib/libs/apache/orc/c++/src/Statistics.hh index ee9db23f867..8cb2283f130 100644 --- a/contrib/libs/apache/orc/c++/src/Statistics.hh +++ b/contrib/libs/apache/orc/c++/src/Statistics.hh @@ -173,6 +173,7 @@ namespace orc { typedef InternalStatisticsImpl<double> InternalDoubleStatistics; typedef InternalStatisticsImpl<Decimal> InternalDecimalStatistics; typedef InternalStatisticsImpl<std::string> InternalStringStatistics; + typedef InternalStatisticsImpl<uint64_t> InternalCollectionStatistics; /** * Mutable column statistics for use by the writer. @@ -665,14 +666,14 @@ namespace orc { proto::DecimalStatistics* decStats = pbStats.mutable_decimalstatistics(); if (_stats.hasMinimum()) { - decStats->set_minimum(TString(_stats.getMinimum().toString())); - decStats->set_maximum(TString(_stats.getMaximum().toString())); + decStats->set_minimum(TString(_stats.getMinimum().toString(true))); + decStats->set_maximum(TString(_stats.getMaximum().toString(true))); } else { decStats->clear_minimum(); decStats->clear_maximum(); } if (_stats.hasSum()) { - decStats->set_sum(TString(_stats.getSum().toString())); + decStats->set_sum(TString(_stats.getSum().toString(true))); } else { decStats->clear_sum(); } @@ -1230,6 +1231,10 @@ namespace orc { bool _hasUpperBound; int64_t _lowerBound; int64_t _upperBound; + int32_t _minimumNanos; // last 6 digits of nanosecond of minimum timestamp + int32_t _maximumNanos; // last 6 digits of nanosecond of maximum timestamp + static constexpr int32_t DEFAULT_MIN_NANOS = 0; + static constexpr int32_t DEFAULT_MAX_NANOS = 999999; public: TimestampColumnStatisticsImpl() { reset(); } @@ -1295,14 +1300,68 @@ namespace orc { _stats.updateMinMax(value); } + void update(int64_t milli, int32_t nano) { + if (!_stats.hasMinimum()) { + _stats.setHasMinimum(true); + _stats.setHasMaximum(true); + _stats.setMinimum(milli); + _stats.setMaximum(milli); + _maximumNanos = _minimumNanos = nano; + } else { + if (milli <= _stats.getMinimum()) { + if (milli < _stats.getMinimum() || nano < _minimumNanos) { + _minimumNanos = nano; + } + _stats.setMinimum(milli); + } + + if (milli >= _stats.getMaximum()) { + if (milli > _stats.getMaximum() || nano > _maximumNanos) { + _maximumNanos = nano; + } + _stats.setMaximum(milli); + } + } + } + void merge(const MutableColumnStatistics& other) override { const TimestampColumnStatisticsImpl& tsStats = dynamic_cast<const TimestampColumnStatisticsImpl&>(other); - _stats.merge(tsStats._stats); + + _stats.setHasNull(_stats.hasNull() || tsStats.hasNull()); + _stats.setNumberOfValues(_stats.getNumberOfValues() + tsStats.getNumberOfValues()); + + if (tsStats.hasMinimum()) { + if (!_stats.hasMinimum()) { + _stats.setHasMinimum(true); + _stats.setHasMaximum(true); + _stats.setMinimum(tsStats.getMinimum()); + _stats.setMaximum(tsStats.getMaximum()); + _minimumNanos = tsStats.getMinimumNanos(); + _maximumNanos = tsStats.getMaximumNanos(); + } else { + if (tsStats.getMaximum() >= _stats.getMaximum()) { + if (tsStats.getMaximum() > _stats.getMaximum() || + tsStats.getMaximumNanos() > _maximumNanos) { + _maximumNanos = tsStats.getMaximumNanos(); + } + _stats.setMaximum(tsStats.getMaximum()); + } + if (tsStats.getMinimum() <= _stats.getMinimum()) { + if (tsStats.getMinimum() < _stats.getMinimum() || + tsStats.getMinimumNanos() < _minimumNanos) { + _minimumNanos = tsStats.getMinimumNanos(); + } + _stats.setMinimum(tsStats.getMinimum()); + } + } + } } void reset() override { _stats.reset(); + _minimumNanos = DEFAULT_MIN_NANOS; + _maximumNanos = DEFAULT_MAX_NANOS; } void toProtoBuf(proto::ColumnStatistics& pbStats) const override { @@ -1314,9 +1373,17 @@ namespace orc { if (_stats.hasMinimum()) { tsStats->set_minimumutc(_stats.getMinimum()); tsStats->set_maximumutc(_stats.getMaximum()); + if (_minimumNanos != DEFAULT_MIN_NANOS) { + tsStats->set_minimumnanos(_minimumNanos + 1); + } + if (_maximumNanos != DEFAULT_MAX_NANOS) { + tsStats->set_maximumnanos(_maximumNanos + 1); + } } else { tsStats->clear_minimumutc(); tsStats->clear_maximumutc(); + tsStats->clear_minimumnanos(); + tsStats->clear_maximumnanos(); } } @@ -1395,6 +1462,186 @@ namespace orc { throw ParseError("UpperBound is not defined."); } } + + int32_t getMinimumNanos() const override { + if (hasMinimum()) { + return _minimumNanos; + } else { + throw ParseError("Minimum is not defined."); + } + } + + int32_t getMaximumNanos() const override { + if (hasMaximum()) { + return _maximumNanos; + } else { + throw ParseError("Maximum is not defined."); + } + } + }; + + class CollectionColumnStatisticsImpl : public CollectionColumnStatistics, + public MutableColumnStatistics { + private: + InternalCollectionStatistics _stats; + + public: + CollectionColumnStatisticsImpl() { reset(); } + CollectionColumnStatisticsImpl(const proto::ColumnStatistics &stats); + virtual ~CollectionColumnStatisticsImpl() override; + + bool hasMinimumChildren() const override { + return _stats.hasMinimum(); + } + + bool hasMaximumChildren() const override { + return _stats.hasMaximum(); + } + + bool hasTotalChildren() const override { + return _stats.hasSum(); + } + + void increase(uint64_t count) override { + _stats.setNumberOfValues(_stats.getNumberOfValues() + count); + } + + uint64_t getNumberOfValues() const override { + return _stats.getNumberOfValues(); + } + + void setNumberOfValues(uint64_t value) override { + _stats.setNumberOfValues(value); + } + + bool hasNull() const override { + return _stats.hasNull(); + } + + void setHasNull(bool hasNull) override { + _stats.setHasNull(hasNull); + } + + uint64_t getMinimumChildren() const override { + if(hasMinimumChildren()) { + return _stats.getMinimum(); + } else { + throw ParseError("MinimumChildren is not defined."); + } + } + + uint64_t getMaximumChildren() const override { + if(hasMaximumChildren()) { + return _stats.getMaximum(); + } else { + throw ParseError("MaximumChildren is not defined."); + } + } + + uint64_t getTotalChildren() const override { + if(hasTotalChildren()) { + return _stats.getSum(); + } else { + throw ParseError("TotalChildren is not defined."); + } + } + + void setMinimumChildren(uint64_t minimum) override { + _stats.setHasMinimum(true); + _stats.setMinimum(minimum); + } + + void setMaximumChildren(uint64_t maximum) override { + _stats.setHasMaximum(true); + _stats.setMaximum(maximum); + } + + void setTotalChildren(uint64_t sum) override { + _stats.setHasSum(true); + _stats.setSum(sum); + } + + void setHasTotalChildren(bool hasSum) override { + _stats.setHasSum(hasSum); + } + + void merge(const MutableColumnStatistics& other) override { + const CollectionColumnStatisticsImpl& collectionStats = + dynamic_cast<const CollectionColumnStatisticsImpl&>(other); + + _stats.merge(collectionStats._stats); + + // hasSumValue here means no overflow + _stats.setHasSum(_stats.hasSum() && collectionStats.hasTotalChildren()); + if (_stats.hasSum()) { + uint64_t oldSum = _stats.getSum(); + _stats.setSum(_stats.getSum() + collectionStats.getTotalChildren()); + if (oldSum > _stats.getSum()) { + _stats.setHasSum(false); + } + } + } + + void reset() override { + _stats.reset(); + setTotalChildren(0); + } + + void update(uint64_t value) { + _stats.updateMinMax(value); + if (_stats.hasSum()) { + uint64_t oldSum = _stats.getSum(); + _stats.setSum(_stats.getSum() + value); + if (oldSum > _stats.getSum()) { + _stats.setHasSum(false); + } + } + } + + void toProtoBuf(proto::ColumnStatistics &pbStats) const override { + pbStats.set_hasnull(_stats.hasNull()); + pbStats.set_numberofvalues(_stats.getNumberOfValues()); + + proto::CollectionStatistics* collectionStats = + pbStats.mutable_collectionstatistics(); + if (_stats.hasMinimum()) { + collectionStats->set_minchildren(_stats.getMinimum()); + collectionStats->set_maxchildren(_stats.getMaximum()); + } else { + collectionStats->clear_minchildren(); + collectionStats->clear_maxchildren(); + } + if (_stats.hasSum()) { + collectionStats->set_totalchildren(_stats.getSum()); + } else { + collectionStats->clear_totalchildren(); + } + } + + std::string toString() const override { + std::ostringstream buffer; + buffer << "Data type: Collection(LIST|MAP)" << std::endl + << "Values: " << getNumberOfValues() << std::endl + << "Has null: " << (hasNull() ? "yes" : "no") << std::endl; + if (hasMinimumChildren()) { + buffer << "MinChildren: " << getMinimumChildren() << std::endl; + } else { + buffer << "MinChildren is not defined" << std::endl; + } + + if (hasMaximumChildren()) { + buffer << "MaxChildren: " << getMaximumChildren() << std::endl; + } else { + buffer << "MaxChildren is not defined" << std::endl; + } + + if (hasTotalChildren()) { + buffer << "TotalChildren: " << getTotalChildren() << std::endl; + } else { + buffer << "TotalChildren is not defined" << std::endl; + } + return buffer.str(); + } }; ColumnStatistics* convertColumnStatistics(const proto::ColumnStatistics& s, diff --git a/contrib/libs/apache/orc/c++/src/StripeStream.cc b/contrib/libs/apache/orc/c++/src/StripeStream.cc index b63f19d28e0..6d6dda8328c 100644 --- a/contrib/libs/apache/orc/c++/src/StripeStream.cc +++ b/contrib/libs/apache/orc/c++/src/StripeStream.cc @@ -30,14 +30,16 @@ namespace orc { const proto::StripeFooter& _footer, uint64_t _stripeStart, InputStream& _input, - const Timezone& _writerTimezone + const Timezone& _writerTimezone, + const Timezone& _readerTimezone ): reader(_reader), stripeInfo(_stripeInfo), footer(_footer), stripeIndex(_index), stripeStart(_stripeStart), input(_input), - writerTimezone(_writerTimezone) { + writerTimezone(_writerTimezone), + readerTimezone(_readerTimezone) { // PASS } @@ -71,6 +73,10 @@ namespace orc { return writerTimezone; } + const Timezone& StripeStreamsImpl::getReaderTimezone() const { + return readerTimezone; + } + std::ostream* StripeStreamsImpl::getErrorStream() const { return reader.getFileContents().errorStream; } @@ -121,6 +127,10 @@ namespace orc { return reader.getThrowOnHive11DecimalOverflow(); } + bool StripeStreamsImpl::isDecimalAsLong() const { + return reader.getIsDecimalAsLong(); + } + int32_t StripeStreamsImpl::getForcedScaleOnHive11Decimal() const { return reader.getForcedScaleOnHive11Decimal(); } diff --git a/contrib/libs/apache/orc/c++/src/StripeStream.hh b/contrib/libs/apache/orc/c++/src/StripeStream.hh index 5cbaf60a69d..8d9fb065273 100644 --- a/contrib/libs/apache/orc/c++/src/StripeStream.hh +++ b/contrib/libs/apache/orc/c++/src/StripeStream.hh @@ -43,6 +43,7 @@ namespace orc { const uint64_t stripeStart; InputStream& input; const Timezone& writerTimezone; + const Timezone& readerTimezone; public: StripeStreamsImpl(const RowReaderImpl& reader, uint64_t index, @@ -50,7 +51,8 @@ namespace orc { const proto::StripeFooter& footer, uint64_t stripeStart, InputStream& input, - const Timezone& writerTimezone); + const Timezone& writerTimezone, + const Timezone& readerTimezone); virtual ~StripeStreamsImpl() override; @@ -68,10 +70,14 @@ namespace orc { const Timezone& getWriterTimezone() const override; + const Timezone& getReaderTimezone() const override; + std::ostream* getErrorStream() const override; bool getThrowOnHive11DecimalOverflow() const override; + bool isDecimalAsLong() const override; + int32_t getForcedScaleOnHive11Decimal() const override; }; diff --git a/contrib/libs/apache/orc/c++/src/Timezone.hh b/contrib/libs/apache/orc/c++/src/Timezone.hh index 136b7a18b76..6c8b8612593 100644 --- a/contrib/libs/apache/orc/c++/src/Timezone.hh +++ b/contrib/libs/apache/orc/c++/src/Timezone.hh @@ -42,6 +42,10 @@ namespace orc { bool isDst; std::string name; + bool hasSameTzRule(const TimezoneVariant& other) const { + return gmtOffset == other.gmtOffset && isDst == other.isDst; + } + std::string toString() const; }; diff --git a/contrib/libs/apache/orc/c++/src/TypeImpl.cc b/contrib/libs/apache/orc/c++/src/TypeImpl.cc index c154f2af04d..14517ce164b 100644 --- a/contrib/libs/apache/orc/c++/src/TypeImpl.cc +++ b/contrib/libs/apache/orc/c++/src/TypeImpl.cc @@ -67,19 +67,12 @@ namespace orc { columnId = static_cast<int64_t>(root); uint64_t current = root + 1; for(uint64_t i=0; i < subtypeCount; ++i) { - current = dynamic_cast<TypeImpl*>(subTypes[i])->assignIds(current); + current = dynamic_cast<TypeImpl*>(subTypes[i].get())->assignIds(current); } maximumColumnId = static_cast<int64_t>(current) - 1; return current; } - TypeImpl::~TypeImpl() { - for (std::vector<Type*>::iterator it = subTypes.begin(); - it != subTypes.end(); it++) { - delete (*it) ; - } - } - void TypeImpl::ensureIdAssigned() const { if (columnId == -1) { const TypeImpl* root = this; @@ -109,7 +102,7 @@ namespace orc { } const Type* TypeImpl::getSubtype(uint64_t i) const { - return subTypes[i]; + return subTypes[i].get(); } const std::string& TypeImpl::getFieldName(uint64_t i) const { @@ -128,14 +121,50 @@ namespace orc { return scale; } + Type& TypeImpl::setAttribute(const std::string& key, + const std::string& value) { + attributes[key] = value; + return *this; + } + + bool TypeImpl::hasAttributeKey(const std::string& key) const { + return attributes.find(key) != attributes.end(); + } + + Type& TypeImpl::removeAttribute(const std::string& key) { + auto it = attributes.find(key); + if (it == attributes.end()) { + throw std::range_error("Key not found: " + key); + } + attributes.erase(it); + return *this; + } + + std::vector<std::string> TypeImpl::getAttributeKeys() const { + std::vector<std::string> ret; + ret.reserve(attributes.size()); + for (auto& attribute : attributes) { + ret.push_back(attribute.first); + } + return ret; + } + + std::string TypeImpl::getAttributeValue(const std::string& key) const { + auto it = attributes.find(key); + if (it == attributes.end()) { + throw std::range_error("Key not found: " + key); + } + return it->second; + } + void TypeImpl::setIds(uint64_t _columnId, uint64_t _maxColumnId) { columnId = static_cast<int64_t>(_columnId); maximumColumnId = static_cast<int64_t>(_maxColumnId); } void TypeImpl::addChildType(std::unique_ptr<Type> childType) { - TypeImpl* child = dynamic_cast<TypeImpl*>(childType.release()); - subTypes.push_back(child); + TypeImpl* child = dynamic_cast<TypeImpl*>(childType.get()); + subTypes.push_back(std::move(childType)); if (child != nullptr) { child->parent = this; } @@ -154,6 +183,15 @@ namespace orc { return this; } + bool isUnquotedFieldName(std::string fieldName) { + for (auto &ch : fieldName) { + if (!isalnum(ch) && ch != '_') { + return false; + } + } + return true; + } + std::string TypeImpl::toString() const { switch (static_cast<int64_t>(kind)) { case BOOLEAN: @@ -176,6 +214,8 @@ namespace orc { return "binary"; case TIMESTAMP: return "timestamp"; + case TIMESTAMP_INSTANT: + return "timestamp with local time zone"; case LIST: return "array<" + (subTypes[0] ? subTypes[0]->toString() : "void") + ">"; case MAP: @@ -187,7 +227,19 @@ namespace orc { if (i != 0) { result += ","; } - result += fieldNames[i]; + if (isUnquotedFieldName(fieldNames[i])) { + result += fieldNames[i]; + } else { + std::string name(fieldNames[i]); + size_t pos = 0; + while ((pos = name.find("`", pos)) != std::string::npos) { + name.replace(pos, 1, "``"); + pos += 2; + } + result += "`"; + result += name; + result += "`"; + } result += ":"; result += subTypes[i]->toString(); } @@ -257,6 +309,7 @@ namespace orc { (new StringVectorBatch(capacity, memoryPool)); case TIMESTAMP: + case TIMESTAMP_INSTANT: return std::unique_ptr<ColumnVectorBatch> (new TimestampVectorBatch(capacity, memoryPool)); @@ -359,6 +412,7 @@ namespace orc { std::string printProtobufMessage(const google::protobuf::Message& message); std::unique_ptr<Type> convertType(const proto::Type& type, const proto::Footer& footer) { + std::unique_ptr<Type> ret; switch (static_cast<int64_t>(type.kind())) { case proto::Type_Kind_BOOLEAN: @@ -371,25 +425,29 @@ namespace orc { case proto::Type_Kind_STRING: case proto::Type_Kind_BINARY: case proto::Type_Kind_TIMESTAMP: + case proto::Type_Kind_TIMESTAMP_INSTANT: case proto::Type_Kind_DATE: - return std::unique_ptr<Type> + ret = std::unique_ptr<Type> (new TypeImpl(static_cast<TypeKind>(type.kind()))); + break; case proto::Type_Kind_CHAR: case proto::Type_Kind_VARCHAR: - return std::unique_ptr<Type> + ret = std::unique_ptr<Type> (new TypeImpl(static_cast<TypeKind>(type.kind()), type.maximumlength())); + break; case proto::Type_Kind_DECIMAL: - return std::unique_ptr<Type> + ret = std::unique_ptr<Type> (new TypeImpl(DECIMAL, type.precision(), type.scale())); + break; case proto::Type_Kind_LIST: case proto::Type_Kind_MAP: case proto::Type_Kind_UNION: { TypeImpl* result = new TypeImpl(static_cast<TypeKind>(type.kind())); - std::unique_ptr<Type> return_value = std::unique_ptr<Type>(result); + ret = std::unique_ptr<Type>(result); if (type.kind() == proto::Type_Kind_LIST && type.subtypes_size() != 1) throw ParseError("Illegal LIST type that doesn't contain one subtype"); if (type.kind() == proto::Type_Kind_MAP && type.subtypes_size() != 2) @@ -401,23 +459,30 @@ namespace orc { (type.subtypes(i))), footer)); } - return return_value; + break; } case proto::Type_Kind_STRUCT: { TypeImpl* result = new TypeImpl(STRUCT); - std::unique_ptr<Type> return_value = std::unique_ptr<Type>(result); + ret = std::unique_ptr<Type>(result); + if (type.subtypes_size() > type.fieldnames_size()) + throw ParseError("Illegal STRUCT type that contains less fieldnames than subtypes"); for(int i=0; i < type.subtypes_size(); ++i) { result->addStructField(type.fieldnames(i), convertType(footer.types(static_cast<int> (type.subtypes(i))), footer)); } - return return_value; + break; } default: throw NotImplementedYet("Unknown type kind"); } + for (int i = 0; i < type.attributes_size(); ++i) { + const auto& attribute = type.attributes(i); + ret->setAttribute(attribute.key(), attribute.value()); + } + return ret; } /** @@ -446,6 +511,7 @@ namespace orc { case STRING: case BINARY: case TIMESTAMP: + case TIMESTAMP_INSTANT: case DATE: result = new TypeImpl(fileType->getKind()); break; @@ -503,16 +569,21 @@ namespace orc { throw NotImplementedYet("Unknown type kind"); } result->setIds(fileType->getColumnId(), fileType->getMaximumColumnId()); + for (auto& key : fileType->getAttributeKeys()) { + const auto& value = fileType->getAttributeValue(key); + result->setAttribute(key, value); + } return std::unique_ptr<Type>(result); } ORC_UNIQUE_PTR<Type> Type::buildTypeFromString(const std::string& input) { - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > res = - TypeImpl::parseType(input, 0, input.size()); - if (res.size() != 1) { + size_t size = input.size(); + std::pair<ORC_UNIQUE_PTR<Type>, size_t> res = + TypeImpl::parseType(input, 0, size); + if (res.second != size) { throw std::logic_error("Invalid type string."); } - return std::move(res[0].second); + return std::move(res.first); } std::unique_ptr<Type> TypeImpl::parseArrayType(const std::string &input, @@ -520,45 +591,107 @@ namespace orc { size_t end) { TypeImpl* arrayType = new TypeImpl(LIST); std::unique_ptr<Type> return_value = std::unique_ptr<Type>(arrayType); - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > v = - TypeImpl::parseType(input, start, end); - if (v.size() != 1) { - throw std::logic_error("Array type must contain exactly one sub type."); + if (input[start] != '<') { + throw std::logic_error("Missing < after array."); + } + std::pair<ORC_UNIQUE_PTR<Type>, size_t> res = + TypeImpl::parseType(input, start + 1, end); + if (res.second != end) { + throw std::logic_error( + "Array type must contain exactly one sub type."); } - arrayType->addChildType(std::move(v[0].second)); + arrayType->addChildType(std::move(res.first)); return return_value; } std::unique_ptr<Type> TypeImpl::parseMapType(const std::string &input, size_t start, size_t end) { - TypeImpl * mapType = new TypeImpl(MAP); + TypeImpl* mapType = new TypeImpl(MAP); std::unique_ptr<Type> return_value = std::unique_ptr<Type>(mapType); - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > v = - TypeImpl::parseType(input, start, end); - if (v.size() != 2) { + if (input[start] != '<') { + throw std::logic_error("Missing < after map."); + } + std::pair<ORC_UNIQUE_PTR<Type>, size_t> key = + TypeImpl::parseType(input, start + 1, end); + if (input[key.second] != ',') { + throw std::logic_error("Missing comma after key."); + } + std::pair<ORC_UNIQUE_PTR<Type>, size_t> val = + TypeImpl::parseType(input, key.second + 1, end); + if (val.second != end) { throw std::logic_error( "Map type must contain exactly two sub types."); } - mapType->addChildType(std::move(v[0].second)); - mapType->addChildType(std::move(v[1].second)); + mapType->addChildType(std::move(key.first)); + mapType->addChildType(std::move(val.first)); return return_value; } + std::pair<std::string, size_t> TypeImpl::parseName(const std::string &input, + const size_t start, + const size_t end) { + size_t pos = start; + if (input[pos] == '`') { + bool closed = false; + std::ostringstream oss; + while (pos < end) { + char ch = input[++pos]; + if (ch == '`') { + if (pos < end && input[pos+1] == '`') { + ++pos; + oss.put('`'); + } else { + closed = true; + break; + } + } else { + oss.put(ch); + } + } + if (!closed) { + throw std::logic_error("Invalid field name. Unmatched quote"); + } + if (oss.tellp() == std::streamoff(0)) { + throw std::logic_error("Empty quoted field name."); + } + return std::make_pair(oss.str(), pos + 1); + } else { + while (pos < end && (isalnum(input[pos]) || input[pos] == '_')) { + ++pos; + } + if (pos == start) { + throw std::logic_error("Missing field name."); + } + return std::make_pair(input.substr(start, pos - start), pos); + } + } + std::unique_ptr<Type> TypeImpl::parseStructType(const std::string &input, size_t start, size_t end) { TypeImpl* structType = new TypeImpl(STRUCT); std::unique_ptr<Type> return_value = std::unique_ptr<Type>(structType); - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type>> > v = - TypeImpl::parseType(input, start, end); - if (v.size() == 0) { - throw std::logic_error( - "Struct type must contain at least one sub type."); + size_t pos = start + 1; + if (input[start] != '<') { + throw std::logic_error("Missing < after struct."); } - for (size_t i = 0; i < v.size(); ++i) { - structType->addStructField(v[i].first, std::move(v[i].second)); + while (pos < end) { + std::pair<std::string, size_t> nameRes = parseName(input, pos, end); + pos = nameRes.second; + if (input[pos] != ':') { + throw std::logic_error("Invalid struct type. No field name set."); + } + std::pair<ORC_UNIQUE_PTR<Type>, size_t> typeRes = + TypeImpl::parseType(input, ++pos, end); + structType->addStructField(nameRes.first, std::move(typeRes.first)); + pos = typeRes.second; + if (pos != end && input[pos] != ',') { + throw std::logic_error("Missing comma after field."); + } + ++pos; } + return return_value; } @@ -567,55 +700,90 @@ namespace orc { size_t end) { TypeImpl* unionType = new TypeImpl(UNION); std::unique_ptr<Type> return_value = std::unique_ptr<Type>(unionType); - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > v = - TypeImpl::parseType(input, start, end); - if (v.size() == 0) { - throw std::logic_error("Union type must contain at least one sub type."); + size_t pos = start + 1; + if (input[start] != '<') { + throw std::logic_error("Missing < after uniontype."); } - for (size_t i = 0; i < v.size(); ++i) { - unionType->addChildType(std::move(v[i].second)); + while (pos < end) { + std::pair<ORC_UNIQUE_PTR<Type>, size_t> res = + TypeImpl::parseType(input, pos, end); + unionType->addChildType(std::move(res.first)); + pos = res.second; + if (pos != end && input[pos] != ',') { + throw std::logic_error("Missing comma after union sub type."); + } + ++pos; } + return return_value; } std::unique_ptr<Type> TypeImpl::parseDecimalType(const std::string &input, size_t start, size_t end) { - size_t sep = input.find(',', start); + if (input[start] != '(') { + throw std::logic_error("Missing ( after decimal."); + } + size_t pos = start + 1; + size_t sep = input.find(',', pos); if (sep + 1 >= end || sep == std::string::npos) { throw std::logic_error("Decimal type must specify precision and scale."); } uint64_t precision = - static_cast<uint64_t>(atoi(input.substr(start, sep - start).c_str())); + static_cast<uint64_t>(atoi(input.substr(pos, sep - pos).c_str())); uint64_t scale = static_cast<uint64_t>(atoi(input.substr(sep + 1, end - sep - 1).c_str())); return std::unique_ptr<Type>(new TypeImpl(DECIMAL, precision, scale)); } + void validatePrimitiveType(std::string category, + const std::string &input, + const size_t pos) { + if (input[pos] == '<' || input[pos] == '(') { + std::ostringstream oss; + oss << "Invalid " << input[pos] << " after " + << category << " type."; + throw std::logic_error(oss.str()); + } + } + std::unique_ptr<Type> TypeImpl::parseCategory(std::string category, const std::string &input, size_t start, size_t end) { if (category == "boolean") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(BOOLEAN)); } else if (category == "tinyint") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(BYTE)); } else if (category == "smallint") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(SHORT)); } else if (category == "int") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(INT)); } else if (category == "bigint") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(LONG)); } else if (category == "float") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(FLOAT)); } else if (category == "double") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(DOUBLE)); } else if (category == "string") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(STRING)); } else if (category == "binary") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(BINARY)); } else if (category == "timestamp") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(TIMESTAMP)); + } else if (category == "timestamp with local time zone") { + validatePrimitiveType(category, input, start); + return std::unique_ptr<Type>(new TypeImpl(TIMESTAMP_INSTANT)); } else if (category == "array") { return parseArrayType(input, start, end); } else if (category == "map") { @@ -627,81 +795,63 @@ namespace orc { } else if (category == "decimal") { return parseDecimalType(input, start, end); } else if (category == "date") { + validatePrimitiveType(category, input, start); return std::unique_ptr<Type>(new TypeImpl(DATE)); } else if (category == "varchar") { + if (input[start] != '(') { + throw std::logic_error("Missing ( after varchar."); + } uint64_t maxLength = static_cast<uint64_t>( - atoi(input.substr(start, end - start).c_str())); + atoi(input.substr(start + 1, end - start + 1).c_str())); return std::unique_ptr<Type>(new TypeImpl(VARCHAR, maxLength)); } else if (category == "char") { + if (input[start] != '(') { + throw std::logic_error("Missing ( after char."); + } uint64_t maxLength = static_cast<uint64_t>( - atoi(input.substr(start, end - start).c_str())); + atoi(input.substr(start + 1, end - start + 1).c_str())); return std::unique_ptr<Type>(new TypeImpl(CHAR, maxLength)); } else { throw std::logic_error("Unknown type " + category); } } - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > TypeImpl::parseType( - const std::string &input, - size_t start, - size_t end) { - std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > res; + std::pair<ORC_UNIQUE_PTR<Type>, size_t> TypeImpl::parseType(const std::string &input, size_t start, size_t end) { size_t pos = start; - - while (pos < end) { - size_t endPos = pos; - while (endPos < end && (isalnum(input[endPos]) || input[endPos] == '_')) { - ++endPos; - } - - std::string fieldName; - if (input[endPos] == ':') { - fieldName = input.substr(pos, endPos - pos); - pos = ++endPos; - while (endPos < end && isalpha(input[endPos])) { - ++endPos; + while (pos < end && (isalpha(input[pos]) || input[pos] == ' ')) { + ++pos; + } + size_t endPos = pos; + size_t nextPos = pos + 1; + if (input[pos] == '<') { + int count = 1; + while (nextPos < end) { + if (input[nextPos] == '<') { + ++count; + } else if (input[nextPos] == '>') { + --count; } - } - - size_t nextPos = endPos + 1; - if (input[endPos] == '<') { - int count = 1; - while (nextPos < end) { - if (input[nextPos] == '<') { - ++count; - } else if (input[nextPos] == '>') { - --count; - } - if (count == 0) { - break; - } - ++nextPos; - } - if (nextPos == end) { - throw std::logic_error("Invalid type string. Cannot find closing >"); - } - } else if (input[endPos] == '(') { - while (nextPos < end && input[nextPos] != ')') { - ++nextPos; + if (count == 0) { + break; } - if (nextPos == end) { - throw std::logic_error("Invalid type string. Cannot find closing )"); - } - } else if (input[endPos] != ',' && endPos != end) { - throw std::logic_error("Unrecognized character."); + ++nextPos; } - - std::string category = input.substr(pos, endPos - pos); - res.push_back(std::make_pair(fieldName, parseCategory(category, input, endPos + 1, nextPos))); - - if (nextPos < end && (input[nextPos] == ')' || input[nextPos] == '>')) { - pos = nextPos + 2; - } else { - pos = nextPos; + if (nextPos == end) { + throw std::logic_error("Invalid type string. Cannot find closing >"); + } + endPos = nextPos + 1; + } else if (input[pos] == '(') { + while (nextPos < end && input[nextPos] != ')') { + ++nextPos; + } + if (nextPos == end) { + throw std::logic_error("Invalid type string. Cannot find closing )"); } + endPos = nextPos + 1; } - return res; + std::string category = input.substr(start, pos - start); + return std::make_pair(parseCategory(category, input, pos, nextPos), endPos); } } diff --git a/contrib/libs/apache/orc/c++/src/TypeImpl.hh b/contrib/libs/apache/orc/c++/src/TypeImpl.hh index 054ceab5dce..88c4737d181 100644 --- a/contrib/libs/apache/orc/c++/src/TypeImpl.hh +++ b/contrib/libs/apache/orc/c++/src/TypeImpl.hh @@ -34,12 +34,13 @@ namespace orc { mutable int64_t columnId; mutable int64_t maximumColumnId; TypeKind kind; - std::vector<Type*> subTypes; + std::vector<std::unique_ptr<Type>> subTypes; std::vector<std::string> fieldNames; uint64_t subtypeCount; uint64_t maxLength; uint64_t precision; uint64_t scale; + std::map<std::string, std::string> attributes; public: /** @@ -58,8 +59,6 @@ namespace orc { TypeImpl(TypeKind kind, uint64_t precision, uint64_t scale); - virtual ~TypeImpl() override; - uint64_t getColumnId() const override; uint64_t getMaximumColumnId() const override; @@ -78,6 +77,17 @@ namespace orc { uint64_t getScale() const override; + Type& setAttribute(const std::string& key, + const std::string& value) override; + + bool hasAttributeKey(const std::string& key) const override; + + Type& removeAttribute(const std::string& key) override; + + std::vector<std::string> getAttributeKeys() const override; + + std::string getAttributeValue(const std::string& key) const override; + std::string toString() const override; Type* addStructField(const std::string& fieldName, @@ -99,7 +109,7 @@ namespace orc { */ void addChildType(std::unique_ptr<Type> childType); - static std::vector<std::pair<std::string, std::unique_ptr<Type> > > parseType( + static std::pair<ORC_UNIQUE_PTR<Type>, size_t> parseType( const std::string &input, size_t start, size_t end); @@ -138,6 +148,16 @@ namespace orc { size_t end); /** + * Parse field name from string + * @param input the input string of a field name + * @param start start position of the input string + * @param end end position of the input string + */ + static std::pair<std::string, size_t> parseName(const std::string &input, + const size_t start, + const size_t end); + + /** * Parse struct type from string * @param input the input string of a struct type * @param start start position of the input string diff --git a/contrib/libs/apache/orc/c++/src/Vector.cc b/contrib/libs/apache/orc/c++/src/Vector.cc index 14c0ded0300..fefaaad4b19 100644 --- a/contrib/libs/apache/orc/c++/src/Vector.cc +++ b/contrib/libs/apache/orc/c++/src/Vector.cc @@ -149,6 +149,13 @@ namespace orc { return buffer.str(); } + void EncodedStringVectorBatch::resize(uint64_t cap) { + if (capacity < cap) { + StringVectorBatch::resize(cap); + index.resize(cap); + } + } + StringVectorBatch::StringVectorBatch(uint64_t _capacity, MemoryPool& pool ): ColumnVectorBatch(_capacity, pool), data(pool, _capacity), @@ -287,8 +294,8 @@ namespace orc { std::string MapVectorBatch::toString() const { std::ostringstream buffer; - buffer << "Map vector <" << keys->toString() << ", " - << elements->toString() << " with " + buffer << "Map vector <" << (keys ? keys->toString(): "key not selected") << ", " + << (elements ? elements->toString(): "value not selected") << " with " << numElements << " of " << capacity << ">"; return buffer.str(); } @@ -309,8 +316,8 @@ namespace orc { uint64_t MapVectorBatch::getMemoryUsage() { return ColumnVectorBatch::getMemoryUsage() + static_cast<uint64_t>(offsets.capacity() * sizeof(int64_t)) - + keys->getMemoryUsage() - + elements->getMemoryUsage(); + + (keys ? keys->getMemoryUsage() : 0) + + (elements ? elements->getMemoryUsage() : 0); } bool MapVectorBatch::hasVariableLength() { @@ -475,8 +482,8 @@ namespace orc { // PASS } - std::string Decimal::toString() const { - return value.toDecimalString(scale); + std::string Decimal::toString(bool trimTrailingZeros) const { + return value.toDecimalString(scale, trimTrailingZeros); } TimestampVectorBatch::TimestampVectorBatch(uint64_t _capacity, diff --git a/contrib/libs/apache/orc/c++/src/Writer.cc b/contrib/libs/apache/orc/c++/src/Writer.cc index b5bd19b3046..8a7d10ba812 100644 --- a/contrib/libs/apache/orc/c++/src/Writer.cc +++ b/contrib/libs/apache/orc/c++/src/Writer.cc @@ -41,6 +41,7 @@ namespace orc { std::set<uint64_t> columnsUseBloomFilter; double bloomFilterFalsePositiveProb; BloomFilterVersion bloomFilterVersion; + std::string timezone; WriterOptionsPrivate() : fileVersion(FileVersion::v_0_12()) { // default to Hive_0_12 @@ -56,6 +57,10 @@ namespace orc { enableIndex = true; bloomFilterFalsePositiveProb = 0.05; bloomFilterVersion = UTF8; + //Writer timezone uses "GMT" by default to get rid of potential issues + //introduced by moving timestamps between different timezones. + //Explictly set the writer timezone if the use case depends on it. + timezone = "GMT"; } }; @@ -73,9 +78,7 @@ namespace orc { WriterOptions::WriterOptions(WriterOptions& rhs) { // swap privateBits with rhs - WriterOptionsPrivate* l = privateBits.release(); - privateBits.reset(rhs.privateBits.release()); - rhs.privateBits.reset(l); + privateBits.swap(rhs.privateBits); } WriterOptions& WriterOptions::operator=(const WriterOptions& rhs) { @@ -140,6 +143,14 @@ namespace orc { privateBits->fileVersion = version; return *this; } + if (version == FileVersion::UNSTABLE_PRE_2_0()) { + *privateBits->errorStream << "Warning: ORC files written in " + << FileVersion::UNSTABLE_PRE_2_0().toString() + << " will not be readable by other versions of the software." + << " It is only for developer testing.\n"; + privateBits->fileVersion = version; + return *this; + } throw std::logic_error("Unsupported file version specified."); } @@ -231,6 +242,19 @@ namespace orc { return privateBits->bloomFilterVersion; } + const Timezone& WriterOptions::getTimezone() const { + return getTimezoneByName(privateBits->timezone); + } + + const std::string& WriterOptions::getTimezoneName() const { + return privateBits->timezone; + } + + WriterOptions& WriterOptions::setTimezoneName(const std::string& zone) { + privateBits->timezone = zone; + return *this; + } + Writer::~Writer() { // PASS } @@ -442,9 +466,7 @@ namespace orc { *stripeFooter.add_columns() = encodings[i]; } - // use GMT to guarantee TimestampVectorBatch from reader can write - // same wall clock time - stripeFooter.set_writertimezone("GMT"); + stripeFooter.set_writertimezone(TString(options.getTimezoneName())); // add stripe statistics to metadata proto::StripeStatistics* stripeStats = metadata.add_stripestats(); @@ -572,6 +594,10 @@ namespace orc { protoType.set_kind(proto::Type_Kind_TIMESTAMP); break; } + case TIMESTAMP_INSTANT: { + protoType.set_kind(proto::Type_Kind_TIMESTAMP_INSTANT); + break; + } case LIST: { protoType.set_kind(proto::Type_Kind_LIST); break; @@ -608,6 +634,13 @@ namespace orc { throw std::logic_error("Unknown type."); } + for (auto& key : t.getAttributeKeys()) { + const auto& value = t.getAttributeValue(key); + auto protoAttr = protoType.add_attributes(); + protoAttr->set_key(TString(key)); + protoAttr->set_value(TString(value)); + } + int pos = static_cast<int>(index); *footer.add_types() = protoType; diff --git a/contrib/libs/apache/orc/c++/src/io/InputStream.cc b/contrib/libs/apache/orc/c++/src/io/InputStream.cc index 6e54b1412fd..ec798d4ed76 100644 --- a/contrib/libs/apache/orc/c++/src/io/InputStream.cc +++ b/contrib/libs/apache/orc/c++/src/io/InputStream.cc @@ -52,6 +52,10 @@ namespace orc { return result; } + uint64_t PositionProvider::current() { + return *position; + } + SeekableInputStream::~SeekableInputStream() { // PASS } diff --git a/contrib/libs/apache/orc/c++/src/io/InputStream.hh b/contrib/libs/apache/orc/c++/src/io/InputStream.hh index d8bd3d4d8ce..ab7ecedb445 100644 --- a/contrib/libs/apache/orc/c++/src/io/InputStream.hh +++ b/contrib/libs/apache/orc/c++/src/io/InputStream.hh @@ -41,6 +41,7 @@ namespace orc { public: PositionProvider(const std::list<uint64_t>& positions); uint64_t next(); + uint64_t current(); }; /** diff --git a/contrib/libs/apache/orc/c++/src/io/OutputStream.cc b/contrib/libs/apache/orc/c++/src/io/OutputStream.cc index 11a21c0bd35..14d5e5e7c4d 100644 --- a/contrib/libs/apache/orc/c++/src/io/OutputStream.cc +++ b/contrib/libs/apache/orc/c++/src/io/OutputStream.cc @@ -97,6 +97,10 @@ namespace orc { return dataSize; } + void BufferedOutputStream::suppress() { + dataBuffer->resize(0); + } + void AppendOnlyBufferedStream::write(const char * data, size_t size) { size_t dataOffset = 0; while (size > 0) { diff --git a/contrib/libs/apache/orc/c++/src/io/OutputStream.hh b/contrib/libs/apache/orc/c++/src/io/OutputStream.hh index 7ce9fafa240..0fb92465e95 100644 --- a/contrib/libs/apache/orc/c++/src/io/OutputStream.hh +++ b/contrib/libs/apache/orc/c++/src/io/OutputStream.hh @@ -62,6 +62,7 @@ namespace orc { virtual std::string getName() const; virtual uint64_t getSize() const; virtual uint64_t flush(); + virtual void suppress(); virtual bool isCompressed() const { return false; } }; diff --git a/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.cc b/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.cc new file mode 100644 index 00000000000..e7d87083d8c --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.cc @@ -0,0 +1,192 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExpressionTree.hh" + +#include <cassert> +#include <sstream> + +namespace orc { + + ExpressionTree::ExpressionTree(Operator op) + : mOperator(op) + , mLeaf(UNUSED_LEAF) + , mConstant(TruthValue::YES_NO_NULL) { + } + + + ExpressionTree::ExpressionTree(Operator op, + std::initializer_list<TreeNode> children) + : mOperator(op) + , mChildren(children.begin(), children.end()) + , mLeaf(UNUSED_LEAF) + , mConstant(TruthValue::YES_NO_NULL) { + // PASS + } + + ExpressionTree::ExpressionTree(size_t leaf) + : mOperator(Operator::LEAF) + , mChildren() + , mLeaf(leaf) + , mConstant(TruthValue::YES_NO_NULL) { + // PASS + } + + ExpressionTree::ExpressionTree(TruthValue constant) + : mOperator(Operator::CONSTANT) + , mChildren() + , mLeaf(UNUSED_LEAF) + , mConstant(constant) { + // PASS + } + + ExpressionTree::ExpressionTree(const ExpressionTree& other) + : mOperator(other.mOperator) + , mLeaf(other.mLeaf) + , mConstant(other.mConstant) { + for (TreeNode child : other.mChildren) { + mChildren.emplace_back(std::make_shared<ExpressionTree>(*child)); + } + } + + ExpressionTree::Operator ExpressionTree::getOperator() const { + return mOperator; + } + + const std::vector<TreeNode>& ExpressionTree::getChildren() const { + return mChildren; + } + + std::vector<TreeNode>& ExpressionTree::getChildren() { + return const_cast<std::vector<TreeNode>&>( + const_cast<const ExpressionTree *>(this)->getChildren()); + } + + const TreeNode ExpressionTree::getChild(size_t i) const { + return mChildren.at(i); + } + + TreeNode ExpressionTree::getChild(size_t i) { + return std::const_pointer_cast<ExpressionTree>( + const_cast<const ExpressionTree *>(this)->getChild(i)); + } + + TruthValue ExpressionTree::getConstant() const { + assert(mOperator == Operator::CONSTANT); + return mConstant; + } + + size_t ExpressionTree::getLeaf() const { + assert(mOperator == Operator::LEAF); + return mLeaf; + } + + void ExpressionTree::setLeaf(size_t leaf) { + assert(mOperator == Operator::LEAF); + mLeaf = leaf; + } + + void ExpressionTree::addChild(TreeNode child) { + mChildren.push_back(child); + } + + TruthValue ExpressionTree::evaluate( + const std::vector<TruthValue>& leaves) const { + TruthValue result; + switch (mOperator) { + case Operator::OR: + { + result = mChildren.at(0)->evaluate(leaves); + for (size_t i = 1; i < mChildren.size() && !isNeeded(result); ++i) { + result = mChildren.at(i)->evaluate(leaves) || result; + } + return result; + } + case Operator::AND: + { + result = mChildren.at(0)->evaluate(leaves); + for (size_t i = 1; i < mChildren.size() && isNeeded(result); ++i) { + result = mChildren.at(i)->evaluate(leaves) && result; + } + return result; + } + case Operator::NOT: + return !mChildren.at(0)->evaluate(leaves); + case Operator::LEAF: + return leaves[mLeaf]; + case Operator::CONSTANT: + return mConstant; + default: + throw std::invalid_argument("Unknown operator!"); + } + } + + std::string to_string(TruthValue truthValue) { + switch (truthValue) { + case TruthValue::YES: + return "YES"; + case TruthValue::NO: + return "NO"; + case TruthValue::IS_NULL: + return "IS_NULL"; + case TruthValue::YES_NULL: + return "YES_NULL"; + case TruthValue::NO_NULL: + return "NO_NULL"; + case TruthValue::YES_NO: + return "YES_NO"; + case TruthValue::YES_NO_NULL: + return "YES_NO_NULL"; + default: + throw std::invalid_argument("unknown TruthValue!"); + } + } + + std::string ExpressionTree::toString() const { + std::ostringstream sstream; + switch (mOperator) { + case Operator::OR: + sstream << "(or"; + for (const auto& child : mChildren) { + sstream << ' ' << child->toString(); + } + sstream << ')'; + break; + case Operator::AND: + sstream << "(and"; + for (const auto& child : mChildren) { + sstream << ' ' << child->toString(); + } + sstream << ')'; + break; + case Operator::NOT: + sstream << "(not " << mChildren.at(0)->toString() << ')'; + break; + case Operator::LEAF: + sstream << "leaf-" << mLeaf; + break; + case Operator::CONSTANT: + sstream << to_string(mConstant); + break; + default: + throw std::invalid_argument("unknown operator!"); + } + return sstream.str(); + } + +} // namespace orc diff --git a/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.hh b/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.hh new file mode 100644 index 00000000000..bb3d16e9246 --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/ExpressionTree.hh @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_EXPRESSIONTREE_HH +#define ORC_EXPRESSIONTREE_HH + +#include "orc/sargs/TruthValue.hh" + +#include <limits> +#include <memory> +#include <string> +#include <vector> + +static const size_t UNUSED_LEAF = std::numeric_limits<size_t>::max(); + +namespace orc { + + class ExpressionTree; + typedef std::shared_ptr<ExpressionTree> TreeNode; + typedef std::initializer_list<TreeNode> NodeList; + + /** + * The inner representation of the SearchArgument. Most users should not + * need this interface, it is only for file formats that need to translate + * the SearchArgument into an internal form. + */ + class ExpressionTree { + public: + enum class Operator { OR, AND, NOT, LEAF, CONSTANT }; + + ExpressionTree(Operator op); + ExpressionTree(Operator op, std::initializer_list<TreeNode> children); + ExpressionTree(size_t leaf); + ExpressionTree(TruthValue constant); + + ExpressionTree(const ExpressionTree& other); + ExpressionTree& operator=(const ExpressionTree&) = delete; + + Operator getOperator() const; + + const std::vector<TreeNode>& getChildren() const; + + std::vector<TreeNode>& getChildren(); + + const TreeNode getChild(size_t i) const; + + TreeNode getChild(size_t i); + + TruthValue getConstant() const; + + size_t getLeaf() const; + + void setLeaf(size_t leaf); + + void addChild(TreeNode child); + + std::string toString() const; + + TruthValue evaluate(const std::vector<TruthValue>& leaves) const; + + private: + Operator mOperator; + std::vector<TreeNode> mChildren; + size_t mLeaf; + TruthValue mConstant; + }; + +} // namespace orc + +#endif //ORC_EXPRESSIONTREE_HH diff --git a/contrib/libs/apache/orc/c++/src/sargs/Literal.cc b/contrib/libs/apache/orc/c++/src/sargs/Literal.cc new file mode 100644 index 00000000000..da4cdd0d470 --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/Literal.cc @@ -0,0 +1,312 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "orc/sargs/Literal.hh" + +#include <cmath> +#include <functional> +#include <limits> +#include <sstream> + +namespace orc { + + Literal::Literal(PredicateDataType type) { + mType = type; + mValue.DecimalVal = 0; + mSize = 0; + mIsNull = true; + mPrecision = 0; + mScale = 0; + mHashCode = 0; + } + + Literal::Literal(int64_t val) { + mType = PredicateDataType::LONG; + mValue.IntVal = val; + mSize = sizeof(val); + mIsNull = false; + mPrecision = 0; + mScale = 0; + mHashCode = hashCode(); + } + + Literal::Literal(double val) { + mType = PredicateDataType::FLOAT; + mValue.DoubleVal = val; + mSize = sizeof(val); + mIsNull = false; + mPrecision = 0; + mScale = 0; + mHashCode = hashCode(); + } + + Literal::Literal(bool val) { + mType = PredicateDataType::BOOLEAN; + mValue.BooleanVal = val; + mSize = sizeof(val); + mIsNull = false; + mPrecision = 0; + mScale = 0; + mHashCode = hashCode(); + } + + Literal::Literal(PredicateDataType type, int64_t val) { + if (type != PredicateDataType::DATE) { + throw std::invalid_argument("only DATE is supported here!"); + } + mType = type; + mValue.IntVal = val; + mSize = sizeof(val); + mIsNull = false; + mPrecision = 0; + mScale = 0; + mHashCode = hashCode(); + } + + Literal::Literal(const char * str, size_t size) { + mType = PredicateDataType::STRING; + mValue.Buffer = new char[size]; + memcpy(mValue.Buffer, str, size); + mSize = size; + mIsNull = false; + mPrecision = 0; + mScale = 0; + mHashCode = hashCode(); + } + + Literal::Literal(Int128 val, int32_t precision, int32_t scale) { + mType = PredicateDataType::DECIMAL; + mValue.DecimalVal = val; + mPrecision = precision; + mScale = scale; + mSize = sizeof(Int128); + mIsNull = false; + mHashCode = hashCode(); + } + + Literal::Literal(int64_t second, int32_t nanos) { + mType = PredicateDataType::TIMESTAMP; + mValue.TimeStampVal.second = second; + mValue.TimeStampVal.nanos = nanos; + mPrecision = 0; + mScale = 0; + mSize = sizeof(Timestamp); + mIsNull = false; + mHashCode = hashCode(); + } + + Literal::Literal(const Literal& r): mType(r.mType) + , mSize(r.mSize) + , mIsNull(r.mIsNull) + , mHashCode(r.mHashCode) { + if (mType == PredicateDataType::STRING) { + mValue.Buffer = new char[r.mSize]; + memcpy(mValue.Buffer, r.mValue.Buffer, r.mSize); + mPrecision = 0; + mScale = 0; + } else if (mType == PredicateDataType::DECIMAL) { + mPrecision = r.mPrecision; + mScale = r.mScale; + mValue = r.mValue; + } else if (mType == PredicateDataType::TIMESTAMP) { + mValue.TimeStampVal = r.mValue.TimeStampVal; + } else { + mValue = r.mValue; + mPrecision = 0; + mScale = 0; + } + } + + Literal::~Literal() { + if (mType == PredicateDataType::STRING && mValue.Buffer) { + delete [] mValue.Buffer; + mValue.Buffer = nullptr; + } + } + + Literal& Literal::operator=(const Literal& r) { + if (this != &r) { + if (mType == PredicateDataType::STRING && mValue.Buffer) { + delete [] mValue.Buffer; + mValue.Buffer = nullptr; + } + + mType = r.mType; + mSize = r.mSize; + mIsNull = r.mIsNull; + mPrecision = r.mPrecision; + mScale = r.mScale; + if (mType == PredicateDataType::STRING) { + mValue.Buffer = new char[r.mSize]; + memcpy(mValue.Buffer, r.mValue.Buffer, r.mSize); + } else if (mType == PredicateDataType::TIMESTAMP) { + mValue.TimeStampVal = r.mValue.TimeStampVal; + } else { + mValue = r.mValue; + } + mHashCode = r.mHashCode; + } + return *this; + } + + std::string Literal::toString() const { + if (mIsNull) { + return "null"; + } + + std::ostringstream sstream; + switch (mType) { + case PredicateDataType::LONG: + sstream << mValue.IntVal; + break; + case PredicateDataType::DATE: + sstream << mValue.DateVal; + break; + case PredicateDataType::TIMESTAMP: + sstream << mValue.TimeStampVal.second << "." + << mValue.TimeStampVal.nanos; + break; + case PredicateDataType::FLOAT: + sstream << mValue.DoubleVal; + break; + case PredicateDataType::BOOLEAN: + sstream << (mValue.BooleanVal ? "true" : "false"); + break; + case PredicateDataType::STRING: + sstream << std::string(mValue.Buffer, mSize); + break; + case PredicateDataType::DECIMAL: + sstream << mValue.DecimalVal.toDecimalString(mScale); + break; + } + return sstream.str(); + } + + size_t Literal::hashCode() const { + if (mIsNull) { + return 0; + } + + switch (mType) { + case PredicateDataType::LONG: + return std::hash<int64_t>{}(mValue.IntVal); + case PredicateDataType::DATE: + return std::hash<int64_t>{}(mValue.DateVal); + case PredicateDataType::TIMESTAMP: + return std::hash<int64_t>{}(mValue.TimeStampVal.second) * 17 + + std::hash<int32_t>{}(mValue.TimeStampVal.nanos); + case PredicateDataType::FLOAT: + return std::hash<double>{}(mValue.DoubleVal); + case PredicateDataType::BOOLEAN: + return std::hash<bool>{}(mValue.BooleanVal); + case PredicateDataType::STRING: + return std::hash<std::string>{}( + std::string(mValue.Buffer, mSize)); + case PredicateDataType::DECIMAL: + // current glibc does not support hash<int128_t> + return std::hash<int64_t>{}(mValue.IntVal); + default: + return 0; + } + } + + bool Literal::operator==(const Literal& r) const { + if (this == &r) { + return true; + } + if (mHashCode != r.mHashCode || mType != r.mType || mIsNull != r.mIsNull) { + return false; + } + + if (mIsNull) { + return true; + } + + switch (mType) { + case PredicateDataType::LONG: + return mValue.IntVal == r.mValue.IntVal; + case PredicateDataType::DATE: + return mValue.DateVal == r.mValue.DateVal; + case PredicateDataType::TIMESTAMP: + return mValue.TimeStampVal == r.mValue.TimeStampVal; + case PredicateDataType::FLOAT: + return std::fabs(mValue.DoubleVal - r.mValue.DoubleVal) < + std::numeric_limits<double>::epsilon(); + case PredicateDataType::BOOLEAN: + return mValue.BooleanVal == r.mValue.BooleanVal; + case PredicateDataType::STRING: + return mSize == r.mSize && memcmp( + mValue.Buffer, r.mValue.Buffer, mSize) == 0; + case PredicateDataType::DECIMAL: + return mValue.DecimalVal == r.mValue.DecimalVal; + default: + return true; + } + } + + bool Literal::operator!=(const Literal& r) const { + return !(*this == r); + } + + inline void validate(const bool& isNull, + const PredicateDataType& type, + const PredicateDataType& expected) { + if (isNull) { + throw std::logic_error("cannot get value when it is null!"); + } + if (type != expected) { + throw std::logic_error("predicate type mismatch"); + } + } + + int64_t Literal::getLong() const { + validate(mIsNull, mType, PredicateDataType::LONG); + return mValue.IntVal; + } + + int64_t Literal::getDate() const { + validate(mIsNull, mType, PredicateDataType::DATE); + return mValue.DateVal; + } + + Literal::Timestamp Literal::getTimestamp() const { + validate(mIsNull, mType, PredicateDataType::TIMESTAMP); + return mValue.TimeStampVal; + } + + double Literal::getFloat() const { + validate(mIsNull, mType, PredicateDataType::FLOAT); + return mValue.DoubleVal; + } + + std::string Literal::getString() const { + validate(mIsNull, mType, PredicateDataType::STRING); + return std::string(mValue.Buffer, mSize); + } + + bool Literal::getBool() const { + validate(mIsNull, mType, PredicateDataType::BOOLEAN); + return mValue.BooleanVal; + } + + Decimal Literal::getDecimal() const { + validate(mIsNull, mType, PredicateDataType::DECIMAL); + return Decimal(mValue.DecimalVal, mScale); + } + +} diff --git a/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.cc b/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.cc new file mode 100644 index 00000000000..3b012cece4b --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.cc @@ -0,0 +1,804 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "orc/BloomFilter.hh" +#include "orc/Common.hh" +#include "orc/Type.hh" +#include "PredicateLeaf.hh" + +#include <algorithm> +#include <functional> +#include <sstream> +#include <type_traits> + +namespace orc { + + PredicateLeaf::PredicateLeaf(Operator op, + PredicateDataType type, + const std::string& colName, + Literal literal) + : mOperator(op) + , mType(type) + , mColumnName(colName) + , mHasColumnName(true) + , mColumnId(0) { + mLiterals.emplace_back(literal); + mHashCode = hashCode(); + validate(); + } + + PredicateLeaf::PredicateLeaf(Operator op, + PredicateDataType type, + uint64_t columnId, + Literal literal) + : mOperator(op) + , mType(type) + , mHasColumnName(false) + , mColumnId(columnId) { + mLiterals.emplace_back(literal); + mHashCode = hashCode(); + validate(); + } + + PredicateLeaf::PredicateLeaf(Operator op, + PredicateDataType type, + const std::string& colName, + const std::initializer_list<Literal>& literals) + : mOperator(op) + , mType(type) + , mColumnName(colName) + , mHasColumnName(true) + , mLiterals(literals.begin(), literals.end()) { + mHashCode = hashCode(); + validate(); + } + + PredicateLeaf::PredicateLeaf(Operator op, + PredicateDataType type, + uint64_t columnId, + const std::initializer_list<Literal>& literals) + : mOperator(op) + , mType(type) + , mHasColumnName(false) + , mColumnId(columnId) + , mLiterals(literals.begin(), literals.end()) { + mHashCode = hashCode(); + validate(); + } + + PredicateLeaf::PredicateLeaf(Operator op, + PredicateDataType type, + const std::string& colName, + const std::vector<Literal>& literals) + : mOperator(op) + , mType(type) + , mColumnName(colName) + , mHasColumnName(true) + , mLiterals(literals.begin(), literals.end()) { + mHashCode = hashCode(); + validate(); + } + + PredicateLeaf::PredicateLeaf(Operator op, + PredicateDataType type, + uint64_t columnId, + const std::vector<Literal>& literals) + : mOperator(op) + , mType(type) + , mHasColumnName(false) + , mColumnId(columnId) + , mLiterals(literals.begin(), literals.end()) { + mHashCode = hashCode(); + validate(); + } + + void PredicateLeaf::validateColumn() const { + if (mHasColumnName && mColumnName.empty()) { + throw std::invalid_argument("column name should not be empty"); + } else if (!mHasColumnName && mColumnId == INVALID_COLUMN_ID) { + throw std::invalid_argument("invalid column id"); + } + } + + void PredicateLeaf::validate() const { + switch (mOperator) { + case Operator::IS_NULL: + validateColumn(); + if (!mLiterals.empty()) { + throw std::invalid_argument("No literal is required!"); + } + break; + case Operator::EQUALS: + case Operator::NULL_SAFE_EQUALS: + case Operator::LESS_THAN: + case Operator::LESS_THAN_EQUALS: + validateColumn(); + if (mLiterals.size() != 1) { + throw std::invalid_argument("One literal is required!"); + } + if (static_cast<int>(mLiterals.at(0).getType()) != + static_cast<int>(mType)) { + throw std::invalid_argument("leaf and literal types do not match!"); + } + break; + case Operator::IN: + validateColumn(); + if (mLiterals.size() < 2) { + throw std::invalid_argument("At least two literals are required!"); + } + for (auto literal : mLiterals) { + if (static_cast<int>(literal.getType()) != static_cast<int>(mType)) { + throw std::invalid_argument("leaf and literal types do not match!"); + } + } + break; + case Operator::BETWEEN: + validateColumn(); + for (auto literal : mLiterals) { + if (static_cast<int>(literal.getType()) != static_cast<int>(mType)) { + throw std::invalid_argument("leaf and literal types do not match!"); + } + } + break; + default: + break; + } + } + + PredicateLeaf::Operator PredicateLeaf::getOperator() const { + return mOperator; + } + + PredicateDataType PredicateLeaf::getType() const { + return mType; + } + + bool PredicateLeaf::hasColumnName() const { + return mHasColumnName; + } + + /** + * Get the simple column name. + */ + const std::string& PredicateLeaf::getColumnName() const { + return mColumnName; + } + + uint64_t PredicateLeaf::getColumnId() const { + return mColumnId; + } + + /** + * Get the literal half of the predicate leaf. + */ + Literal PredicateLeaf::getLiteral() const { + return mLiterals.at(0); + } + + /** + * For operators with multiple literals (IN and BETWEEN), get the literals. + */ + const std::vector<Literal>& PredicateLeaf::getLiteralList() const { + return mLiterals; + } + + static std::string getLiteralString(const std::vector<Literal>& literals) { + return literals.at(0).toString(); + } + + static std::string getLiteralsString(const std::vector<Literal>& literals) { + std::ostringstream sstream; + sstream << "["; + for (size_t i = 0; i != literals.size(); ++i) { + sstream << literals[i].toString(); + if (i + 1 != literals.size()) { + sstream << ", "; + } + } + sstream << "]"; + return sstream.str(); + } + + std::string PredicateLeaf::columnDebugString() const { + if (mHasColumnName) return mColumnName; + std::ostringstream sstream; + sstream << "column(id=" << mColumnId << ')'; + return sstream.str(); + } + + std::string PredicateLeaf::toString() const { + std::ostringstream sstream; + sstream << '('; + switch (mOperator) { + case Operator::IS_NULL: + sstream << columnDebugString() << " is null"; + break; + case Operator::EQUALS: + sstream << columnDebugString() << " = " << getLiteralString(mLiterals); + break; + case Operator::NULL_SAFE_EQUALS: + sstream << columnDebugString() << " null_safe_= " + << getLiteralString(mLiterals); + break; + case Operator::LESS_THAN: + sstream << columnDebugString() << " < " << getLiteralString(mLiterals); + break; + case Operator::LESS_THAN_EQUALS: + sstream << columnDebugString() << " <= " << getLiteralString(mLiterals); + break; + case Operator::IN: + sstream << columnDebugString() << " in " << getLiteralsString(mLiterals); + break; + case Operator::BETWEEN: + sstream << columnDebugString() << " between " << getLiteralsString(mLiterals); + break; + default: + sstream << "unknown operator, column: " + << columnDebugString() << ", literals: " + << getLiteralsString(mLiterals); + } + sstream << ')'; + return sstream.str(); + } + + size_t PredicateLeaf::hashCode() const { + size_t value = 0; + std::for_each(mLiterals.cbegin(), mLiterals.cend(), + [&](const Literal& literal) { + value = value * 17 + literal.getHashCode(); + }); + auto colHash = mHasColumnName ? + std::hash<std::string>{}(mColumnName) : + std::hash<uint64_t>{}(mColumnId); + return value * 103 * 101 * 3 * 17 + + std::hash<int>{}(static_cast<int>(mOperator)) + + std::hash<int>{}(static_cast<int>(mType)) * 17 + + colHash * 3 * 17; + } + + bool PredicateLeaf::operator==(const PredicateLeaf& r) const { + if (this == &r) { + return true; + } + if (mHashCode != r.mHashCode || mType != r.mType || mOperator != r.mOperator || + mHasColumnName != r.mHasColumnName || mColumnName != r.mColumnName || + mColumnId != r.mColumnId || mLiterals.size() != r.mLiterals.size()) { + return false; + } + for (size_t i = 0; i != mLiterals.size(); ++i) { + if (mLiterals[i] != r.mLiterals[i]) { + return false; + } + } + return true; + } + + // enum to mark the position of predicate in the range + enum class Location { + BEFORE, MIN, MIDDLE, MAX, AFTER + }; + + DIAGNOSTIC_PUSH + DIAGNOSTIC_IGNORE("-Wfloat-equal") + + /** + * Given a point and min and max, determine if the point is before, at the + * min, in the middle, at the max, or after the range. + * @param point the point to test + * @param min the minimum point + * @param max the maximum point + * @return the location of the point + */ + template <typename T> + Location compareToRange(const T& point, const T& min, const T& max) { + if (point < min) { + return Location::BEFORE; + } else if (point == min) { + return Location::MIN; + } + + if (point > max) { + return Location::AFTER; + } else if (point == max) { + return Location::MAX; + } + + return Location::MIDDLE; + } + + /** + * Evaluate a predicate leaf according to min/max values + * @param op operator of the predicate + * @param values the value to test + * @param minValue the minimum value + * @param maxValue the maximum value + * @param hasNull whether the statistics contain null + * @return the TruthValue result of the test + */ + template <typename T> + TruthValue evaluatePredicateRange(const PredicateLeaf::Operator op, + const std::vector<T>& values, + const T& minValue, + const T& maxValue, + bool hasNull) { + Location loc; + switch (op) { + case PredicateLeaf::Operator::NULL_SAFE_EQUALS: + loc = compareToRange(values.at(0), minValue, maxValue); + if (loc == Location::BEFORE || loc == Location::AFTER) { + return TruthValue::NO; + } else { + return TruthValue::YES_NO; + } + case PredicateLeaf::Operator::EQUALS: + loc = compareToRange(values.at(0), minValue, maxValue); + if (minValue == maxValue && loc == Location::MIN) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } else if (loc == Location::BEFORE || loc == Location::AFTER) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else { + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + case PredicateLeaf::Operator::LESS_THAN: + loc = compareToRange(values.at(0), minValue, maxValue); + if (loc == Location::AFTER) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } else if (loc == Location::BEFORE || loc == Location::MIN) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else { + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + case PredicateLeaf::Operator::LESS_THAN_EQUALS: + loc = compareToRange(values.at(0), minValue, maxValue); + if (loc == Location::AFTER || loc == Location::MAX || + (loc == Location::MIN && minValue == maxValue)) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } else if (loc == Location::BEFORE) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else { + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + case PredicateLeaf::Operator::IN: + if (minValue == maxValue) { + // for a single value, look through to see if that value is in the set + for (auto& value : values) { + loc = compareToRange(value, minValue, maxValue); + if (loc == Location::MIN) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } + } + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else { + // are all of the values outside of the range? + for (auto& value : values) { + loc = compareToRange(value, minValue, maxValue); + if (loc == Location::MIN || loc == Location::MIDDLE || + loc == Location::MAX) { + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + } + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } + case PredicateLeaf::Operator::BETWEEN: + if (values.empty()) { + return TruthValue::YES_NO; + } + loc = compareToRange(values.at(0), minValue, maxValue); + if (loc == Location::BEFORE || loc == Location::MIN) { + Location loc2 = compareToRange(values.at(1), minValue, maxValue); + if (loc2 == Location::AFTER || loc2 == Location::MAX) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } else if (loc2 == Location::BEFORE) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else { + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + } else if (loc == Location::AFTER) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else { + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + case PredicateLeaf::Operator::IS_NULL: + // min = null condition above handles the all-nulls YES case + return hasNull ? TruthValue::YES_NO : TruthValue::NO; + default: + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + } + + DIAGNOSTIC_POP + + static TruthValue evaluateBoolPredicate( + const PredicateLeaf::Operator op, + const std::vector<Literal>& literals, + const proto::ColumnStatistics& stats) { + bool hasNull = stats.hasnull(); + if (!stats.has_bucketstatistics() || + stats.bucketstatistics().count_size() == 0) { + // does not have bool stats + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + + auto trueCount = stats.bucketstatistics().count(0); + auto falseCount = stats.numberofvalues() - trueCount; + switch (op) { + case PredicateLeaf::Operator::IS_NULL: + return hasNull ? TruthValue::YES_NO : TruthValue::NO; + case PredicateLeaf::Operator::NULL_SAFE_EQUALS: { + if (literals.at(0).getBool()) { + if (trueCount == 0) { + return TruthValue::NO; + } else if (falseCount == 0) { + return TruthValue::YES; + } + } else { + if (falseCount == 0) { + return TruthValue::NO; + } else if (trueCount == 0) { + return TruthValue::YES; + } + } + return TruthValue::YES_NO; + } + case PredicateLeaf::Operator::EQUALS: { + if (literals.at(0).getBool()) { + if (trueCount == 0) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else if (falseCount == 0) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } + } else { + if (falseCount == 0) { + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + } else if (trueCount == 0) { + return hasNull ? TruthValue::YES_NULL : TruthValue::YES; + } + } + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + case PredicateLeaf::Operator::LESS_THAN: + case PredicateLeaf::Operator::LESS_THAN_EQUALS: + case PredicateLeaf::Operator::IN: + case PredicateLeaf::Operator::BETWEEN: + default: + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + } + + static std::vector<google::protobuf::int64> literal2Long(const std::vector<Literal>& values) { + std::vector<google::protobuf::int64> result; + std::for_each(values.cbegin(), values.cend(), [&](const Literal& val) { + if (!val.isNull()) { + result.emplace_back(val.getLong()); + } + }); + return result; + } + + static std::vector<int32_t> literal2Date(const std::vector<Literal>& values) { + std::vector<int32_t> result; + std::for_each(values.cbegin(), values.cend(), [&](const Literal& val) { + if (!val.isNull()) { + result.emplace_back(val.getDate()); + } + }); + return result; + } + + static std::vector<Literal::Timestamp> literal2Timestamp( + const std::vector<Literal>& values) { + std::vector<Literal::Timestamp> result; + std::for_each(values.cbegin(), values.cend(), [&](const Literal& val) { + if (!val.isNull()) { + result.emplace_back(val.getTimestamp()); + } + }); + return result; + } + + static std::vector<Decimal> literal2Decimal( + const std::vector<Literal>& values) { + std::vector<Decimal> result; + std::for_each(values.cbegin(), values.cend(), [&](const Literal& val) { + if (!val.isNull()) { + result.emplace_back(val.getDecimal()); + } + }); + return result; + } + + static std::vector<double> literal2Double( + const std::vector<Literal>& values) { + std::vector<double> result; + std::for_each(values.cbegin(), values.cend(), [&](const Literal& val) { + if (!val.isNull()) { + result.emplace_back(val.getFloat()); + } + }); + return result; + } + + static std::vector<TString> literal2String( + const std::vector<Literal>& values) { + std::vector<TString> result; + std::for_each(values.cbegin(), values.cend(), [&](const Literal& val) { + if (!val.isNull()) { + result.emplace_back(TString(val.getString())); + } + }); + return result; + } + + TruthValue PredicateLeaf::evaluatePredicateMinMax( + const proto::ColumnStatistics& colStats) const { + TruthValue result = TruthValue::YES_NO_NULL; + switch (mType) { + case PredicateDataType::LONG: { + if (colStats.has_intstatistics() && + colStats.intstatistics().has_minimum() && + colStats.intstatistics().has_maximum()) { + const auto& stats = colStats.intstatistics(); + result = evaluatePredicateRange( + mOperator, + literal2Long(mLiterals), + stats.minimum(), + stats.maximum(), + colStats.hasnull()); + } + break; + } + case PredicateDataType::FLOAT: { + if (colStats.has_doublestatistics() && + colStats.doublestatistics().has_minimum() && + colStats.doublestatistics().has_maximum()) { + const auto& stats = colStats.doublestatistics(); + if (!std::isfinite(stats.sum())) { + result = colStats.hasnull() ? + TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } else { + result = evaluatePredicateRange( + mOperator, + literal2Double(mLiterals), + stats.minimum(), + stats.maximum(), + colStats.hasnull()); + } + } + break; + } + case PredicateDataType::STRING: { + ///TODO: check lowerBound and upperBound as well + if (colStats.has_stringstatistics() && + colStats.stringstatistics().has_minimum() && + colStats.stringstatistics().has_maximum()) { + const auto& stats = colStats.stringstatistics(); + result = evaluatePredicateRange( + mOperator, + literal2String(mLiterals), + stats.minimum(), + stats.maximum(), + colStats.hasnull()); + } + break; + } + case PredicateDataType::DATE: { + if (colStats.has_datestatistics() && + colStats.datestatistics().has_minimum() && + colStats.datestatistics().has_maximum()) { + const auto& stats = colStats.datestatistics(); + result = evaluatePredicateRange( + mOperator, + literal2Date(mLiterals), + stats.minimum(), + stats.maximum(), + colStats.hasnull()); + } + break; + } + case PredicateDataType::TIMESTAMP: { + if (colStats.has_timestampstatistics() && + colStats.timestampstatistics().has_minimumutc() && + colStats.timestampstatistics().has_maximumutc()) { + const auto& stats = colStats.timestampstatistics(); + constexpr int32_t DEFAULT_MIN_NANOS = 0; + constexpr int32_t DEFAULT_MAX_NANOS = 999999; + int32_t minNano = stats.has_minimumnanos() ? + stats.minimumnanos() - 1 : DEFAULT_MIN_NANOS; + int32_t maxNano = stats.has_maximumnanos() ? + stats.maximumnanos() - 1 : DEFAULT_MAX_NANOS; + Literal::Timestamp minTimestamp( + stats.minimumutc() / 1000, + static_cast<int32_t>((stats.minimumutc() % 1000) * 1000000) + minNano); + Literal::Timestamp maxTimestamp( + stats.maximumutc() / 1000, + static_cast<int32_t>((stats.maximumutc() % 1000) * 1000000) + maxNano); + result = evaluatePredicateRange( + mOperator, + literal2Timestamp(mLiterals), + minTimestamp, + maxTimestamp, + colStats.hasnull()); + } + break; + } + case PredicateDataType::DECIMAL: { + if (colStats.has_decimalstatistics() && + colStats.decimalstatistics().has_minimum() && + colStats.decimalstatistics().has_maximum()) { + const auto& stats = colStats.decimalstatistics(); + result = evaluatePredicateRange( + mOperator, + literal2Decimal(mLiterals), + Decimal(stats.minimum()), + Decimal(stats.maximum()), + colStats.hasnull()); + } + break; + } + case PredicateDataType::BOOLEAN: { + if (colStats.has_bucketstatistics()) { + result = evaluateBoolPredicate(mOperator, mLiterals, colStats); + } + break; + } + default: + break; + } + + // make sure null literal is respected for IN operator + if (mOperator == Operator::IN && colStats.hasnull()) { + for (const auto& literal : mLiterals) { + if (literal.isNull()) { + result = TruthValue::YES_NO_NULL; + break; + } + } + } + + return result; + } + + static bool shouldEvaluateBloomFilter(PredicateLeaf::Operator op, + TruthValue result, + const BloomFilter * bloomFilter) { + // evaluate bloom filter only when + // 1) Bloom filter is available + // 2) Min/Max evaluation yield YES or MAYBE + // 3) Predicate is EQUALS or IN list + // 4) Decimal type stores its string representation + // but has inconsistency in trailing zeros + if (bloomFilter != nullptr + && result != TruthValue::NO_NULL && result != TruthValue::NO + && (op == PredicateLeaf::Operator::EQUALS + || op == PredicateLeaf::Operator::NULL_SAFE_EQUALS + || op == PredicateLeaf::Operator::IN)) { + return true; + } + return false; + } + + static TruthValue checkInBloomFilter(PredicateLeaf::Operator, + PredicateDataType type, + const Literal& literal, + const BloomFilter * bf, + bool hasNull) { + TruthValue result = hasNull ? TruthValue::NO_NULL : TruthValue::NO; + if (literal.isNull()) { + result = hasNull ? TruthValue::YES_NO_NULL : TruthValue::NO; + } else if (type == PredicateDataType::LONG) { + if (bf->testLong(literal.getLong())) { + result = TruthValue::YES_NO_NULL; + } + } else if (type == PredicateDataType::FLOAT) { + if (bf->testDouble(literal.getFloat())) { + result = TruthValue::YES_NO_NULL; + } + } else if (type == PredicateDataType::STRING) { + std::string str = literal.getString(); + if (bf->testBytes(str.c_str(), static_cast<int64_t>(str.size()))) { + result = TruthValue::YES_NO_NULL; + } + } else if (type == PredicateDataType::DECIMAL) { + std::string decimal = literal.getDecimal().toString(true); + if (bf->testBytes(decimal.c_str(), static_cast<int64_t>(decimal.size()))) { + result = TruthValue::YES_NO_NULL; + } + } else if (type == PredicateDataType::TIMESTAMP) { + if (bf->testLong(literal.getTimestamp().getMillis())) { + result = TruthValue::YES_NO_NULL; + } + } else if (type == PredicateDataType::DATE) { + if (bf->testLong(literal.getDate())) { + result = TruthValue::YES_NO_NULL; + } + } else { + result = TruthValue::YES_NO_NULL; + } + + if (result == TruthValue::YES_NO_NULL && !hasNull) { + result = TruthValue::YES_NO; + } + + return result; + } + + TruthValue PredicateLeaf::evaluatePredicateBloomFiter(const BloomFilter * bf, + bool hasNull) const { + switch (mOperator) { + case Operator::NULL_SAFE_EQUALS: + // null safe equals does not return *_NULL variant. + // So set hasNull to false + return checkInBloomFilter( + mOperator, mType, mLiterals.front(), bf, false); + case Operator::EQUALS: + return checkInBloomFilter( + mOperator, mType, mLiterals.front(), bf, hasNull); + case Operator::IN: + for (const auto &literal : mLiterals) { + // if at least one value in IN list exist in bloom filter, + // qualify the row group/stripe + TruthValue result = checkInBloomFilter( + mOperator, mType, literal, bf, hasNull); + if (result == TruthValue::YES_NO_NULL || + result == TruthValue::YES_NO) { + return result; + } + } + return hasNull ? TruthValue::NO_NULL : TruthValue::NO; + case Operator::LESS_THAN: + case Operator::LESS_THAN_EQUALS: + case Operator::BETWEEN: + case Operator::IS_NULL: + default: + return hasNull ? TruthValue::YES_NO_NULL : TruthValue::YES_NO; + } + } + + TruthValue PredicateLeaf::evaluate(const WriterVersion writerVersion, + const proto::ColumnStatistics& colStats, + const BloomFilter * bloomFilter) const { + // files written before ORC-135 stores timestamp wrt to local timezone + // causing issues with PPD. disable PPD for timestamp for all old files + if (mType == PredicateDataType::TIMESTAMP) { + if (writerVersion < WriterVersion::WriterVersion_ORC_135) { + return TruthValue::YES_NO_NULL; + } + } + + bool allNull = colStats.hasnull() && colStats.numberofvalues() == 0; + if (mOperator == Operator::IS_NULL || (( + mOperator == Operator::EQUALS || + mOperator == Operator::NULL_SAFE_EQUALS) && + mLiterals.at(0).isNull())) { + // IS_NULL operator does not need to check min/max stats and bloom filter + return allNull ? TruthValue::YES : + (colStats.hasnull() ? TruthValue::YES_NO : TruthValue::NO); + } else if (allNull) { + // if we don't have any value, everything must have been null + return TruthValue::IS_NULL; + } + + TruthValue result = evaluatePredicateMinMax(colStats); + if (shouldEvaluateBloomFilter(mOperator, result, bloomFilter)) { + return evaluatePredicateBloomFiter(bloomFilter, colStats.hasnull()); + } else { + return result; + } + } + +} // namespace orc diff --git a/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.hh b/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.hh new file mode 100644 index 00000000000..99791cf976e --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/PredicateLeaf.hh @@ -0,0 +1,185 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_PREDICATELEAF_HH +#define ORC_PREDICATELEAF_HH + +#include "wrap/orc-proto-wrapper.hh" +#include "orc/Common.hh" +#include "orc/sargs/Literal.hh" +#include "orc/sargs/TruthValue.hh" + +#include <string> +#include <vector> + +namespace orc { + + static constexpr uint64_t INVALID_COLUMN_ID = + std::numeric_limits<uint64_t>::max(); + + class BloomFilter; + + /** + * The primitive predicates that form a SearchArgument. + */ + class PredicateLeaf { + public: + /** + * The possible operators for predicates. To get the opposites, construct + * an expression with a not operator. + */ + enum class Operator { + EQUALS = 0, + NULL_SAFE_EQUALS, + LESS_THAN, + LESS_THAN_EQUALS, + IN, + BETWEEN, + IS_NULL + }; + + // The possible types for sargs. + enum class Type { + LONG = 0, // all of the integer types + FLOAT, // float and double + STRING, // string, char, varchar + DATE, + DECIMAL, + TIMESTAMP, + BOOLEAN + }; + + PredicateLeaf() = default; + + PredicateLeaf(Operator op, + PredicateDataType type, + const std::string& colName, + Literal literal); + + PredicateLeaf(Operator op, + PredicateDataType type, + uint64_t columnId, + Literal literal); + + PredicateLeaf(Operator op, + PredicateDataType type, + const std::string& colName, + const std::initializer_list<Literal>& literalList); + + PredicateLeaf(Operator op, + PredicateDataType type, + uint64_t columnId, + const std::initializer_list<Literal>& literalList); + + PredicateLeaf(Operator op, + PredicateDataType type, + const std::string& colName, + const std::vector<Literal>& literalList); + + PredicateLeaf(Operator op, + PredicateDataType type, + uint64_t columnId, + const std::vector<Literal>& literalList); + + /** + * Get the operator for the leaf. + */ + Operator getOperator() const; + + /** + * Get the type of the column and literal by the file format. + */ + PredicateDataType getType() const; + + /** + * Get whether the predicate is created using column name. + */ + bool hasColumnName() const; + + /** + * Get the simple column name. + */ + const std::string& getColumnName() const; + + /** + * Get the column id. + */ + uint64_t getColumnId() const; + + /** + * Get the literal half of the predicate leaf. + */ + Literal getLiteral() const; + + /** + * For operators with multiple literals (IN and BETWEEN), get the literals. + */ + const std::vector<Literal>& getLiteralList() const; + + /** + * Evaluate current PredicateLeaf based on ColumnStatistics and BloomFilter + */ + TruthValue evaluate(const WriterVersion writerVersion, + const proto::ColumnStatistics& colStats, + const BloomFilter * bloomFilter) const; + + std::string toString() const; + + bool operator==(const PredicateLeaf& r) const; + + size_t getHashCode() const { return mHashCode; } + + private: + size_t hashCode() const; + + void validate() const; + void validateColumn() const; + + std::string columnDebugString() const; + + TruthValue evaluatePredicateMinMax( + const proto::ColumnStatistics& colStats) const; + + TruthValue evaluatePredicateBloomFiter(const BloomFilter * bloomFilter, + bool hasNull) const; + + private: + Operator mOperator; + PredicateDataType mType; + std::string mColumnName; + bool mHasColumnName; + uint64_t mColumnId; + std::vector<Literal> mLiterals; + size_t mHashCode; + }; + + struct PredicateLeafHash { + size_t operator()(const PredicateLeaf& leaf) const { + return leaf.getHashCode(); + } + }; + + struct PredicateLeafComparator { + bool operator()(const PredicateLeaf& lhs, const PredicateLeaf& rhs) const { + return lhs == rhs; + } + }; + +} // namespace orc + +#endif //ORC_PREDICATELEAF_HH diff --git a/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.cc b/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.cc new file mode 100644 index 00000000000..42a554f5cab --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.cc @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "SargsApplier.hh" +#include <numeric> + +namespace orc { + + // find column id from column name + uint64_t SargsApplier::findColumn(const Type& type, + const std::string& colName) { + for (uint64_t i = 0; i != type.getSubtypeCount(); ++i) { + // Only STRUCT type has field names + if (type.getKind() == STRUCT && type.getFieldName(i) == colName) { + return type.getSubtype(i)->getColumnId(); + } else { + uint64_t ret = findColumn(*type.getSubtype(i), colName); + if (ret != INVALID_COLUMN_ID) { + return ret; + } + } + } + return INVALID_COLUMN_ID; + } + + SargsApplier::SargsApplier(const Type& type, + const SearchArgument * searchArgument, + uint64_t rowIndexStride, + WriterVersion writerVersion) + : mType(type) + , mSearchArgument(searchArgument) + , mRowIndexStride(rowIndexStride) + , mWriterVersion(writerVersion) + , mStats(0, 0) + , mHasEvaluatedFileStats(false) + , mFileStatsEvalResult(true) { + const SearchArgumentImpl * sargs = + dynamic_cast<const SearchArgumentImpl *>(mSearchArgument); + + // find the mapping from predicate leaves to columns + const std::vector<PredicateLeaf>& leaves = sargs->getLeaves(); + mFilterColumns.resize(leaves.size(), INVALID_COLUMN_ID); + for (size_t i = 0; i != mFilterColumns.size(); ++i) { + if (leaves[i].hasColumnName()) { + mFilterColumns[i] = findColumn(type, leaves[i].getColumnName()); + } else { + mFilterColumns[i] = leaves[i].getColumnId(); + } + } + } + + bool SargsApplier::pickRowGroups( + uint64_t rowsInStripe, + const std::unordered_map<uint64_t, proto::RowIndex>& rowIndexes, + const std::map<uint32_t, BloomFilterIndex>& bloomFilters) { + // init state of each row group + uint64_t groupsInStripe = + (rowsInStripe + mRowIndexStride - 1) / mRowIndexStride; + mNextSkippedRows.resize(groupsInStripe); + mTotalRowsInStripe = rowsInStripe; + + // row indexes do not exist, simply read all rows + if (rowIndexes.empty()) { + return true; + } + + const auto& leaves = + dynamic_cast<const SearchArgumentImpl *>(mSearchArgument)->getLeaves(); + std::vector<TruthValue> leafValues( + leaves.size(), TruthValue::YES_NO_NULL); + mHasSelected = false; + mHasSkipped = false; + uint64_t nextSkippedRowGroup = groupsInStripe; + size_t rowGroup = groupsInStripe; + do { + --rowGroup; + for (size_t pred = 0; pred != leaves.size(); ++pred) { + uint64_t columnIdx = mFilterColumns[pred]; + auto rowIndexIter = rowIndexes.find(columnIdx); + if (columnIdx == INVALID_COLUMN_ID || rowIndexIter == rowIndexes.cend()) { + // this column does not exist in current file + leafValues[pred] = TruthValue::YES_NO_NULL; + } else { + // get column statistics + const proto::ColumnStatistics& statistics = + rowIndexIter->second.entry(static_cast<int>(rowGroup)).statistics(); + + // get bloom filter + std::shared_ptr<BloomFilter> bloomFilter; + auto iter = bloomFilters.find(static_cast<uint32_t>(columnIdx)); + if (iter != bloomFilters.cend()) { + bloomFilter = iter->second.entries.at(rowGroup); + } + + leafValues[pred] = leaves[pred].evaluate(mWriterVersion, + statistics, + bloomFilter.get()); + } + } + + bool needed = isNeeded(mSearchArgument->evaluate(leafValues)); + if (!needed) { + mNextSkippedRows[rowGroup] = 0; + nextSkippedRowGroup = rowGroup; + } else { + mNextSkippedRows[rowGroup] = (nextSkippedRowGroup == groupsInStripe) ? + rowsInStripe : (nextSkippedRowGroup * mRowIndexStride); + } + mHasSelected |= needed; + mHasSkipped |= !needed; + } while (rowGroup != 0); + + // update stats + mStats.first = std::accumulate( + mNextSkippedRows.cbegin(), mNextSkippedRows.cend(), mStats.first, + [](bool rg, uint64_t s) { return rg ? 1 : 0 + s; }); + mStats.second += groupsInStripe; + + return mHasSelected; + } + + bool SargsApplier::evaluateColumnStatistics( + const PbColumnStatistics& colStats) const { + const SearchArgumentImpl * sargs = + dynamic_cast<const SearchArgumentImpl *>(mSearchArgument); + if (sargs == nullptr) { + throw InvalidArgument("Failed to cast to SearchArgumentImpl"); + } + + const std::vector<PredicateLeaf>& leaves = sargs->getLeaves(); + std::vector<TruthValue> leafValues( + leaves.size(), TruthValue::YES_NO_NULL); + + for (size_t pred = 0; pred != leaves.size(); ++pred) { + uint64_t columnId = mFilterColumns[pred]; + if (columnId != INVALID_COLUMN_ID && + colStats.size() > static_cast<int>(columnId)) { + leafValues[pred] = leaves[pred].evaluate( + mWriterVersion, colStats.Get(static_cast<int>(columnId)), nullptr); + } + } + + return isNeeded(mSearchArgument->evaluate(leafValues)); + } + + bool SargsApplier::evaluateStripeStatistics( + const proto::StripeStatistics& stripeStats) { + if (stripeStats.colstats_size() == 0) { + return true; + } + + bool ret = evaluateColumnStatistics(stripeStats.colstats()); + if (!ret) { + // reset mNextSkippedRows when the current stripe does not satisfy the PPD + mNextSkippedRows.clear(); + } + return ret; + } + + bool SargsApplier::evaluateFileStatistics(const proto::Footer& footer) { + if (!mHasEvaluatedFileStats) { + if (footer.statistics_size() == 0) { + mFileStatsEvalResult = true; + } else { + mFileStatsEvalResult = evaluateColumnStatistics(footer.statistics()); + } + mHasEvaluatedFileStats = true; + } + return mFileStatsEvalResult; + } +} diff --git a/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.hh b/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.hh new file mode 100644 index 00000000000..d8bdf852d0b --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/SargsApplier.hh @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_SARGSAPPLIER_HH +#define ORC_SARGSAPPLIER_HH + +#include "wrap/orc-proto-wrapper.hh" +#include <orc/Common.hh> +#include "orc/BloomFilter.hh" +#include "orc/Type.hh" + +#include "sargs/SearchArgument.hh" + +#include <unordered_map> + +namespace orc { + + class SargsApplier { + public: + SargsApplier(const Type& type, + const SearchArgument * searchArgument, + uint64_t rowIndexStride, + WriterVersion writerVersion); + + /** + * Evaluate search argument on file statistics + * @return true if file statistics satisfy the sargs + */ + bool evaluateFileStatistics(const proto::Footer& footer); + + /** + * Evaluate search argument on stripe statistics + * @return true if stripe statistics satisfy the sargs + */ + bool evaluateStripeStatistics(const proto::StripeStatistics& stripeStats); + + /** + * TODO: use proto::RowIndex and proto::BloomFilter to do the evaluation + * Pick the row groups that we need to load from the current stripe. + * @return true if any row group is selected + */ + bool pickRowGroups( + uint64_t rowsInStripe, + const std::unordered_map<uint64_t, proto::RowIndex>& rowIndexes, + const std::map<uint32_t, BloomFilterIndex>& bloomFilters); + + /** + * Return a vector of the next skipped row for each RowGroup. Each value is the row id + * in stripe. 0 means the current RowGroup is entirely skipped. + * Only valid after invoking pickRowGroups(). + */ + const std::vector<uint64_t>& getNextSkippedRows() const { return mNextSkippedRows; } + + /** + * Indicate whether any row group is selected in the last evaluation + */ + bool hasSelected() const { return mHasSelected; } + + /** + * Indicate whether any row group is skipped in the last evaluation + */ + bool hasSkipped() const { return mHasSkipped; } + + /** + * Whether any row group from current row in the stripe matches PPD. + */ + bool hasSelectedFrom(uint64_t currentRowInStripe) const { + uint64_t rg = currentRowInStripe / mRowIndexStride; + for (; rg < mNextSkippedRows.size(); ++rg) { + if (mNextSkippedRows[rg]) { + return true; + } + } + return false; + } + + std::pair<uint64_t, uint64_t> getStats() const { + return mStats; + } + + private: + // evaluate column statistics in the form of protobuf::RepeatedPtrField + typedef ::google::protobuf::RepeatedPtrField<proto::ColumnStatistics> + PbColumnStatistics; + bool evaluateColumnStatistics(const PbColumnStatistics& colStats) const; + + friend class TestSargsApplier_findColumnTest_Test; + friend class TestSargsApplier_findArrayColumnTest_Test; + friend class TestSargsApplier_findMapColumnTest_Test; + static uint64_t findColumn(const Type& type, const std::string& colName); + + private: + const Type& mType; + const SearchArgument * mSearchArgument; + uint64_t mRowIndexStride; + WriterVersion mWriterVersion; + // column ids for each predicate leaf in the search argument + std::vector<uint64_t> mFilterColumns; + + // Map from RowGroup index to the next skipped row of the selected range it + // locates. If the RowGroup is not selected, set the value to 0. + // Calculated in pickRowGroups(). + std::vector<uint64_t> mNextSkippedRows; + uint64_t mTotalRowsInStripe; + bool mHasSelected; + bool mHasSkipped; + // keep stats of selected RGs and evaluated RGs + std::pair<uint64_t, uint64_t> mStats; + // store result of file stats evaluation + bool mHasEvaluatedFileStats; + bool mFileStatsEvalResult; + }; + +} + +#endif //ORC_SARGSAPPLIER_HH diff --git a/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.cc b/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.cc new file mode 100644 index 00000000000..f6abb316b5b --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.cc @@ -0,0 +1,629 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sargs/SearchArgument.hh" + +#include <algorithm> +#include <functional> +#include <sstream> +#include <unordered_set> + +namespace orc { + + SearchArgument::~SearchArgument() { + // PASS + } + + const std::vector<PredicateLeaf>& SearchArgumentImpl::getLeaves() const { + return mLeaves; + } + + const ExpressionTree * SearchArgumentImpl::getExpression() const { + return mExpressionTree.get(); + } + + TruthValue SearchArgumentImpl::evaluate( + const std::vector<TruthValue>& leaves) const { + return mExpressionTree == nullptr ? + TruthValue::YES : mExpressionTree->evaluate(leaves); + } + + std::string SearchArgumentImpl::toString() const { + std::ostringstream sstream; + for (size_t i = 0; i != mLeaves.size(); ++i) { + sstream << "leaf-" << i << " = " << mLeaves.at(i).toString() << ", "; + } + sstream << "expr = " << mExpressionTree->toString(); + return sstream.str(); + } + + SearchArgumentBuilder::~SearchArgumentBuilder() { + // PASS + } + + SearchArgumentBuilderImpl::SearchArgumentBuilderImpl() { + mRoot.reset(new ExpressionTree(ExpressionTree::Operator::AND)); + mCurrTree.push_back(mRoot); + } + + SearchArgumentBuilder& + SearchArgumentBuilderImpl::start(ExpressionTree::Operator op) { + TreeNode node = std::make_shared<ExpressionTree>(op); + mCurrTree.front()->addChild(node); + mCurrTree.push_front(node); + return *this; + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::startOr() { + return start(ExpressionTree::Operator::OR); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::startAnd() { + return start(ExpressionTree::Operator::AND); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::startNot() { + return start(ExpressionTree::Operator::NOT); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::end() { + TreeNode& current = mCurrTree.front(); + if (current->getChildren().empty()) { + throw std::invalid_argument("Cannot create expression " + + mRoot->toString() + " with no children."); + } + if (current->getOperator() == ExpressionTree::Operator::NOT && + current->getChildren().size() != 1) { + throw std::invalid_argument("Can't create NOT expression " + + current->toString() + " with more than 1 child."); + } + mCurrTree.pop_front(); + return *this; + } + + size_t SearchArgumentBuilderImpl::addLeaf(PredicateLeaf leaf) { + size_t id = mLeaves.size(); + const auto& result = mLeaves.insert(std::make_pair(leaf, id)); + return result.first->second; + } + + bool SearchArgumentBuilderImpl::isInvalidColumn(const std::string& column) { + return column.empty(); + } + + bool SearchArgumentBuilderImpl::isInvalidColumn(uint64_t columnId) { + return columnId == INVALID_COLUMN_ID; + } + + template<typename T> + SearchArgumentBuilder& + SearchArgumentBuilderImpl::compareOperator(PredicateLeaf::Operator op, + T column, + PredicateDataType type, + Literal literal) { + TreeNode parent = mCurrTree.front(); + if (isInvalidColumn(column)) { + parent->addChild( + std::make_shared<ExpressionTree>(TruthValue::YES_NO_NULL)); + } else { + PredicateLeaf leaf(op, type, column, literal); + parent->addChild(std::make_shared<ExpressionTree>(addLeaf(leaf))); + } + return *this; + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::lessThan(const std::string& column, + PredicateDataType type, + Literal literal) { + return compareOperator( + PredicateLeaf::Operator::LESS_THAN, column, type, literal); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::lessThan(uint64_t columnId, + PredicateDataType type, + Literal literal) { + return compareOperator( + PredicateLeaf::Operator::LESS_THAN, columnId, type, literal); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::lessThanEquals(const std::string& column, + PredicateDataType type, + Literal literal) { + return compareOperator( + PredicateLeaf::Operator::LESS_THAN_EQUALS, column, type, literal); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::lessThanEquals(uint64_t columnId, + PredicateDataType type, + Literal literal) { + return compareOperator( + PredicateLeaf::Operator::LESS_THAN_EQUALS, columnId, type, literal); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::equals(const std::string& column, + PredicateDataType type, + Literal literal) { + if (literal.isNull()) { + return isNull(column, type); + } else { + return compareOperator( + PredicateLeaf::Operator::EQUALS, column, type, literal); + } + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::equals(uint64_t columnId, + PredicateDataType type, + Literal literal) { + if (literal.isNull()) { + return isNull(columnId, type); + } else { + return compareOperator( + PredicateLeaf::Operator::EQUALS, columnId, type, literal); + } + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::nullSafeEquals(const std::string& column, + PredicateDataType type, + Literal literal) { + return compareOperator( + PredicateLeaf::Operator::NULL_SAFE_EQUALS, column, type, literal); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::nullSafeEquals(uint64_t columnId, + PredicateDataType type, + Literal literal) { + return compareOperator( + PredicateLeaf::Operator::NULL_SAFE_EQUALS, columnId, type, literal); + } + + template<typename T, typename CONTAINER> + SearchArgumentBuilder& SearchArgumentBuilderImpl::addChildForIn(T column, + PredicateDataType type, + const CONTAINER& literals) { + TreeNode &parent = mCurrTree.front(); + if (isInvalidColumn(column)) { + parent->addChild( + std::make_shared<ExpressionTree>((TruthValue::YES_NO_NULL))); + } else { + if (literals.size() == 0) { + throw std::invalid_argument( + "Can't create in expression with no arguments"); + } + PredicateLeaf leaf( + PredicateLeaf::Operator::IN, type, column, literals); + parent->addChild(std::make_shared<ExpressionTree>(addLeaf(leaf))); + } + return *this; + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::in(const std::string& column, + PredicateDataType type, + const std::initializer_list<Literal>& literals) { + return addChildForIn(column, type, literals); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::in(uint64_t columnId, + PredicateDataType type, + const std::initializer_list<Literal>& literals) { + return addChildForIn(columnId, type, literals); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::in(const std::string& column, + PredicateDataType type, + const std::vector<Literal>& literals) { + return addChildForIn(column, type, literals); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::in(uint64_t columnId, + PredicateDataType type, + const std::vector<Literal>& literals) { + return addChildForIn(columnId, type, literals); + } + + template<typename T> + SearchArgumentBuilder& SearchArgumentBuilderImpl::addChildForIsNull(T column, PredicateDataType type) { + TreeNode& parent = mCurrTree.front(); + if (isInvalidColumn(column)) { + parent->addChild( + std::make_shared<ExpressionTree>(TruthValue::YES_NO_NULL)); + } else { + PredicateLeaf leaf(PredicateLeaf::Operator::IS_NULL, + type, + column, + {}); + parent->addChild(std::make_shared<ExpressionTree>(addLeaf(leaf))); + } + return *this; + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::isNull(const std::string& column, + PredicateDataType type) { + return addChildForIsNull(column, type); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::isNull(uint64_t columnId, + PredicateDataType type) { + return addChildForIsNull(columnId, type); + } + + template<typename T> + SearchArgumentBuilder& SearchArgumentBuilderImpl::addChildForBetween(T column, + PredicateDataType type, + Literal lower, Literal upper) { + TreeNode& parent = mCurrTree.front(); + if (isInvalidColumn(column)) { + parent->addChild( + std::make_shared<ExpressionTree>(TruthValue::YES_NO_NULL)); + } else { + PredicateLeaf leaf(PredicateLeaf::Operator::BETWEEN, + type, + column, + { lower, upper }); + parent->addChild(std::make_shared<ExpressionTree>(addLeaf(leaf))); + } + return *this; + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::between(const std::string& column, + PredicateDataType type, + Literal lower, + Literal upper) { + return addChildForBetween(column, type, lower, upper); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::between(uint64_t columnId, + PredicateDataType type, + Literal lower, + Literal upper) { + return addChildForBetween(columnId, type, lower, upper); + } + + SearchArgumentBuilder& SearchArgumentBuilderImpl::literal(TruthValue truth) { + TreeNode& parent = mCurrTree.front(); + parent->addChild(std::make_shared<ExpressionTree>(truth)); + return *this; + } + + /** + * Recursively explore the tree to find the leaves that are still reachable + * after optimizations. + * @param tree the node to check next + * @param next the next available leaf id + * @param leafReorder buffer for leaf reorder + * @return the next available leaf id + */ + static size_t compactLeaves(const TreeNode& tree, + size_t next, + size_t leafReorder[]) { + if (tree->getOperator() == ExpressionTree::Operator::LEAF) { + size_t oldLeaf = tree->getLeaf(); + if (leafReorder[oldLeaf] == UNUSED_LEAF) { + leafReorder[oldLeaf] = next++; + } + } else { + for (const TreeNode& child : tree->getChildren()) { + next = compactLeaves(child, next, leafReorder); + } + } + return next; + } + + /** + * Rewrite expression tree to update the leaves. + * @param root the root of the tree to fix + * @param leafReorder a map from old leaf ids to new leaf ids + * @return the fixed root + */ + static TreeNode rewriteLeaves(TreeNode root, size_t leafReorder[]) { + // The leaves could be shared in the tree. Use Set to remove the duplicates. + std::unordered_set<TreeNode> leaves; + std::deque<TreeNode> nodes; + nodes.push_back(root); + + // Perform BFS + while (!nodes.empty()) { + TreeNode& node = nodes.front(); + nodes.pop_front(); + + if (node->getOperator() == ExpressionTree::Operator::LEAF) { + leaves.insert(node); + } else { + for (auto& child : node->getChildren()) { + nodes.push_back(child); + } + } + } + + // Update the leaf in place + for (auto& leaf : leaves) { + leaf->setLeaf(leafReorder[leaf->getLeaf()]); + } + + return root; + } + + /** + * Push the negations all the way to just before the leaves. Also remove + * double negatives. + * + * @param root the expression to normalize + * @return the normalized expression, which may share some or all of the + * nodes of the original expression. + */ + TreeNode SearchArgumentBuilderImpl::pushDownNot(TreeNode root) { + if (root->getOperator() == ExpressionTree::Operator::NOT) { + TreeNode child = root->getChild(0); + switch (child->getOperator()) { + case ExpressionTree::Operator::NOT: { + return pushDownNot(child->getChild(0)); + } + case ExpressionTree::Operator::CONSTANT: { + return std::make_shared<ExpressionTree>(!child->getConstant()); + } + case ExpressionTree::Operator::AND: { + TreeNode result(new ExpressionTree(ExpressionTree::Operator::OR)); + for (auto& kid : child->getChildren()) { + result->addChild(pushDownNot(std::make_shared<ExpressionTree>( + ExpressionTree::Operator::NOT, NodeList{ kid }) + )); + } + return result; + } + case ExpressionTree::Operator::OR: { + TreeNode result(new ExpressionTree(ExpressionTree::Operator::AND)); + for (auto& kid : child->getChildren()) { + result->addChild(pushDownNot(std::make_shared<ExpressionTree>( + ExpressionTree::Operator::NOT, NodeList{ kid }) + )); + } + return result; + } + // for leaf, we don't do anything + case ExpressionTree::Operator::LEAF: + default: + break; + } + } else { + // iterate through children and push down not for each one + for (size_t i = 0; i != root->getChildren().size(); ++i) { + root->getChildren()[i] = pushDownNot(root->getChild(i)); + } + } + return root; + } + + /** + * Remove MAYBE values from the expression. If they are in an AND operator, + * they are dropped. If they are in an OR operator, they kill their parent. + * This assumes that pushDownNot has already been called. + * + * @param expr The expression to clean up + * @return The cleaned up expression + */ + TreeNode SearchArgumentBuilderImpl::foldMaybe(TreeNode expr) { + if (expr) { + for (size_t i = 0; i != expr->getChildren().size(); ++i) { + TreeNode child = foldMaybe(expr->getChild(i)); + if (child->getOperator() == ExpressionTree::Operator::CONSTANT && + child->getConstant() == TruthValue::YES_NO_NULL) { + switch (expr->getOperator()) { + case ExpressionTree::Operator::AND: + expr->getChildren()[i] = nullptr; + break; + case ExpressionTree::Operator::OR: + // a maybe will kill the or condition + return child; + case ExpressionTree::Operator::NOT: + case ExpressionTree::Operator::LEAF: + case ExpressionTree::Operator::CONSTANT: + default: + throw std::invalid_argument( + "Got a maybe as child of " + expr->toString()); + } + } else { + expr->getChildren()[i] = child; + } + } + + auto& children = expr->getChildren(); + if (!children.empty()) { + // eliminate removed maybe nodes from expr + std::vector<TreeNode> nodes; + std::for_each(children.begin(), children.end(), + [&](const TreeNode& node){ if (node) nodes.emplace_back(node); }); + std::swap(children, nodes); + if (children.empty()) { + return std::make_shared<ExpressionTree>(TruthValue::YES_NO_NULL); + } + } + } + return expr; + } + + /** + * Converts multi-level ands and ors into single level ones. + * + * @param root the expression to flatten + * @return the flattened expression, which will always be root with + * potentially modified children. + */ + TreeNode SearchArgumentBuilderImpl::flatten(TreeNode root) { + if (root) { + std::vector<TreeNode> nodes; + for (size_t i = 0; i != root->getChildren().size(); ++i) { + TreeNode child = flatten(root->getChild(i)); + // do we need to flatten? + if (child->getOperator() == root->getOperator() && + child->getOperator() != ExpressionTree::Operator::NOT) { + for (auto& grandkid : child->getChildren()) { + nodes.emplace_back(grandkid); + } + } else { + nodes.emplace_back(child); + } + } + std::swap(root->getChildren(), nodes); + + // if we have a single AND or OR, just return the child + if ((root->getOperator() == ExpressionTree::Operator::OR || + root->getOperator() == ExpressionTree::Operator::AND) && + root->getChildren().size() == 1) { + return root->getChild(0); + } + } + return root; + } + + /** + * Generate all combinations of items on the andList. For each item on the + * andList, it generates all combinations of one child from each and + * expression. Thus, (and a b) (and c d) will be expanded to: (or a c) + * (or a d) (or b c) (or b d). If there are items on the nonAndList, they + * are added to each or expression. + * @param result a list to put the results onto + * @param andList a list of and expressions + * @param nonAndList a list of non-and expressions + */ + static void generateAllCombinations(std::vector<TreeNode>& result, + const std::vector<TreeNode>& andList, + const std::vector<TreeNode>& nonAndList) { + std::vector<TreeNode>& kids = andList.front()->getChildren(); + if (result.empty()) { + for (TreeNode& kid : kids) { + TreeNode root(new ExpressionTree(ExpressionTree::Operator::OR)); + result.emplace_back(root); + for (const TreeNode& node : nonAndList) { + root->addChild(std::make_shared<ExpressionTree>(*node)); + } + root->addChild(kid); + } + } else { + std::vector<TreeNode> work(result.begin(), result.end()); + result.clear(); + for (TreeNode& kid : kids) { + for (TreeNode node : work) { + TreeNode copy = std::make_shared<ExpressionTree>(*node); + copy->addChild(kid); + result.emplace_back(copy); + } + } + } + if (andList.size() > 1) { + generateAllCombinations( + result, + std::vector<TreeNode>(andList.cbegin() + 1, andList.cend()), + nonAndList); + } + } + + static const size_t CNF_COMBINATIONS_THRESHOLD = 256; + static bool checkCombinationsThreshold(const std::vector<TreeNode>& andList) { + size_t numComb = 1; + for (const TreeNode& tree : andList) { + numComb *= tree->getChildren().size(); + if (numComb > CNF_COMBINATIONS_THRESHOLD) { + return false; + } + } + return true; + } + + /** + * Convert an expression so that the top level operator is AND with OR + * operators under it. This routine assumes that all of the NOT operators + * have been pushed to the leaves via pushdDownNot. + * @param root the expression + * @return the normalized expression + */ + TreeNode SearchArgumentBuilderImpl::convertToCNF(TreeNode root) { + if (root) { + // convert all of the children to CNF + size_t size = root->getChildren().size(); + for (size_t i = 0; i != size; ++i) { + root->getChildren()[i] = convertToCNF(root->getChild(i)); + } + if (root->getOperator() == ExpressionTree::Operator::OR) { + // a list of leaves that weren't under AND expressions + std::vector<TreeNode> nonAndList; + // a list of AND expressions that we need to distribute + std::vector<TreeNode> andList; + for (TreeNode& child : root->getChildren()) { + if (child->getOperator() == ExpressionTree::Operator::AND) { + andList.emplace_back(child); + } else if (child->getOperator() == ExpressionTree::Operator::OR) { + // pull apart the kids of the OR expression + for (TreeNode& grandkid : child->getChildren()) { + nonAndList.emplace_back(grandkid); + } + } else { + nonAndList.emplace_back(child); + } + } + if (!andList.empty()) { + if (checkCombinationsThreshold(andList)) { + root = std::make_shared<ExpressionTree>( + ExpressionTree::Operator::AND); + generateAllCombinations(root->getChildren(), andList, nonAndList); + } else { + root = std::make_shared<ExpressionTree>(TruthValue::YES_NO_NULL); + } + } + } + } + return root; + } + + SearchArgumentImpl::SearchArgumentImpl(TreeNode root, + const std::vector<PredicateLeaf>& leaves) + : mExpressionTree(root) + , mLeaves(leaves) { + // PASS + } + + std::unique_ptr<SearchArgument> SearchArgumentBuilderImpl::build() { + if (mCurrTree.size() != 1) { + throw std::invalid_argument("Failed to end " + + std::to_string(mCurrTree.size()) + " operations."); + } + mRoot = pushDownNot(mRoot); + mRoot = foldMaybe(mRoot); + mRoot = flatten(mRoot); + mRoot = convertToCNF(mRoot); + mRoot = flatten(mRoot); + std::vector<size_t> leafReorder(mLeaves.size(), UNUSED_LEAF); + size_t newLeafCount = compactLeaves(mRoot, 0, leafReorder.data()); + mRoot = rewriteLeaves(mRoot, leafReorder.data()); + + std::vector<PredicateLeaf> leafList(newLeafCount, PredicateLeaf()); + + // build the new list + for (auto & leaf : mLeaves) { + size_t newLoc = leafReorder[leaf.second]; + if (newLoc != UNUSED_LEAF) { + leafList[newLoc] = leaf.first; + } + } + return std::unique_ptr<SearchArgument>( + new SearchArgumentImpl(mRoot, leafList)); + } + + std::unique_ptr<SearchArgumentBuilder> SearchArgumentFactory::newBuilder() { + return std::unique_ptr<SearchArgumentBuilder>(new SearchArgumentBuilderImpl()); + } + +} // namespace orc diff --git a/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.hh b/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.hh new file mode 100644 index 00000000000..57d765e1df1 --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/SearchArgument.hh @@ -0,0 +1,341 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ORC_SRC_SEARCHARGUMENT_HH +#define ORC_SRC_SEARCHARGUMENT_HH + +#include "wrap/orc-proto-wrapper.hh" +#include "ExpressionTree.hh" +#include "orc/sargs/SearchArgument.hh" +#include "sargs/PredicateLeaf.hh" + +#include <deque> +#include <stdexcept> +#include <unordered_map> + +namespace orc { + + /** + * Primary interface for a search argument, which are the subset of predicates + * that can be pushed down to the RowReader. Each SearchArgument consists + * of a series of search clauses that must each be true for the row to be + * accepted by the filter. + * + * This requires that the filter be normalized into conjunctive normal form + * (<a href="http://en.wikipedia.org/wiki/Conjunctive_normal_form">CNF</a>). + */ + class SearchArgumentImpl : public SearchArgument { + public: + SearchArgumentImpl(TreeNode root, const std::vector<PredicateLeaf>& leaves); + + /** + * Get the leaf predicates that are required to evaluate the predicate. The + * list will have the duplicates removed. + * @return the list of leaf predicates + */ + const std::vector<PredicateLeaf>& getLeaves() const; + + /** + * Get the expression tree. This should only needed for file formats that + * need to translate the expression to an internal form. + */ + const ExpressionTree * getExpression() const; + + /** + * Evaluate the entire predicate based on the values for the leaf predicates. + * @param leaves the value of each leaf predicate + * @return the value of the entire predicate + */ + TruthValue evaluate(const std::vector<TruthValue>& leaves) const override; + + std::string toString() const override; + + private: + std::shared_ptr<ExpressionTree> mExpressionTree; + std::vector<PredicateLeaf> mLeaves; + }; + + /** + * A builder object to create a SearchArgument from expressions. The user + * must call startOr, startAnd, or startNot before adding any leaves. + */ + class SearchArgumentBuilderImpl : public SearchArgumentBuilder { + public: + SearchArgumentBuilderImpl(); + + /** + * Start building an or operation and push it on the stack. + * @return this + */ + SearchArgumentBuilder& startOr() override; + + /** + * Start building an and operation and push it on the stack. + * @return this + */ + SearchArgumentBuilder& startAnd() override; + + /** + * Start building a not operation and push it on the stack. + * @return this + */ + SearchArgumentBuilder& startNot() override; + + /** + * Finish the current operation and pop it off of the stack. Each start + * call must have a matching end. + * @return this + */ + SearchArgumentBuilder& end() override; + + /** + * Add a less than leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& lessThan(const std::string& column, + PredicateDataType type, + Literal literal) override; + + /** + * Add a less than leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& lessThan(uint64_t columnId, + PredicateDataType type, + Literal literal) override; + + /** + * Add a less than equals leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& lessThanEquals(const std::string& column, + PredicateDataType type, + Literal literal) override; + + /** + * Add a less than equals leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& lessThanEquals(uint64_t columnId, + PredicateDataType type, + Literal literal) override; + + /** + * Add an equals leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& equals(const std::string& column, + PredicateDataType type, + Literal literal) override; + + /** + * Add an equals leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& equals(uint64_t columnId, + PredicateDataType type, + Literal literal) override; + + /** + * Add a null safe equals leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& nullSafeEquals(const std::string& column, + PredicateDataType type, + Literal literal) override; + + /** + * Add a null safe equals leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literal the literal + * @return this + */ + SearchArgumentBuilder& nullSafeEquals(uint64_t columnId, + PredicateDataType type, + Literal literal) override; + + /** + * Add an in leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + SearchArgumentBuilder& in(const std::string& column, + PredicateDataType type, + const std::initializer_list<Literal>& literals) override; + + /** + * Add an in leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + SearchArgumentBuilder& in(uint64_t columnId, + PredicateDataType type, + const std::initializer_list<Literal>& literals) override; + + /** + * Add an in leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + SearchArgumentBuilder& in(const std::string& column, + PredicateDataType type, + const std::vector<Literal>& literals) override; + + /** + * Add an in leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param literals the literals + * @return this + */ + SearchArgumentBuilder& in(uint64_t columnId, + PredicateDataType type, + const std::vector<Literal>& literals) override; + + /** + * Add an is null leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @return this + */ + SearchArgumentBuilder& isNull(const std::string& column, + PredicateDataType type) override; + + /** + * Add an is null leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @return this + */ + SearchArgumentBuilder& isNull(uint64_t columnId, + PredicateDataType type) override; + + /** + * Add a between leaf to the current item on the stack. + * @param column the field name of the column + * @param type the type of the expression + * @param lower the literal + * @param upper the literal + * @return this + */ + SearchArgumentBuilder& between(const std::string& column, + PredicateDataType type, + Literal lower, + Literal upper) override; + + /** + * Add a between leaf to the current item on the stack. + * @param columnId the column id of the column + * @param type the type of the expression + * @param lower the literal + * @param upper the literal + * @return this + */ + SearchArgumentBuilder& between(uint64_t columnId, + PredicateDataType type, + Literal lower, + Literal upper) override; + + /** + * Add a truth value to the expression. + * @param truth truth value + * @return this + */ + SearchArgumentBuilder& literal(TruthValue truth) override; + + /** + * Build and return the SearchArgument that has been defined. All of the + * starts must have been ended before this call. + * @return the new SearchArgument + */ + std::unique_ptr<SearchArgument> build() override; + + private: + SearchArgumentBuilder& start(ExpressionTree::Operator op); + size_t addLeaf(PredicateLeaf leaf); + + static bool isInvalidColumn(const std::string& column); + static bool isInvalidColumn(uint64_t columnId); + + template<typename T> + SearchArgumentBuilder& compareOperator(PredicateLeaf::Operator op, + T column, + PredicateDataType type, + Literal literal); + + template<typename T, typename CONTAINER> + SearchArgumentBuilder& addChildForIn(T column, + PredicateDataType type, + const CONTAINER& literals); + + template<typename T> + SearchArgumentBuilder& addChildForIsNull(T column, + PredicateDataType type); + + template<typename T> + SearchArgumentBuilder& addChildForBetween(T column, + PredicateDataType type, + Literal lower, + Literal upper); + + public: + static TreeNode pushDownNot(TreeNode root); + static TreeNode foldMaybe(TreeNode expr); + static TreeNode flatten(TreeNode root); + static TreeNode convertToCNF(TreeNode root); + + private: + std::deque<TreeNode> mCurrTree; + std::unordered_map<PredicateLeaf, + size_t, + PredicateLeafHash, + PredicateLeafComparator> mLeaves; + std::shared_ptr<ExpressionTree> mRoot; + }; + +} // namespace orc + +#endif //ORC_SRC_SEARCHARGUMENT_HH diff --git a/contrib/libs/apache/orc/c++/src/sargs/TruthValue.cc b/contrib/libs/apache/orc/c++/src/sargs/TruthValue.cc new file mode 100644 index 00000000000..fe00ed94724 --- /dev/null +++ b/contrib/libs/apache/orc/c++/src/sargs/TruthValue.cc @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "orc/sargs/TruthValue.hh" + +#include <stdexcept> + +namespace orc { + + TruthValue operator||(TruthValue left, TruthValue right) { + if (right == TruthValue::YES || left == TruthValue::YES) { + return TruthValue::YES; + } + if (right == TruthValue::YES_NULL || left == TruthValue::YES_NULL) { + return TruthValue::YES_NULL; + } + if (right == TruthValue::NO) { + return left; + } + if (left == TruthValue::NO) { + return right; + } + if (left == TruthValue::IS_NULL) { + if (right == TruthValue::NO_NULL || right == TruthValue::IS_NULL) { + return TruthValue::IS_NULL; + } else { + return TruthValue::YES_NULL; + } + } + if (right == TruthValue::IS_NULL) { + if (left == TruthValue::NO_NULL) { + return TruthValue::IS_NULL; + } else { + return TruthValue::YES_NULL; + } + } + if (left == TruthValue::NO_NULL && right == TruthValue::NO_NULL) { + return TruthValue::NO_NULL; + } + return TruthValue::YES_NO_NULL; + } + + TruthValue operator&&(TruthValue left, TruthValue right) { + if (right == TruthValue::NO || left == TruthValue::NO) { + return TruthValue::NO; + } + if (right == TruthValue::NO_NULL || left == TruthValue::NO_NULL) { + return TruthValue::NO_NULL; + } + if (right == TruthValue::YES) { + return left; + } + if (left == TruthValue::YES) { + return right; + } + if (left == TruthValue::IS_NULL) { + if (right == TruthValue::YES_NULL || right == TruthValue::IS_NULL) { + return TruthValue::IS_NULL; + } else { + return TruthValue::NO_NULL; + } + } + if (right == TruthValue::IS_NULL) { + if (left == TruthValue::YES_NULL) { + return TruthValue::IS_NULL; + } else { + return TruthValue::NO_NULL; + } + } + if (left == TruthValue::YES_NULL && right == TruthValue::YES_NULL) { + return TruthValue::YES_NULL; + } + return TruthValue::YES_NO_NULL; + } + + TruthValue operator!(TruthValue val) { + switch (val) { + case TruthValue::NO: + return TruthValue::YES; + case TruthValue::YES: + return TruthValue::NO; + case TruthValue::IS_NULL: + case TruthValue::YES_NO: + case TruthValue::YES_NO_NULL: + return val; + case TruthValue::NO_NULL: + return TruthValue::YES_NULL; + case TruthValue::YES_NULL: + return TruthValue::NO_NULL; + default: + throw std::invalid_argument("Unknown TruthValue"); + } + } + + bool isNeeded(TruthValue val) { + switch (val) { + case TruthValue::NO: + case TruthValue::IS_NULL: + case TruthValue::NO_NULL: + return false; + case TruthValue::YES: + case TruthValue::YES_NO: + case TruthValue::YES_NULL: + case TruthValue::YES_NO_NULL: + default: + return true; + } + } + +} diff --git a/contrib/libs/apache/orc/proto/orc_proto.proto b/contrib/libs/apache/orc/proto/orc_proto.proto index e8b84dbecde..ff05657a547 100644 --- a/contrib/libs/apache/orc/proto/orc_proto.proto +++ b/contrib/libs/apache/orc/proto/orc_proto.proto @@ -366,6 +366,7 @@ message Footer { // 1 = ORC C++ // 2 = Presto // 3 = Scritchley Go from https://github.com/scritchley/orc + // 4 = Trino optional uint32 writer = 9; // information about the encryption in this file |