diff options
author | vitalyisaev <vitalyisaev@yandex-team.com> | 2023-06-29 10:00:50 +0300 |
---|---|---|
committer | vitalyisaev <vitalyisaev@yandex-team.com> | 2023-06-29 10:00:50 +0300 |
commit | 6ffe9e53658409f212834330e13564e4952558f6 (patch) | |
tree | 85b1e00183517648b228aafa7c8fb07f5276f419 /contrib/libs/llvm16/lib/Analysis/TrainingLogger.cpp | |
parent | 726057070f9c5a91fc10fde0d5024913d10f1ab9 (diff) | |
download | ydb-6ffe9e53658409f212834330e13564e4952558f6.tar.gz |
YQ Connector: support managed ClickHouse
Со стороны dqrun можно обратиться к инстансу коннектора, который работает на streaming стенде, и извлечь данные из облачного CH.
Diffstat (limited to 'contrib/libs/llvm16/lib/Analysis/TrainingLogger.cpp')
-rw-r--r-- | contrib/libs/llvm16/lib/Analysis/TrainingLogger.cpp | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/contrib/libs/llvm16/lib/Analysis/TrainingLogger.cpp b/contrib/libs/llvm16/lib/Analysis/TrainingLogger.cpp new file mode 100644 index 0000000000..dcee8d40c5 --- /dev/null +++ b/contrib/libs/llvm16/lib/Analysis/TrainingLogger.cpp @@ -0,0 +1,88 @@ +//===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logging infrastructure for extracting features and +// rewards for mlgo policy training. +// +//===----------------------------------------------------------------------===// +#include "llvm/Analysis/TensorSpec.h" +#include "llvm/Config/config.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/Utils/TrainingLogger.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +#include <cassert> +#include <numeric> + +using namespace llvm; + +// FIXME(mtrofin): remove the flag altogether +static cl::opt<bool> + UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden, + cl::desc("Output simple (non-protobuf) log.")); + +void Logger::writeHeader() { + json::OStream JOS(*OS); + JOS.object([&]() { + JOS.attributeArray("features", [&]() { + for (const auto &TS : FeatureSpecs) + TS.toJSON(JOS); + }); + if (IncludeReward) { + JOS.attributeBegin("score"); + RewardSpec.toJSON(JOS); + JOS.attributeEnd(); + } + }); + *OS << "\n"; +} + +void Logger::switchContext(StringRef Name) { + CurrentContext = Name.str(); + json::OStream JOS(*OS); + JOS.object([&]() { JOS.attribute("context", Name); }); + *OS << "\n"; +} + +void Logger::startObservation() { + auto I = ObservationIDs.insert({CurrentContext, 0}); + size_t NewObservationID = I.second ? 0 : ++I.first->second; + json::OStream JOS(*OS); + JOS.object([&]() { + JOS.attribute("observation", static_cast<int64_t>(NewObservationID)); + }); + *OS << "\n"; +} + +void Logger::endObservation() { *OS << "\n"; } + +void Logger::logRewardImpl(const char *RawData) { + assert(IncludeReward); + json::OStream JOS(*OS); + JOS.object([&]() { + JOS.attribute("outcome", static_cast<int64_t>( + ObservationIDs.find(CurrentContext)->second)); + }); + *OS << "\n"; + writeTensor(RewardSpec, RawData); + *OS << "\n"; +} + +Logger::Logger(std::unique_ptr<raw_ostream> OS, + const std::vector<TensorSpec> &FeatureSpecs, + const TensorSpec &RewardSpec, bool IncludeReward) + : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), + IncludeReward(IncludeReward) { + writeHeader(); +} |