diff options
| author | robot-piglet <[email protected]> | 2025-08-14 11:26:15 +0300 |
|---|---|---|
| committer | robot-piglet <[email protected]> | 2025-08-14 12:06:36 +0300 |
| commit | dc2bf727ea4698fa382f0f8623a8854c4900e212 (patch) | |
| tree | a621e92060fd7560066f33a323b4b8aca34f1e36 /contrib/libs/apache/arrow_next/cpp/src/arrow/compute/expression.cc | |
| parent | 322ee7d149464c6f18d6a330d937227cb022b9f3 (diff) | |
Intermediate changes
commit_hash:746e9b78ab4c78ba4f30511f1fa9330c0d56a406
Diffstat (limited to 'contrib/libs/apache/arrow_next/cpp/src/arrow/compute/expression.cc')
| -rw-r--r-- | contrib/libs/apache/arrow_next/cpp/src/arrow/compute/expression.cc | 1755 |
1 files changed, 1755 insertions, 0 deletions
diff --git a/contrib/libs/apache/arrow_next/cpp/src/arrow/compute/expression.cc b/contrib/libs/apache/arrow_next/cpp/src/arrow/compute/expression.cc new file mode 100644 index 00000000000..29512c0dd4a --- /dev/null +++ b/contrib/libs/apache/arrow_next/cpp/src/arrow/compute/expression.cc @@ -0,0 +1,1755 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/config.h" + +#include "arrow/compute/expression.h" + +#include <algorithm> +#include <optional> +#include <unordered_map> +#include <unordered_set> + +#include "arrow/chunked_array.h" +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec_internal.h" +#include "arrow/compute/expression_internal.h" +#include "arrow/compute/function_internal.h" +#include "arrow/compute/util.h" +#include "arrow/io/memory.h" +#ifdef ARROW_IPC +# include "arrow/ipc/reader.h" +# include "arrow/ipc/writer.h" +#endif +#include "arrow/util/hash_util.h" +#include "arrow/util/key_value_metadata.h" +#include "arrow/util/logging.h" +#include "arrow/util/string.h" +#include "arrow/util/value_parsing.h" +#include "arrow/util/vector.h" + +namespace arrow20 { + +using internal::checked_cast; +using internal::checked_pointer_cast; +using internal::EndsWith; +using internal::ToChars; + +namespace compute { + +void Expression::Call::ComputeHash() { + hash = std::hash<std::string>{}(function_name); + for (const auto& arg : arguments) { + arrow20::internal::hash_combine(hash, arg.hash()); + } +} + +Expression::Expression(Call call) { + call.ComputeHash(); + impl_ = std::make_shared<Impl>(std::move(call)); +} + +Expression::Expression(Datum literal) + : impl_(std::make_shared<Impl>(std::move(literal))) {} + +Expression::Expression(Parameter parameter) + : impl_(std::make_shared<Impl>(std::move(parameter))) {} + +Expression literal(Datum lit) { return Expression(std::move(lit)); } + +Expression field_ref(FieldRef ref) { + return Expression(Expression::Parameter{std::move(ref), TypeHolder{}, {-1}}); +} + +Expression call(std::string function, std::vector<Expression> arguments, + std::shared_ptr<compute::FunctionOptions> options) { + Expression::Call call; + call.function_name = std::move(function); + call.arguments = std::move(arguments); + call.options = std::move(options); + return Expression(std::move(call)); +} + +const Datum* Expression::literal() const { + if (impl_ == nullptr) return nullptr; + + return std::get_if<Datum>(impl_.get()); +} + +const Expression::Parameter* Expression::parameter() const { + if (impl_ == nullptr) return nullptr; + + return std::get_if<Parameter>(impl_.get()); +} + +const FieldRef* Expression::field_ref() const { + if (auto parameter = this->parameter()) { + return ¶meter->ref; + } + return nullptr; +} + +const Expression::Call* Expression::call() const { + if (impl_ == nullptr) return nullptr; + + return std::get_if<Call>(impl_.get()); +} + +const DataType* Expression::type() const { + if (impl_ == nullptr) return nullptr; + + if (const Datum* lit = literal()) { + return lit->type().get(); + } + + if (const Parameter* parameter = this->parameter()) { + return parameter->type.type; + } + + return CallNotNull(*this)->type.type; +} + +namespace { + +std::string PrintDatum(const Datum& datum) { + if (datum.is_scalar()) { + if (!datum.scalar()->is_valid) return "null[" + datum.type()->ToString() + "]"; + + switch (datum.type()->id()) { + case Type::STRING: + case Type::LARGE_STRING: + return '"' + + Escape(std::string_view(*datum.scalar_as<BaseBinaryScalar>().value)) + '"'; + + case Type::BINARY: + case Type::FIXED_SIZE_BINARY: + case Type::LARGE_BINARY: + return '"' + datum.scalar_as<BaseBinaryScalar>().value->ToHexString() + '"'; + + default: + break; + } + + return datum.scalar()->ToString(); + } else if (datum.is_array()) { + return "Array[" + datum.type()->ToString() + "]"; + } + return datum.ToString(); +} + +} // namespace + +std::string Expression::ToString() const { + if (auto lit = literal()) { + return PrintDatum(*lit); + } + + if (auto ref = field_ref()) { + if (auto name = ref->name()) { + return *name; + } + if (auto path = ref->field_path()) { + return path->ToString(); + } + return ref->ToString(); + } + + auto call = CallNotNull(*this); + auto binary = [&](std::string op) { + return "(" + call->arguments[0].ToString() + " " + op + " " + + call->arguments[1].ToString() + ")"; + }; + + if (auto cmp = Comparison::Get(call->function_name)) { + return binary(Comparison::GetOp(*cmp)); + } + + constexpr std::string_view kleene = "_kleene"; + if (EndsWith(call->function_name, kleene)) { + auto op = call->function_name.substr(0, call->function_name.size() - kleene.size()); + return binary(std::move(op)); + } + + if (auto options = GetMakeStructOptions(*call)) { + std::string out = "{"; + auto argument = call->arguments.begin(); + for (const auto& field_name : options->field_names) { + out += field_name + "=" + argument++->ToString() + ", "; + } + out.resize(out.size() - 1); + out.back() = '}'; + return out; + } + + std::string out = call->function_name + "("; + for (const auto& arg : call->arguments) { + out += arg.ToString() + ", "; + } + + if (call->options) { + out += call->options->ToString(); + } else if (call->arguments.size()) { + out.resize(out.size() - 2); + } + + out += ')'; + return out; +} + +void PrintTo(const Expression& expr, std::ostream* os) { + *os << expr.ToString(); + if (expr.IsBound()) { + *os << "[bound]"; + } +} + +bool Expression::Equals(const Expression& other) const { + if (Identical(*this, other)) return true; + + if (impl_ == nullptr || other.impl_ == nullptr) return false; + + if (impl_->index() != other.impl_->index()) { + return false; + } + + if (auto lit = literal()) { + // The scalar NaN is not equal to the scalar NaN but the literal NaN + // is equal to the literal NaN (e.g. the expressions are equal even if + // the values are not) + EqualOptions equal_options = EqualOptions::Defaults().nans_equal(true); + return lit->scalar()->Equals(*other.literal()->scalar(), equal_options); + } + + if (auto ref = field_ref()) { + return ref->Equals(*other.field_ref()); + } + + auto call = CallNotNull(*this); + auto other_call = CallNotNull(other); + + if (call->function_name != other_call->function_name || + call->kernel != other_call->kernel) { + return false; + } + + for (size_t i = 0; i < call->arguments.size(); ++i) { + if (!call->arguments[i].Equals(other_call->arguments[i])) { + return false; + } + } + + if (call->options == other_call->options) return true; + if (call->options && other_call->options) { + return call->options->Equals(*other_call->options); + } + return false; +} + +bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; } + +size_t Expression::hash() const { + if (auto lit = literal()) { + if (lit->is_scalar()) { + return lit->scalar()->hash(); + } + return 0; + } + + if (auto ref = field_ref()) { + return ref->hash(); + } + + return CallNotNull(*this)->hash; +} + +bool Expression::IsBound() const { + if (type() == nullptr) return false; + + if (const Call* call = this->call()) { + if (call->kernel == nullptr) return false; + + for (const Expression& arg : call->arguments) { + if (!arg.IsBound()) return false; + } + } + + return true; +} + +bool Expression::IsScalarExpression() const { + if (auto lit = literal()) { + return lit->is_scalar(); + } + + if (field_ref()) return true; + + auto call = CallNotNull(*this); + + for (const Expression& arg : call->arguments) { + if (!arg.IsScalarExpression()) return false; + } + + if (call->function) { + return call->function->kind() == compute::Function::SCALAR; + } + + // this expression is not bound; make a best guess based on + // the default function registry + if (auto function = compute::GetFunctionRegistry() + ->GetFunction(call->function_name) + .ValueOr(nullptr)) { + return function->kind() == compute::Function::SCALAR; + } + + // unknown function or other error; conservatively return false + return false; +} + +bool Expression::IsNullLiteral() const { + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return true; + } + } + + return false; +} + +namespace { +std::optional<compute::NullHandling::type> GetNullHandling(const Expression::Call& call) { + DCHECK_NE(call.function, nullptr); + if (call.function->kind() == compute::Function::SCALAR) { + return static_cast<const compute::ScalarKernel*>(call.kernel)->null_handling; + } + return std::nullopt; +} +} // namespace + +bool Expression::IsSatisfiable() const { + if (type() == nullptr) return true; + if (type()->id() != Type::BOOL) return true; + + if (auto lit = literal()) { + if (lit->null_count() == lit->length()) { + return false; + } + + if (lit->is_scalar()) { + return lit->scalar_as<BooleanScalar>().value; + } + + return true; + } + + if (field_ref()) return true; + + auto call = CallNotNull(*this); + + // invert(true_unless_null(x)) is always false or null by definition + // true_unless_null arises in simplification of inequalities below + if (call->function_name == "invert") { + if (auto nested_call = call->arguments[0].call()) { + if (nested_call->function_name == "true_unless_null") return false; + } + } + + if (call->function_name == "and_kleene" || call->function_name == "and") { + return std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression& arg) { return arg.IsSatisfiable(); }); + } + if (call->function_name == "or_kleene" || call->function_name == "or") { + return std::any_of(call->arguments.begin(), call->arguments.end(), + [](const Expression& arg) { return arg.IsSatisfiable(); }); + } + + return true; +} + +namespace { + +TypeHolder SmallestTypeFor(const arrow20::Datum& value) { + switch (value.type()->id()) { + case Type::INT8: + return int8(); + case Type::UINT8: + return uint8(); + case Type::INT16: { + int16_t i16 = value.scalar_as<Int16Scalar>().value; + if (i16 <= std::numeric_limits<int8_t>::max() && + i16 >= std::numeric_limits<int8_t>::min()) { + return int8(); + } + return int16(); + } + case Type::UINT16: { + uint16_t ui16 = value.scalar_as<UInt16Scalar>().value; + if (ui16 <= std::numeric_limits<uint8_t>::max()) { + return uint8(); + } + return uint16(); + } + case Type::INT32: { + int32_t i32 = value.scalar_as<Int32Scalar>().value; + if (i32 <= std::numeric_limits<int8_t>::max() && + i32 >= std::numeric_limits<int8_t>::min()) { + return int8(); + } + if (i32 <= std::numeric_limits<int16_t>::max() && + i32 >= std::numeric_limits<int16_t>::min()) { + return int16(); + } + return int32(); + } + case Type::UINT32: { + uint32_t ui32 = value.scalar_as<UInt32Scalar>().value; + if (ui32 <= std::numeric_limits<uint8_t>::max()) { + return uint8(); + } + if (ui32 <= std::numeric_limits<uint16_t>::max()) { + return uint16(); + } + return uint32(); + } + case Type::INT64: { + int64_t i64 = value.scalar_as<Int64Scalar>().value; + if (i64 <= std::numeric_limits<int8_t>::max() && + i64 >= std::numeric_limits<int8_t>::min()) { + return int8(); + } + if (i64 <= std::numeric_limits<int16_t>::max() && + i64 >= std::numeric_limits<int16_t>::min()) { + return int16(); + } + if (i64 <= std::numeric_limits<int32_t>::max() && + i64 >= std::numeric_limits<int32_t>::min()) { + return int32(); + } + return int64(); + } + case Type::UINT64: { + uint64_t ui64 = value.scalar_as<UInt64Scalar>().value; + if (ui64 <= std::numeric_limits<uint8_t>::max()) { + return uint8(); + } + if (ui64 <= std::numeric_limits<uint16_t>::max()) { + return uint16(); + } + if (ui64 <= std::numeric_limits<uint32_t>::max()) { + return uint32(); + } + return uint64(); + } + case Type::DOUBLE: { + double doub = value.scalar_as<DoubleScalar>().value; + if (!std::isfinite(doub)) { + // Special values can be float + return float32(); + } + // Test if float representation is the same + if (static_cast<double>(static_cast<float>(doub)) == doub) { + return float32(); + } + return float64(); + } + case Type::LARGE_STRING: { + if (value.scalar_as<LargeStringScalar>().value->size() <= + std::numeric_limits<int32_t>::max()) { + return utf8(); + } + return large_utf8(); + } + case Type::LARGE_BINARY: + if (value.scalar_as<LargeBinaryScalar>().value->size() <= + std::numeric_limits<int32_t>::max()) { + return binary(); + } + return large_binary(); + case Type::TIMESTAMP: { + const auto& ts_type = checked_pointer_cast<TimestampType>(value.type()); + uint64_t ts = value.scalar_as<TimestampScalar>().value; + switch (ts_type->unit()) { + case TimeUnit::SECOND: + return value.type(); + case TimeUnit::MILLI: + if (ts % 1000 == 0) { + return timestamp(TimeUnit::SECOND, ts_type->timezone()); + } + return value.type(); + case TimeUnit::MICRO: + if (ts % 1000000 == 0) { + return timestamp(TimeUnit::SECOND, ts_type->timezone()); + } + if (ts % 1000 == 0) { + return timestamp(TimeUnit::MILLI, ts_type->timezone()); + } + return value.type(); + case TimeUnit::NANO: + if (ts % 1000000000 == 0) { + return timestamp(TimeUnit::SECOND, ts_type->timezone()); + } + if (ts % 1000000 == 0) { + return timestamp(TimeUnit::MILLI, ts_type->timezone()); + } + if (ts % 1000 == 0) { + return timestamp(TimeUnit::MICRO, ts_type->timezone()); + } + return value.type(); + default: + return value.type(); + } + } + default: + return value.type(); + } +} + +inline std::vector<TypeHolder> GetTypesWithSmallestLiteralRepresentation( + const std::vector<Expression>& exprs) { + std::vector<TypeHolder> types(exprs.size()); + for (size_t i = 0; i < exprs.size(); ++i) { + DCHECK(exprs[i].IsBound()); + if (const Datum* literal = exprs[i].literal()) { + if (literal->is_scalar()) { + types[i] = SmallestTypeFor(*literal); + } + } else { + types[i] = exprs[i].type(); + } + } + return types; +} + +// Produce a bound Expression from unbound Call and bound arguments. +Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_casts, + compute::ExecContext* exec_context) { + DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(), + [](const Expression& argument) { return argument.IsBound(); })); + + std::vector<TypeHolder> types = GetTypes(call.arguments); + ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context)); + + auto FinishBind = [&] { + compute::KernelContext kernel_context(exec_context, call.kernel); + if (call.kernel->init) { + const FunctionOptions* options = + call.options ? call.options.get() : call.function->default_options(); + ARROW_ASSIGN_OR_RAISE( + call.kernel_state, + call.kernel->init(&kernel_context, {call.kernel, types, options})); + + kernel_context.SetState(call.kernel_state.get()); + } + + ARROW_ASSIGN_OR_RAISE( + call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types)); + return Status::OK(); + }; + + // First try and bind exactly + Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types); + if (maybe_exact_match.ok()) { + call.kernel = *maybe_exact_match; + if (FinishBind().ok()) { + return Expression(std::move(call)); + } + } + + if (!insert_implicit_casts) { + return maybe_exact_match.status(); + } + + // If exact binding fails, and we are allowed to cast, then prefer casting literals + // first. Since DispatchBest generally prefers up-casting the best way to do this is + // first down-cast the literals as much as possible + types = GetTypesWithSmallestLiteralRepresentation(call.arguments); + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types)); + + for (size_t i = 0; i < types.size(); ++i) { + if (types[i] == call.arguments[i].type()) continue; + + if (const Datum* lit = call.arguments[i].literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, types[i].GetSharedPtr())); + call.arguments[i] = literal(std::move(new_lit)); + continue; + } + + // construct an implicit cast Expression with which to replace this argument + Expression::Call implicit_cast; + implicit_cast.function_name = "cast"; + implicit_cast.arguments = {std::move(call.arguments[i])}; + + // TODO(wesm): Use TypeHolder in options + implicit_cast.options = std::make_shared<compute::CastOptions>( + compute::CastOptions::Safe(types[i].GetSharedPtr())); + + ARROW_ASSIGN_OR_RAISE( + call.arguments[i], + BindNonRecursive(std::move(implicit_cast), + /*insert_implicit_casts=*/false, exec_context)); + } + + RETURN_NOT_OK(FinishBind()); + return Expression(std::move(call)); +} + +template <typename TypeOrSchema> +Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return BindImpl(std::move(expr), in, &exec_context); + } + + if (expr.literal()) return expr; + + if (const FieldRef* ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(FieldPath path, ref->FindOne(in)); + + Expression::Parameter param = *expr.parameter(); + param.indices.resize(path.indices().size()); + std::copy(path.indices().begin(), path.indices().end(), param.indices.begin()); + ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in)); + param.type = field->type(); + return Expression{std::move(param)}; + } + + auto call = *CallNotNull(expr); + for (auto& argument : call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context)); + } + return BindNonRecursive(std::move(call), + /*insert_implicit_casts=*/true, exec_context); +} + +} // namespace + +Result<Expression> Expression::Bind(const TypeHolder& in, + compute::ExecContext* exec_context) const { + return BindImpl(*this, *in.type, exec_context); +} + +Result<Expression> Expression::Bind(const Schema& in_schema, + compute::ExecContext* exec_context) const { + return BindImpl(*this, in_schema, exec_context); +} + +Result<ExecBatch> MakeExecBatch(const Schema& full_schema, const Datum& partial, + Expression guarantee) { + ExecBatch out; + + if (partial.kind() == Datum::RECORD_BATCH) { + const auto& partial_batch = *partial.record_batch(); + out.guarantee = std::move(guarantee); + out.length = partial_batch.num_rows(); + + ARROW_ASSIGN_OR_RAISE(auto known_field_values, + ExtractKnownFieldValues(out.guarantee)); + + for (const auto& field : full_schema.fields()) { + auto field_ref = FieldRef(field->name()); + + // If we know what the value must be from the guarantee, prefer to use that value + // than the data from the record batch (if it exists at all -- probably it doesn't), + // because this way it will be a scalar. + auto known_field_value = known_field_values.map.find(field_ref); + if (known_field_value != known_field_values.map.end()) { + out.values.emplace_back(known_field_value->second); + continue; + } + + ARROW_ASSIGN_OR_RAISE(auto column, field_ref.GetOneOrNone(partial_batch)); + if (column) { + if (!column->type()->Equals(field->type())) { + // Referenced field was present but didn't have the expected type. + // This *should* be handled by readers, and will just be an error in the future. + ARROW_ASSIGN_OR_RAISE( + auto converted, + compute::Cast(column, field->type(), compute::CastOptions::Safe())); + column = converted.make_array(); + } + out.values.emplace_back(std::move(column)); + } else { + out.values.emplace_back(MakeNullScalar(field->type())); + } + } + return out; + } + + // wasteful but useful for testing: + if (partial.type()->id() == Type::STRUCT) { + if (partial.is_array()) { + ARROW_ASSIGN_OR_RAISE(auto partial_batch, + RecordBatch::FromStructArray(partial.make_array())); + + return MakeExecBatch(full_schema, partial_batch, std::move(guarantee)); + } + + if (partial.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto partial_array, + MakeArrayFromScalar(*partial.scalar(), 1)); + ARROW_ASSIGN_OR_RAISE( + auto out, MakeExecBatch(full_schema, partial_array, std::move(guarantee))); + + for (Datum& value : out.values) { + if (value.is_scalar()) continue; + ARROW_ASSIGN_OR_RAISE(value, value.make_array()->GetScalar(0)); + } + return out; + } + } + + return Status::NotImplemented("MakeExecBatch from ", PrintDatum(partial)); +} + +Result<Datum> ExecuteScalarExpression(const Expression& expr, const Schema& full_schema, + const Datum& partial_input, + compute::ExecContext* exec_context) { + ARROW_ASSIGN_OR_RAISE(auto input, MakeExecBatch(full_schema, partial_input)); + return ExecuteScalarExpression(expr, input, exec_context); +} + +Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& input, + compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return ExecuteScalarExpression(expr, input, &exec_context); + } + + if (!expr.IsBound()) { + return Status::Invalid("Cannot Execute unbound expression."); + } + + if (!expr.IsScalarExpression()) { + return Status::Invalid( + "ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString()); + } + + if (auto lit = expr.literal()) return *lit; + + if (auto param = expr.parameter()) { + if (param->type.id() == Type::NA) { + return MakeNullScalar(null()); + } + + Datum field = input[param->indices[0]]; + if (param->indices.size() > 1) { + std::vector<int> indices(param->indices.begin() + 1, param->indices.end()); + compute::StructFieldOptions options(std::move(indices)); + ARROW_ASSIGN_OR_RAISE( + field, compute::CallFunction("struct_field", {std::move(field)}, &options)); + } + if (!field.type()->Equals(*param->type.type)) { + return Status::Invalid("Referenced field ", expr.ToString(), " was ", + field.type()->ToString(), " but should have been ", + param->type.ToString()); + } + + return field; + } + + auto call = CallNotNull(expr); + + std::vector<Datum> arguments(call->arguments.size()); + + bool all_scalar = true; + for (size_t i = 0; i < arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE( + arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context)); + all_scalar &= arguments[i].is_scalar(); + } + + int64_t input_length; + if (!arguments.empty() && all_scalar) { + // all inputs are scalar, so use a 1-long batch to avoid + // computing input.length equivalent outputs + input_length = 1; + } else { + input_length = input.length; + } + + auto executor = compute::detail::KernelExecutor::MakeScalar(); + + compute::KernelContext kernel_context(exec_context, call->kernel); + kernel_context.SetState(call->kernel_state.get()); + + const Kernel* kernel = call->kernel; + std::vector<TypeHolder> types = GetTypes(arguments); + auto options = call->options.get(); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options})); + + compute::detail::DatumAccumulator listener; + RETURN_NOT_OK( + executor->Execute(ExecBatch(std::move(arguments), input_length), &listener)); + const auto out = executor->WrapResults(arguments, listener.values()); +#ifndef NDEBUG + DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str())); +#endif + return out; +} + +namespace { + +std::array<std::pair<const Expression&, const Expression&>, 2> +ArgumentsAndFlippedArguments(const Expression::Call& call) { + DCHECK_EQ(call.arguments.size(), 2); + return {std::pair<const Expression&, const Expression&>{call.arguments[0], + call.arguments[1]}, + std::pair<const Expression&, const Expression&>{call.arguments[1], + call.arguments[0]}}; +} + +} // namespace + +std::vector<FieldRef> FieldsInExpression(const Expression& expr) { + if (expr.literal()) return {}; + + if (auto ref = expr.field_ref()) { + return {*ref}; + } + + std::vector<FieldRef> fields; + for (const Expression& arg : CallNotNull(expr)->arguments) { + auto argument_fields = FieldsInExpression(arg); + std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields)); + } + return fields; +} + +bool ExpressionHasFieldRefs(const Expression& expr) { + if (expr.literal()) return false; + + if (expr.field_ref()) return true; + + for (const Expression& arg : CallNotNull(expr)->arguments) { + if (ExpressionHasFieldRefs(arg)) return true; + } + return false; +} + +Result<Expression> FoldConstants(Expression expr) { + if (!expr.IsBound()) { + return Status::Invalid("Cannot fold constants in unbound expression."); + } + + return ModifyExpression( + std::move(expr), [](Expression expr) { return expr; }, + [](Expression expr, ...) -> Result<Expression> { + auto call = CallNotNull(expr); + if (!call->function->is_pure()) return expr; + + if (std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression& argument) { return argument.literal(); })) { + // all arguments are literal; we can evaluate this subexpression *now* + static const ExecBatch ignored_input = ExecBatch({}, 1); + ARROW_ASSIGN_OR_RAISE(Datum constant, + ExecuteScalarExpression(expr, ignored_input)); + + return literal(std::move(constant)); + } + + // XXX the following should probably be in a registry of passes instead + // of inline + + if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { + // kernels which always produce intersected validity can be resolved + // to null *now* if any of their inputs is a null literal + for (const Expression& argument : call->arguments) { + if (argument.IsNullLiteral()) { + if (argument.type()->Equals(*call->type.type)) { + return argument; + } else { + return literal(MakeNullScalar(call->type.GetSharedPtr())); + } + } + } + } + + if (call->function_name == "and_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // true and x == x + if (args.first == literal(true)) return args.second; + + // false and x == false + if (args.first == literal(false)) return args.first; + + // x and x == x + if (args.first == args.second) return args.first; + } + return expr; + } + + if (call->function_name == "or_kleene") { + for (auto args : ArgumentsAndFlippedArguments(*call)) { + // false or x == x + if (args.first == literal(false)) return args.second; + + // true or x == true + if (args.first == literal(true)) return args.first; + + // x or x == x + if (args.first == args.second) return args.first; + } + return expr; + } + + return expr; + }); +} + +namespace { + +std::vector<Expression> GuaranteeConjunctionMembers( + const Expression& guaranteed_true_predicate) { + auto guarantee = guaranteed_true_predicate.call(); + if (!guarantee || guarantee->function_name != "and_kleene") { + return {guaranteed_true_predicate}; + } + return FlattenedAssociativeChain(guaranteed_true_predicate).fringe; +} + +/// \brief Extract an equality from an expression. +/// +/// Recognizes expressions of the form: +/// equal(a, 2) +/// is_null(a) +std::optional<std::pair<FieldRef, Datum>> ExtractOneFieldValue( + const Expression& guarantee) { + auto call = guarantee.call(); + if (!call) return std::nullopt; + + // search for an equality conditions between a field and a literal + if (call->function_name == "equal") { + auto ref = call->arguments[0].field_ref(); + if (!ref) return std::nullopt; + + auto lit = call->arguments[1].literal(); + if (!lit) return std::nullopt; + + return std::make_pair(*ref, *lit); + } + + // ... or a known null field + if (call->function_name == "is_null") { + auto ref = call->arguments[0].field_ref(); + if (!ref) return std::nullopt; + + return std::make_pair(*ref, Datum(std::make_shared<NullScalar>())); + } + + return std::nullopt; +} + +// Conjunction members which are represented in known_values are erased from +// conjunction_members +Status ExtractKnownFieldValues(std::vector<Expression>* conjunction_members, + KnownFieldValues* known_values) { + // filter out consumed conjunction members, leaving only unconsumed + *conjunction_members = arrow20::internal::FilterVector( + std::move(*conjunction_members), + [known_values](const Expression& guarantee) -> bool { + if (auto known_value = ExtractOneFieldValue(guarantee)) { + known_values->map.insert(std::move(*known_value)); + return false; + } + return true; + }); + + return Status::OK(); +} + +} // namespace + +Result<KnownFieldValues> ExtractKnownFieldValues( + const Expression& guaranteed_true_predicate) { + KnownFieldValues known_values; + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + RETURN_NOT_OK(ExtractKnownFieldValues(&conjunction_members, &known_values)); + return known_values; +} + +Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values, + Expression expr) { + if (!expr.IsBound()) { + return Status::Invalid( + "ReplaceFieldsWithKnownValues called on an unbound Expression"); + } + + return ModifyExpression( + std::move(expr), + [&known_values](Expression expr) -> Result<Expression> { + if (auto ref = expr.field_ref()) { + auto it = known_values.map.find(*ref); + if (it != known_values.map.end()) { + Datum lit = it->second; + if (lit.type()->Equals(*expr.type())) return literal(std::move(lit)); + // type mismatch, try casting the known value to the correct type + + if (expr.type()->id() == Type::DICTIONARY && + lit.type()->id() != Type::DICTIONARY) { + // the known value must be dictionary encoded + + const auto& dict_type = checked_cast<const DictionaryType&>(*expr.type()); + if (!lit.type()->Equals(dict_type.value_type())) { + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, dict_type.value_type())); + } + + if (lit.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto dictionary, + MakeArrayFromScalar(*lit.scalar(), 1)); + + lit = Datum{DictionaryScalar::Make(MakeScalar<int32_t>(0), + std::move(dictionary))}; + } + } + + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()->GetSharedPtr())); + return literal(std::move(lit)); + } + } + return expr; + }, + [](Expression expr, ...) { return expr; }); +} + +namespace { + +bool IsBinaryAssociativeCommutative(const Expression::Call& call) { + static std::unordered_set<std::string> binary_associative_commutative{ + "and", "or", "and_kleene", "or_kleene", "xor", + "multiply", "add", "multiply_checked", "add_checked"}; + + auto it = binary_associative_commutative.find(call.function_name); + return it != binary_associative_commutative.end(); +} + +Result<Expression> HandleInconsistentTypes(Expression::Call call, + compute::ExecContext* exec_context) { + // ARROW-18334: due to reordering of arguments, the call may have + // inconsistent argument types. For example, the call's kernel may + // correspond to `timestamp + duration` but the arguments happen to + // be `duration, timestamp`. The addition itself is still commutative, + // but the mismatch in declared argument types is potentially problematic + // if we ever start using the Expression::Call::kernel field more than + // we do currently. Check and rebind if necessary. + // + // The more correct fix for this problem is to ensure that all kernels of + // functions which are commutative be commutative as well, which would + // obviate rebinding like this. In the context of ARROW-18334, this + // would require rewriting KernelSignature so that a single kernel can + // handle both `timestamp + duration` and `duration + timestamp`. + if (call.kernel->signature->MatchesInputs(GetTypes(call.arguments))) { + return Expression(std::move(call)); + } + return BindNonRecursive(std::move(call), /*insert_implicit_casts=*/false, exec_context); +} + +} // namespace + +Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_context) { + if (!expr.IsBound()) { + return Status::Invalid("Cannot canonicalize an unbound expression."); + } + + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return Canonicalize(std::move(expr), &exec_context); + } + + // If potentially reconstructing more deeply than a call's immediate arguments + // (for example, when reorganizing an associative chain), add expressions to this set to + // avoid unnecessary work + struct { + std::unordered_set<Expression, Expression::Hash> set_; + + bool operator()(const Expression& expr) const { + return set_.find(expr) != set_.end(); + } + + void Add(std::vector<Expression> exprs) { + std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end())); + } + } AlreadyCanonicalized; + + return ModifyExpression( + std::move(expr), + [&AlreadyCanonicalized, exec_context](Expression expr) -> Result<Expression> { + auto call = expr.call(); + if (!call) return expr; + if (!call->function->is_pure()) return expr; + + if (AlreadyCanonicalized(expr)) return expr; + + if (IsBinaryAssociativeCommutative(*call)) { + struct { + int Priority(const Expression& operand) const { + // order literals first, starting with nulls + if (operand.IsNullLiteral()) return 0; + if (operand.literal()) return 1; + return 2; + } + bool operator()(const Expression& l, const Expression& r) const { + return Priority(l) < Priority(r); + } + } CanonicalOrdering; + + FlattenedAssociativeChain chain(expr); + + if (chain.was_left_folded && + std::is_sorted(chain.fringe.begin(), chain.fringe.end(), + CanonicalOrdering)) { + // fast path for expressions which happen to have arrived in an + // already-canonical form + AlreadyCanonicalized.Add(std::move(chain.exprs)); + return expr; + } + + std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering); + + // fold the chain back up + Expression folded = std::move(chain.fringe.front()); + + for (auto it = chain.fringe.begin() + 1; it != chain.fringe.end(); ++it) { + auto canonicalized_call = *call; + canonicalized_call.arguments = {std::move(folded), std::move(*it)}; + ARROW_ASSIGN_OR_RAISE( + folded, + HandleInconsistentTypes(std::move(canonicalized_call), exec_context)); + AlreadyCanonicalized.Add({expr}); + } + return folded; + } + + if (auto cmp = Comparison::Get(call->function_name)) { + if (call->arguments[0].literal() && !call->arguments[1].literal()) { + // ensure that literals are on comparisons' RHS + auto flipped_call = *call; + + std::swap(flipped_call.arguments[0], flipped_call.arguments[1]); + flipped_call.function_name = + Comparison::GetName(Comparison::GetFlipped(*cmp)); + + return BindNonRecursive(flipped_call, + /*insert_implicit_casts=*/false, exec_context); + } + } + + return expr; + }, + [](Expression expr, ...) { return expr; }); +} + +namespace { + +// An inequality comparison which a target Expression is known to satisfy. If nullable, +// the target may evaluate to null in addition to values satisfying the comparison. +struct Inequality { + // The inequality type + Comparison::type cmp; + // The LHS of the inequality + const FieldRef& target; + // The RHS of the inequality + const Datum& bound; + // Whether target can be null + bool nullable; + + // Extract an Inequality if possible, derived from "less", + // "greater", "less_equal", and "greater_equal" expressions, + // possibly disjuncted with an "is_null" Expression. + // cmp(a, 2) + // cmp(a, 2) or is_null(a) + static std::optional<Inequality> ExtractOne(const Expression& guarantee) { + auto call = guarantee.call(); + if (!call) return std::nullopt; + + if (call->function_name == "or_kleene") { + // expect the LHS to be a usable field inequality + auto out = ExtractOneFromComparison(call->arguments[0]); + if (!out) return std::nullopt; + + // expect the RHS to be an is_null expression + auto call_rhs = call->arguments[1].call(); + if (!call_rhs) return std::nullopt; + if (call_rhs->function_name != "is_null") return std::nullopt; + + // ... and that it references the same target + auto target = call_rhs->arguments[0].field_ref(); + if (!target) return std::nullopt; + if (*target != out->target) return std::nullopt; + + out->nullable = true; + return out; + } + + // fall back to a simple comparison with no "is_null" + return ExtractOneFromComparison(guarantee); + } + + static std::optional<Inequality> ExtractOneFromComparison(const Expression& guarantee) { + auto call = guarantee.call(); + if (!call) return std::nullopt; + + if (auto cmp = Comparison::Get(call->function_name)) { + // not_equal comparisons are not very usable as guarantees + if (*cmp == Comparison::NOT_EQUAL) return std::nullopt; + + auto target = call->arguments[0].field_ref(); + if (!target) return std::nullopt; + + auto bound = call->arguments[1].literal(); + if (!bound) return std::nullopt; + if (!bound->is_scalar()) return std::nullopt; + + return Inequality{*cmp, /*target=*/*target, *bound, /*nullable=*/false}; + } + + return std::nullopt; + } + + /// The given expression simplifies to `value` if the inequality + /// target is not nullable. Otherwise, it simplifies to either a + /// call to true_unless_null or !true_unless_null. + Result<Expression> simplified_to(const Expression& bound_target, bool value) const { + if (!nullable) return literal(value); + + ExecContext exec_context; + + // Data may be null, so comparison will yield `value` - or null IFF the data was null + // + // true_unless_null is cheap; it purely reuses the validity bitmap for the values + // buffer. Inversion is less cheap but we expect that term never to be evaluated + // since invert(true_unless_null(x)) is not satisfiable. + Expression::Call call; + call.function_name = "true_unless_null"; + call.arguments = {bound_target}; + ARROW_ASSIGN_OR_RAISE( + auto true_unless_null, + BindNonRecursive(std::move(call), + /*insert_implicit_casts=*/false, &exec_context)); + if (value) return true_unless_null; + + Expression::Call invert; + invert.function_name = "invert"; + invert.arguments = {std::move(true_unless_null)}; + return BindNonRecursive(std::move(invert), + /*insert_implicit_casts=*/false, &exec_context); + } + + /// Simplify an `is_in` call against an inequality guarantee. + /// + /// We avoid the complexity of fully simplifying EQUAL comparisons to true + /// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to + /// potential complications with null matching behavior. This is ok for the + /// predicate pushdown use case because the overall aim is to simplify to an + /// unsatisfiable expression. + /// + /// \pre `is_in_call` is a call to the `is_in` function + /// \return a simplified expression, or nullopt if no simplification occurred + static Result<std::optional<Expression>> SimplifyIsIn( + const Inequality& guarantee, const Expression::Call* is_in_call) { + DCHECK_EQ(is_in_call->function_name, "is_in"); + + auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options); + + const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); + if (!lhs.field_ref()) return std::nullopt; + if (*lhs.field_ref() != guarantee.target) return std::nullopt; + + FilterOptions::NullSelectionBehavior null_selection{}; + switch (options->null_matching_behavior) { + case SetLookupOptions::MATCH: + null_selection = + guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; + break; + case SetLookupOptions::SKIP: + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::EMIT_NULL: + if (guarantee.nullable) return std::nullopt; + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::INCONCLUSIVE: + if (guarantee.nullable) return std::nullopt; + ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set)); + ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); + if (any_null.scalar_as<BooleanScalar>().value) return std::nullopt; + null_selection = FilterOptions::DROP; + break; + } + + std::string func_name = Comparison::GetName(guarantee.cmp); + DCHECK_NE(func_name, "na"); + std::vector<Datum> args{options->value_set, guarantee.bound}; + ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); + FilterOptions filter_options(null_selection); + ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, + Filter(options->value_set, filter_mask, filter_options)); + + if (simplified_value_set.length() == 0) return literal(false); + if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; + + ExecContext exec_context; + Expression::Call simplified_call; + simplified_call.function_name = "is_in"; + simplified_call.arguments = is_in_call->arguments; + simplified_call.options = std::make_shared<SetLookupOptions>( + simplified_value_set, options->null_matching_behavior); + ARROW_ASSIGN_OR_RAISE( + Expression simplified_expr, + BindNonRecursive(std::move(simplified_call), + /*insert_implicit_casts=*/false, &exec_context)); + return simplified_expr; + } + + /// \brief Simplify the given expression given this inequality as a guarantee. + Result<Expression> Simplify(Expression expr) { + const auto& guarantee = *this; + + auto call = expr.call(); + if (!call) return expr; + + if (call->function_name == "is_valid" || call->function_name == "is_null") { + if (guarantee.nullable) return expr; + const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); + if (!lhs.field_ref()) return expr; + if (*lhs.field_ref() != guarantee.target) return expr; + + return call->function_name == "is_valid" ? literal(true) : literal(false); + } + + if (call->function_name == "is_in") { + ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result, + SimplifyIsIn(guarantee, call)); + return result.value_or(expr); + } + + auto cmp = Comparison::Get(expr); + if (!cmp) return expr; + + auto rhs = call->arguments[1].literal(); + if (!rhs) return expr; + if (!rhs->is_scalar()) return expr; + + const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); + if (!lhs.field_ref()) return expr; + if (*lhs.field_ref() != guarantee.target) return expr; + + // Whether the RHS of the expression is EQUAL, LESS, or GREATER than the + // RHS of the guarantee. N.B. Comparison::type is a bitmask + ARROW_ASSIGN_OR_RAISE(const Comparison::type cmp_rhs_bound, + Comparison::Execute(*rhs, guarantee.bound)); + DCHECK_NE(cmp_rhs_bound, Comparison::NA); + + if (cmp_rhs_bound == Comparison::EQUAL) { + // RHS of filter is equal to RHS of guarantee + + if ((*cmp & guarantee.cmp) == guarantee.cmp) { + // guarantee is a subset of filter, so all data will be included + // x > 1, x >= 1, x != 1 guaranteed by x > 1 + return simplified_to(lhs, true); + } + + if ((*cmp & guarantee.cmp) == 0) { + // guarantee disjoint with filter, so all data will be excluded + // x > 1, x >= 1 unsatisfiable if x == 1 + return simplified_to(lhs, false); + } + + return expr; + } + + if (guarantee.cmp & cmp_rhs_bound) { + // We guarantee (x (?) N) and are trying to simplify (x (?) M). We know + // either M < N or M > N (i.e. cmp_rhs_bound is either LESS or GREATER). + + // If M > N, then if the guarantee is (x > N), (x >= N), or (x != N) + // (i.e. guarantee.cmp & cmp_rhs_bound), we cannot do anything with the + // guarantee, and bail out here. + + // For example, take M = 5, N = 3. Then cmp_rhs_bound = GREATER. + // x > 3, x >= 3, x != 3 implies nothing about x < 5, x <= 5, x > 5, + // x >= 5, x != 5 and we bail out here. + // x < 3, x <= 3 could simplify (some of) those expressions. + return expr; + } + + if (*cmp & Comparison::GetFlipped(cmp_rhs_bound)) { + // x > 1, x >= 1, x != 1 guaranteed by x >= 3 + // (where `guarantee.cmp` is GREATER_EQUAL, `cmp_rhs_bound` is LESS) + return simplified_to(lhs, true); + } else { + // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 + return simplified_to(lhs, false); + } + } +}; + +/// \brief Simplify an expression given a guarantee, if the guarantee +/// is is_valid(). +Result<Expression> SimplifyIsValidGuarantee(Expression expr, + const Expression::Call& guarantee) { + if (guarantee.function_name != "is_valid") return expr; + + return ModifyExpression( + std::move(expr), [](Expression expr) { return expr; }, + [&](Expression expr, ...) -> Result<Expression> { + auto call = expr.call(); + if (!call) return expr; + + if (call->arguments[0] != guarantee.arguments[0]) return expr; + + if (call->function_name == "is_valid") return literal(true); + + if (call->function_name == "true_unless_null") return literal(true); + + if (call->function_name == "is_null") return literal(false); + + return expr; + }); +} + +} // namespace + +Result<Expression> SimplifyWithGuarantee(Expression expr, + const Expression& guaranteed_true_predicate) { + KnownFieldValues known_values; + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + + RETURN_NOT_OK(ExtractKnownFieldValues(&conjunction_members, &known_values)); + + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + + auto CanonicalizeAndFoldConstants = [&expr] { + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); + ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr))); + return Status::OK(); + }; + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + + for (const auto& guarantee : conjunction_members) { + if (!guarantee.call()) continue; + + if (auto inequality = Inequality::ExtractOne(guarantee)) { + ARROW_ASSIGN_OR_RAISE(auto simplified, + ModifyExpression( + std::move(expr), [](Expression expr) { return expr; }, + [&](Expression expr, ...) -> Result<Expression> { + return inequality->Simplify(std::move(expr)); + })); + + if (Identical(simplified, expr)) continue; + + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + + if (guarantee.call()->function_name == "is_valid") { + ARROW_ASSIGN_OR_RAISE( + auto simplified, + SimplifyIsValidGuarantee(std::move(expr), *CallNotNull(guarantee))); + + if (Identical(simplified, expr)) continue; + + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + } + + return expr; +} + +Result<Expression> RemoveNamedRefs(Expression src) { + if (!src.IsBound()) { + return Status::Invalid("RemoveNamedRefs called on unbound expression"); + } + return ModifyExpression( + std::move(src), + /*pre=*/ + [](Expression expr) { + const Expression::Parameter* param = expr.parameter(); + if (param && !param->ref.IsFieldPath()) { + FieldPath ref_as_path( + std::vector<int>(param->indices.begin(), param->indices.end())); + return Expression( + Expression::Parameter{std::move(ref_as_path), param->type, param->indices}); + } + + return expr; + }, + /*post_call=*/[](Expression expr, ...) { return expr; }); +} + +// Serialization is accomplished by converting expressions to KeyValueMetadata and storing +// this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its +// columns. Finally, the RecordBatch is written to an IPC file. +Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) { +#ifdef ARROW_IPC + struct { + std::shared_ptr<KeyValueMetadata> metadata_ = std::make_shared<KeyValueMetadata>(); + ArrayVector columns_; + + Result<std::string> AddScalar(const Scalar& scalar) { + auto ret = columns_.size(); + ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(scalar, 1)); + columns_.push_back(std::move(array)); + return ToChars(ret); + } + + Status VisitFieldRef(const FieldRef& ref) { + if (ref.nested_refs()) { + metadata_->Append("nested_field_ref", ToChars(ref.nested_refs()->size())); + for (const auto& child : *ref.nested_refs()) { + RETURN_NOT_OK(VisitFieldRef(child)); + } + return Status::OK(); + } + if (!ref.name()) { + return Status::NotImplemented("Serialization of non-name field_refs"); + } + metadata_->Append("field_ref", *ref.name()); + return Status::OK(); + } + + Status Visit(const Expression& expr) { + if (auto lit = expr.literal()) { + if (!lit->is_scalar()) { + return Status::NotImplemented("Serialization of non-scalar literals"); + } + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*lit->scalar())); + metadata_->Append("literal", std::move(value)); + return Status::OK(); + } + + if (auto ref = expr.field_ref()) { + return VisitFieldRef(*ref); + } + + auto call = CallNotNull(expr); + metadata_->Append("call", call->function_name); + + for (const auto& argument : call->arguments) { + RETURN_NOT_OK(Visit(argument)); + } + + if (call->options) { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, + internal::FunctionOptionsToStructScalar(*call->options)); + ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar)); + metadata_->Append("options", std::move(value)); + } + + metadata_->Append("end", call->function_name); + return Status::OK(); + } + + Result<std::shared_ptr<RecordBatch>> operator()(const Expression& expr) { + RETURN_NOT_OK(Visit(expr)); + FieldVector fields(columns_.size()); + for (size_t i = 0; i < fields.size(); ++i) { + fields[i] = field("", columns_[i]->type()); + } + return RecordBatch::Make(schema(std::move(fields), std::move(metadata_)), 1, + std::move(columns_)); + } + } ToRecordBatch; + + ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr)); + ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create()); + ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema())); + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + RETURN_NOT_OK(writer->Close()); + return stream->Finish(); +#else + return Status::NotImplemented("IPC feature isn't enabled"); +#endif +} + +Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) { +#ifdef ARROW_IPC + io::BufferReader stream(std::move(buffer)); + ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream)); + ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0)); + if (batch->schema()->metadata() == nullptr) { + return Status::Invalid("serialized Expression's batch repr had null metadata"); + } + if (batch->num_rows() != 1) { + return Status::Invalid( + "serialized Expression's batch repr was not a single row - had ", + batch->num_rows()); + } + + struct FromRecordBatch { + const RecordBatch& batch_; + int index_; + + const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); } + + bool ParseInteger(const std::string& s, int32_t* value) { + return ::arrow20::internal::ParseValue<Int32Type>(s.data(), s.length(), value); + } + + Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) { + int32_t column_index; + if (!ParseInteger(i, &column_index)) { + return Status::Invalid("Couldn't parse column_index"); + } + if (column_index >= batch_.num_columns()) { + return Status::Invalid("column_index out of bounds"); + } + return batch_.column(column_index)->GetScalar(0); + } + + Result<Expression> GetOne() { + if (index_ >= metadata().size()) { + return Status::Invalid("unterminated serialized Expression"); + } + + const std::string& key = metadata().key(index_); + const std::string& value = metadata().value(index_); + ++index_; + + if (key == "literal") { + ARROW_ASSIGN_OR_RAISE(auto scalar, GetScalar(value)); + return literal(std::move(scalar)); + } + + if (key == "nested_field_ref") { + int32_t size; + if (!ParseInteger(value, &size)) { + return Status::Invalid("Couldn't parse nested field ref length"); + } + if (size <= 0) { + return Status::Invalid("nested field ref length must be > 0"); + } + std::vector<FieldRef> nested; + nested.reserve(size); + while (size-- > 0) { + ARROW_ASSIGN_OR_RAISE(auto ref, GetOne()); + if (!ref.field_ref()) { + return Status::Invalid("invalid nested field ref"); + } + nested.push_back(*ref.field_ref()); + } + return field_ref(FieldRef(std::move(nested))); + } + + if (key == "field_ref") { + return field_ref(value); + } + + if (key != "call") { + return Status::Invalid("Unrecognized serialized Expression key ", key); + } + + std::vector<Expression> arguments; + while (metadata().key(index_) != "end") { + if (metadata().key(index_) == "options") { + ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_))); + std::shared_ptr<compute::FunctionOptions> options; + if (options_scalar) { + ARROW_ASSIGN_OR_RAISE( + options, internal::FunctionOptionsFromStructScalar( + checked_cast<const StructScalar&>(*options_scalar))); + } + auto expr = call(value, std::move(arguments), std::move(options)); + index_ += 2; + return expr; + } + + ARROW_ASSIGN_OR_RAISE(auto argument, GetOne()); + arguments.push_back(std::move(argument)); + } + + ++index_; + return call(value, std::move(arguments)); + } + }; + + return FromRecordBatch{*batch, 0}.GetOne(); +#else + return Status::NotImplemented("IPC feature isn't enabled"); +#endif +} + +Expression project(std::vector<Expression> values, std::vector<std::string> names) { + return call("make_struct", std::move(values), + compute::MakeStructOptions{std::move(names)}); +} + +Expression equal(Expression lhs, Expression rhs) { + return call("equal", {std::move(lhs), std::move(rhs)}); +} + +Expression not_equal(Expression lhs, Expression rhs) { + return call("not_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression less(Expression lhs, Expression rhs) { + return call("less", {std::move(lhs), std::move(rhs)}); +} + +Expression less_equal(Expression lhs, Expression rhs) { + return call("less_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression greater(Expression lhs, Expression rhs) { + return call("greater", {std::move(lhs), std::move(rhs)}); +} + +Expression greater_equal(Expression lhs, Expression rhs) { + return call("greater_equal", {std::move(lhs), std::move(rhs)}); +} + +Expression is_null(Expression lhs, bool nan_is_null) { + return call("is_null", {std::move(lhs)}, compute::NullOptions(std::move(nan_is_null))); +} + +Expression is_valid(Expression lhs) { return call("is_valid", {std::move(lhs)}); } + +Expression and_(Expression lhs, Expression rhs) { + return call("and_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression and_(const std::vector<Expression>& operands) { + if (operands.empty()) return literal(true); + + Expression folded = operands.front(); + for (auto it = operands.begin() + 1; it != operands.end(); ++it) { + folded = and_(std::move(folded), *it); + } + return folded; +} + +Expression or_(Expression lhs, Expression rhs) { + return call("or_kleene", {std::move(lhs), std::move(rhs)}); +} + +Expression or_(const std::vector<Expression>& operands) { + if (operands.empty()) return literal(false); + + Expression folded = operands.front(); + for (auto it = operands.begin() + 1; it != operands.end(); ++it) { + folded = or_(std::move(folded), *it); + } + return folded; +} + +Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } + +} // namespace compute +} // namespace arrow20 |
