aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm16/lib/Analysis/TensorSpec.cpp
blob: 4f7428ded85e0fc4fb1d1d625ac4445cbd4fe61a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
//===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implementation file for the abstraction of a tensor type, and JSON loading
// utils.
//
//===----------------------------------------------------------------------===//
#include "llvm/Config/config.h"

#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/TensorSpec.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
#include <array>
#include <cassert>
#include <numeric>

using namespace llvm;

namespace llvm {

#define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
  template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }

SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)

#undef TFUTILS_GETDATATYPE_IMPL

static std::array<std::string, static_cast<size_t>(TensorType::Total)>
    TensorTypeNames{"INVALID",
#define TFUTILS_GETNAME_IMPL(T, _) #T,
                    SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
#undef TFUTILS_GETNAME_IMPL
    };

StringRef toString(TensorType TT) {
  return TensorTypeNames[static_cast<size_t>(TT)];
}

void TensorSpec::toJSON(json::OStream &OS) const {
  OS.object([&]() {
    OS.attribute("name", name());
    OS.attribute("type", toString(type()));
    OS.attribute("port", port());
    OS.attributeArray("shape", [&]() {
      for (size_t D : shape())
        OS.value(static_cast<int64_t>(D));
    });
  });
}

TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
                       size_t ElementSize, const std::vector<int64_t> &Shape)
    : Name(Name), Port(Port), Type(Type), Shape(Shape),
      ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
                                   std::multiplies<int64_t>())),
      ElementSize(ElementSize) {}

std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
                                                const json::Value &Value) {
  auto EmitError =
      [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
    std::string S;
    llvm::raw_string_ostream OS(S);
    OS << Value;
    Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
    return std::nullopt;
  };
  // FIXME: accept a Path as a parameter, and use it for error reporting.
  json::Path::Root Root("tensor_spec");
  json::ObjectMapper Mapper(Value, Root);
  if (!Mapper)
    return EmitError("Value is not a dict");

  std::string TensorName;
  int TensorPort = -1;
  std::string TensorType;
  std::vector<int64_t> TensorShape;

  if (!Mapper.map<std::string>("name", TensorName))
    return EmitError("'name' property not present or not a string");
  if (!Mapper.map<std::string>("type", TensorType))
    return EmitError("'type' property not present or not a string");
  if (!Mapper.map<int>("port", TensorPort))
    return EmitError("'port' property not present or not an int");
  if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
    return EmitError("'shape' property not present or not an int array");

#define PARSE_TYPE(T, E)                                                       \
  if (TensorType == #T)                                                        \
    return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
  SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
#undef PARSE_TYPE
  return std::nullopt;
}

} // namespace llvm