aboutsummaryrefslogblamecommitdiffstats
path: root/yql/essentials/minikql/comp_nodes/mkql_random.cpp
blob: 3fb55dff94e2efca2a974b7a201f396670aa2dc7 (plain) (tree)






























































































































































                                                                                                                                          
#include "mkql_random.h"
#include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
#include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
#include <yql/essentials/minikql/mkql_node_cast.h>
#include <yql/essentials/minikql/mkql_program_builder.h>
#include <yql/essentials/minikql/mkql_string_util.h>
#include <util/random/mersenne.h>

namespace NKikimr {
namespace NMiniKQL {

namespace {

class TRandomMTResource : public TComputationValue<TRandomMTResource> {
public:
    TRandomMTResource(TMemoryUsageInfo* memInfo, ui64 seed)
        : TComputationValue(memInfo)
        , Gen(seed)
    {
    }

private:
    NUdf::TStringRef GetResourceTag() const override {
        return NUdf::TStringRef(RandomMTResource);
    }

    void* GetResource() override {
        return &Gen;
    }

    TMersenne<ui64> Gen;
};

class TNewMTRandWrapper : public TMutableComputationNode<TNewMTRandWrapper> {
    typedef TMutableComputationNode<TNewMTRandWrapper> TBaseComputation;
public:
    TNewMTRandWrapper(TComputationMutables& mutables, IComputationNode* seed)
        : TBaseComputation(mutables)
        , Seed(seed)
    {
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
        const ui64 seedValue = Seed->GetValue(compCtx).Get<ui64>();
        return compCtx.HolderFactory.Create<TRandomMTResource>(seedValue);
    }

private:
    void RegisterDependencies() const final {
        DependsOn(Seed);
    }

    IComputationNode* const Seed;
};

class TNextMTRandWrapper : public TMutableComputationNode<TNextMTRandWrapper> {
    typedef TMutableComputationNode<TNextMTRandWrapper> TBaseComputation;
public:
    TNextMTRandWrapper(TComputationMutables& mutables, IComputationNode* rand)
        : TBaseComputation(mutables)
        , Rand(rand)
        , ResPair(mutables)
    {
    }

    NUdf::TUnboxedValue DoCalculate(TComputationContext& compCtx) const {
        auto rand = Rand->GetValue(compCtx);
        Y_DEBUG_ABORT_UNLESS(rand.GetResourceTag() == NUdf::TStringRef(RandomMTResource));
        NUdf::TUnboxedValue *items = nullptr;
        const auto tuple = ResPair.NewArray(compCtx, 2, items);
        items[0] = NUdf::TUnboxedValuePod(static_cast<TMersenne<ui64>*>(rand.GetResource())->GenRand());
        items[1] = std::move(rand);
        return tuple;
    }

private:
    void RegisterDependencies() const final {
        DependsOn(Rand);
    }

    IComputationNode* const Rand;
    const TContainerCacheOnContext ResPair;
};

template <ERandom Rnd>
class TRandomWrapper : public TMutableComputationNode<TRandomWrapper<Rnd>> {
    typedef TMutableComputationNode<TRandomWrapper<Rnd>> TBaseComputation;
public:
    TRandomWrapper(TComputationMutables& mutables, TComputationNodePtrVector&& dependentNodes)
        : TBaseComputation(mutables)
        , DependentNodes(dependentNodes)
    {
    }

    NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
        switch (Rnd) {
        case ERandom::Double:
            return NUdf::TUnboxedValuePod(ctx.RandomProvider.GenRandReal2());
        case ERandom::Number:
            return NUdf::TUnboxedValuePod(ctx.RandomProvider.GenRand64());
        case ERandom::Uuid: {
            auto uuid = ctx.RandomProvider.GenUuid4();
            return MakeString(NUdf::TStringRef((const char*)&uuid, sizeof(uuid)));
        }
        }

        Y_ABORT("Unexpected");
    }

private:
    void RegisterDependencies() const final {
        std::for_each(DependentNodes.cbegin(), DependentNodes.cend(), std::bind(&TRandomWrapper::DependsOn, this, std::placeholders::_1));
    }

    const TComputationNodePtrVector DependentNodes;
};

}

IComputationNode* WrapNewMTRand(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");

    TDataType* dataType = AS_TYPE(TDataType, callable.GetInput(0));
    MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<ui64>::Id,
        "Expected ui64");

    auto data = LocateNode(ctx.NodeLocator, callable, 0);
    return new TNewMTRandWrapper(ctx.Mutables, data);
}

IComputationNode* WrapNextMTRand(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");

    AS_TYPE(TResourceType, callable.GetInput(0));

    auto rand = LocateNode(ctx.NodeLocator, callable, 0);
    return new TNextMTRandWrapper(ctx.Mutables, rand);
}

template <ERandom Rnd>
IComputationNode* WrapRandom(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
    TComputationNodePtrVector dependentNodes(callable.GetInputsCount());
    for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
        dependentNodes[i] = LocateNode(ctx.NodeLocator, callable, i);
    }

    return new TRandomWrapper<Rnd>(ctx.Mutables, std::move(dependentNodes));
}

template
IComputationNode* WrapRandom<ERandom::Double>(TCallable& callable, const TComputationNodeFactoryContext& ctx);

template
IComputationNode* WrapRandom<ERandom::Number>(TCallable& callable, const TComputationNodeFactoryContext& ctx);

template
IComputationNode* WrapRandom<ERandom::Uuid>(TCallable& callable, const TComputationNodeFactoryContext& ctx);

}
}