diff options
author | nadya02 <nadya02@yandex-team.com> | 2023-09-02 19:32:52 +0300 |
---|---|---|
committer | nadya02 <nadya02@yandex-team.com> | 2023-09-02 19:47:15 +0300 |
commit | f3351138d25ba9b9c86194ca826e0cb1257aff89 (patch) | |
tree | 54f85f476a52259687b2f6f24cd66facbf014216 | |
parent | c9811db61e454037546a71b3d1b6e2dc4dbb1a9d (diff) | |
download | ydb-f3351138d25ba9b9c86194ca826e0cb1257aff89.tar.gz |
YT-19430: Add arrow writer
Add arrow writer
283 files changed, 47077 insertions, 0 deletions
diff --git a/contrib/libs/backtrace/macho.c b/contrib/libs/backtrace/macho.c new file mode 100644 index 0000000000..d00aea9bc8 --- /dev/null +++ b/contrib/libs/backtrace/macho.c @@ -0,0 +1,1355 @@ +/* elf.c -- Get debug data from a Mach-O file for backtraces. + Copyright (C) 2020-2021 Free Software Foundation, Inc. + Written by Ian Lance Taylor, Google. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + (1) Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + (2) Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + + (3) The name of the author may not be used to + endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, +INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. */ + +#include "config.h" + +#include <sys/types.h> +#include <dirent.h> +#include <stdlib.h> +#include <string.h> + +#ifdef HAVE_MACH_O_DYLD_H +#include <mach-o/dyld.h> +#endif + +#include "backtrace.h" +#include "internal.h" + +/* Mach-O file header for a 32-bit executable. */ + +struct macho_header_32 +{ + uint32_t magic; /* Magic number (MACH_O_MAGIC_32) */ + uint32_t cputype; /* CPU type */ + uint32_t cpusubtype; /* CPU subtype */ + uint32_t filetype; /* Type of file (object, executable) */ + uint32_t ncmds; /* Number of load commands */ + uint32_t sizeofcmds; /* Total size of load commands */ + uint32_t flags; /* Flags for special features */ +}; + +/* Mach-O file header for a 64-bit executable. */ + +struct macho_header_64 +{ + uint32_t magic; /* Magic number (MACH_O_MAGIC_64) */ + uint32_t cputype; /* CPU type */ + uint32_t cpusubtype; /* CPU subtype */ + uint32_t filetype; /* Type of file (object, executable) */ + uint32_t ncmds; /* Number of load commands */ + uint32_t sizeofcmds; /* Total size of load commands */ + uint32_t flags; /* Flags for special features */ + uint32_t reserved; /* Reserved */ +}; + +/* Mach-O file header for a fat executable. */ + +struct macho_header_fat +{ + uint32_t magic; /* Magic number (MACH_O_MH_(MAGIC|CIGAM)_FAT(_64)?) */ + uint32_t nfat_arch; /* Number of components */ +}; + +/* Values for the header magic field. */ + +#define MACH_O_MH_MAGIC_32 0xfeedface +#define MACH_O_MH_MAGIC_64 0xfeedfacf +#define MACH_O_MH_MAGIC_FAT 0xcafebabe +#define MACH_O_MH_CIGAM_FAT 0xbebafeca +#define MACH_O_MH_MAGIC_FAT_64 0xcafebabf +#define MACH_O_MH_CIGAM_FAT_64 0xbfbafeca + +/* Value for the header filetype field. */ + +#define MACH_O_MH_EXECUTE 0x02 +#define MACH_O_MH_DYLIB 0x06 +#define MACH_O_MH_DSYM 0x0a + +/* A component of a fat file. A fat file starts with a + macho_header_fat followed by nfat_arch instances of this + struct. */ + +struct macho_fat_arch +{ + uint32_t cputype; /* CPU type */ + uint32_t cpusubtype; /* CPU subtype */ + uint32_t offset; /* File offset of this entry */ + uint32_t size; /* Size of this entry */ + uint32_t align; /* Alignment of this entry */ +}; + +/* A component of a 64-bit fat file. This is used if the magic field + is MAGIC_FAT_64. This is only used when some file size or file + offset is too large to represent in the 32-bit format. */ + +struct macho_fat_arch_64 +{ + uint32_t cputype; /* CPU type */ + uint32_t cpusubtype; /* CPU subtype */ + uint64_t offset; /* File offset of this entry */ + uint64_t size; /* Size of this entry */ + uint32_t align; /* Alignment of this entry */ + uint32_t reserved; /* Reserved */ +}; + +/* Values for the fat_arch cputype field (and the header cputype + field). */ + +#define MACH_O_CPU_ARCH_ABI64 0x01000000 + +#define MACH_O_CPU_TYPE_X86 7 +#define MACH_O_CPU_TYPE_ARM 12 +#define MACH_O_CPU_TYPE_PPC 18 + +#define MACH_O_CPU_TYPE_X86_64 (MACH_O_CPU_TYPE_X86 | MACH_O_CPU_ARCH_ABI64) +#define MACH_O_CPU_TYPE_ARM64 (MACH_O_CPU_TYPE_ARM | MACH_O_CPU_ARCH_ABI64) +#define MACH_O_CPU_TYPE_PPC64 (MACH_O_CPU_TYPE_PPC | MACH_O_CPU_ARCH_ABI64) + +/* The header of a load command. */ + +struct macho_load_command +{ + uint32_t cmd; /* The type of load command */ + uint32_t cmdsize; /* Size in bytes of the entire command */ +}; + +/* Values for the load_command cmd field. */ + +#define MACH_O_LC_SEGMENT 0x01 +#define MACH_O_LC_SYMTAB 0x02 +#define MACH_O_LC_SEGMENT_64 0x19 +#define MACH_O_LC_UUID 0x1b + +/* The length of a section of segment name. */ + +#define MACH_O_NAMELEN (16) + +/* LC_SEGMENT load command. */ + +struct macho_segment_command +{ + uint32_t cmd; /* The type of load command (LC_SEGMENT) */ + uint32_t cmdsize; /* Size in bytes of the entire command */ + char segname[MACH_O_NAMELEN]; /* Segment name */ + uint32_t vmaddr; /* Virtual memory address */ + uint32_t vmsize; /* Virtual memory size */ + uint32_t fileoff; /* Offset of data to be mapped */ + uint32_t filesize; /* Size of data in file */ + uint32_t maxprot; /* Maximum permitted virtual protection */ + uint32_t initprot; /* Initial virtual memory protection */ + uint32_t nsects; /* Number of sections in this segment */ + uint32_t flags; /* Flags */ +}; + +/* LC_SEGMENT_64 load command. */ + +struct macho_segment_64_command +{ + uint32_t cmd; /* The type of load command (LC_SEGMENT) */ + uint32_t cmdsize; /* Size in bytes of the entire command */ + char segname[MACH_O_NAMELEN]; /* Segment name */ + uint64_t vmaddr; /* Virtual memory address */ + uint64_t vmsize; /* Virtual memory size */ + uint64_t fileoff; /* Offset of data to be mapped */ + uint64_t filesize; /* Size of data in file */ + uint32_t maxprot; /* Maximum permitted virtual protection */ + uint32_t initprot; /* Initial virtual memory protection */ + uint32_t nsects; /* Number of sections in this segment */ + uint32_t flags; /* Flags */ +}; + +/* LC_SYMTAB load command. */ + +struct macho_symtab_command +{ + uint32_t cmd; /* The type of load command (LC_SEGMENT) */ + uint32_t cmdsize; /* Size in bytes of the entire command */ + uint32_t symoff; /* File offset of symbol table */ + uint32_t nsyms; /* Number of symbols */ + uint32_t stroff; /* File offset of string table */ + uint32_t strsize; /* String table size */ +}; + +/* The length of a Mach-O uuid. */ + +#define MACH_O_UUID_LEN (16) + +/* LC_UUID load command. */ + +struct macho_uuid_command +{ + uint32_t cmd; /* Type of load command (LC_UUID) */ + uint32_t cmdsize; /* Size in bytes of command */ + unsigned char uuid[MACH_O_UUID_LEN]; /* UUID */ +}; + +/* 32-bit section header within a LC_SEGMENT segment. */ + +struct macho_section +{ + char sectname[MACH_O_NAMELEN]; /* Section name */ + char segment[MACH_O_NAMELEN]; /* Segment of this section */ + uint32_t addr; /* Address in memory */ + uint32_t size; /* Section size */ + uint32_t offset; /* File offset */ + uint32_t align; /* Log2 of section alignment */ + uint32_t reloff; /* File offset of relocations */ + uint32_t nreloc; /* Number of relocs for this section */ + uint32_t flags; /* Flags */ + uint32_t reserved1; + uint32_t reserved2; +}; + +/* 64-bit section header within a LC_SEGMENT_64 segment. */ + +struct macho_section_64 +{ + char sectname[MACH_O_NAMELEN]; /* Section name */ + char segment[MACH_O_NAMELEN]; /* Segment of this section */ + uint64_t addr; /* Address in memory */ + uint64_t size; /* Section size */ + uint32_t offset; /* File offset */ + uint32_t align; /* Log2 of section alignment */ + uint32_t reloff; /* File offset of section relocations */ + uint32_t nreloc; /* Number of relocs for this section */ + uint32_t flags; /* Flags */ + uint32_t reserved1; + uint32_t reserved2; + uint32_t reserved3; +}; + +/* 32-bit symbol data. */ + +struct macho_nlist +{ + uint32_t n_strx; /* Index of name in string table */ + uint8_t n_type; /* Type flag */ + uint8_t n_sect; /* Section number */ + uint16_t n_desc; /* Stabs description field */ + uint32_t n_value; /* Value */ +}; + +/* 64-bit symbol data. */ + +struct macho_nlist_64 +{ + uint32_t n_strx; /* Index of name in string table */ + uint8_t n_type; /* Type flag */ + uint8_t n_sect; /* Section number */ + uint16_t n_desc; /* Stabs description field */ + uint64_t n_value; /* Value */ +}; + +/* Value found in nlist n_type field. */ + +#define MACH_O_N_EXT 0x01 /* Extern symbol */ +#define MACH_O_N_ABS 0x02 /* Absolute symbol */ +#define MACH_O_N_SECT 0x0e /* Defined in section */ + +#define MACH_O_N_TYPE 0x0e /* Mask for type bits */ +#define MACH_O_N_STAB 0xe0 /* Stabs debugging symbol */ + +/* Information we keep for a Mach-O symbol. */ + +struct macho_symbol +{ + const char *name; /* Symbol name */ + uintptr_t address; /* Symbol address */ +}; + +/* Information to pass to macho_syminfo. */ + +struct macho_syminfo_data +{ + struct macho_syminfo_data *next; /* Next module */ + struct macho_symbol *symbols; /* Symbols sorted by address */ + size_t count; /* Number of symbols */ +}; + +/* Names of sections, indexed by enum dwarf_section in internal.h. */ + +static const char * const dwarf_section_names[DEBUG_MAX] = +{ + "__debug_info", + "__debug_line", + "__debug_abbrev", + "__debug_ranges", + "__debug_str", + "", /* DEBUG_ADDR */ + "__debug_str_offs", + "", /* DEBUG_LINE_STR */ + "__debug_rnglists" +}; + +/* Forward declaration. */ + +static int macho_add (struct backtrace_state *, const char *, int, off_t, + const unsigned char *, uintptr_t, int, + backtrace_error_callback, void *, fileline *, int *); + +/* A dummy callback function used when we can't find any debug info. */ + +static int +macho_nodebug (struct backtrace_state *state ATTRIBUTE_UNUSED, + uintptr_t pc ATTRIBUTE_UNUSED, + backtrace_full_callback callback ATTRIBUTE_UNUSED, + backtrace_error_callback error_callback, void *data) +{ + error_callback (data, "no debug info in Mach-O executable", -1); + return 0; +} + +/* A dummy callback function used when we can't find a symbol + table. */ + +static void +macho_nosyms (struct backtrace_state *state ATTRIBUTE_UNUSED, + uintptr_t addr ATTRIBUTE_UNUSED, + backtrace_syminfo_callback callback ATTRIBUTE_UNUSED, + backtrace_error_callback error_callback, void *data) +{ + error_callback (data, "no symbol table in Mach-O executable", -1); +} + +/* Add a single DWARF section to DWARF_SECTIONS, if we need the + section. Returns 1 on success, 0 on failure. */ + +static int +macho_add_dwarf_section (struct backtrace_state *state, int descriptor, + const char *sectname, uint32_t offset, uint64_t size, + backtrace_error_callback error_callback, void *data, + struct dwarf_sections *dwarf_sections) +{ + int i; + + for (i = 0; i < (int) DEBUG_MAX; ++i) + { + if (dwarf_section_names[i][0] != '\0' + && strncmp (sectname, dwarf_section_names[i], MACH_O_NAMELEN) == 0) + { + struct backtrace_view section_view; + + /* FIXME: Perhaps it would be better to try to use a single + view to read all the DWARF data, as we try to do for + ELF. */ + + if (!backtrace_get_view (state, descriptor, offset, size, + error_callback, data, §ion_view)) + return 0; + dwarf_sections->data[i] = (const unsigned char *) section_view.data; + dwarf_sections->size[i] = size; + break; + } + } + return 1; +} + +/* Collect DWARF sections from a DWARF segment. Returns 1 on success, + 0 on failure. */ + +static int +macho_add_dwarf_segment (struct backtrace_state *state, int descriptor, + off_t offset, unsigned int cmd, const char *psecs, + size_t sizesecs, unsigned int nsects, + backtrace_error_callback error_callback, void *data, + struct dwarf_sections *dwarf_sections) +{ + size_t sec_header_size; + size_t secoffset; + unsigned int i; + + switch (cmd) + { + case MACH_O_LC_SEGMENT: + sec_header_size = sizeof (struct macho_section); + break; + case MACH_O_LC_SEGMENT_64: + sec_header_size = sizeof (struct macho_section_64); + break; + default: + abort (); + } + + secoffset = 0; + for (i = 0; i < nsects; ++i) + { + if (secoffset + sec_header_size > sizesecs) + { + error_callback (data, "section overflow withing segment", 0); + return 0; + } + + switch (cmd) + { + case MACH_O_LC_SEGMENT: + { + struct macho_section section; + + memcpy (§ion, psecs + secoffset, sizeof section); + macho_add_dwarf_section (state, descriptor, section.sectname, + offset + section.offset, section.size, + error_callback, data, dwarf_sections); + } + break; + + case MACH_O_LC_SEGMENT_64: + { + struct macho_section_64 section; + + memcpy (§ion, psecs + secoffset, sizeof section); + macho_add_dwarf_section (state, descriptor, section.sectname, + offset + section.offset, section.size, + error_callback, data, dwarf_sections); + } + break; + + default: + abort (); + } + + secoffset += sec_header_size; + } + + return 1; +} + +/* Compare struct macho_symbol for qsort. */ + +static int +macho_symbol_compare (const void *v1, const void *v2) +{ + const struct macho_symbol *m1 = (const struct macho_symbol *) v1; + const struct macho_symbol *m2 = (const struct macho_symbol *) v2; + + if (m1->address < m2->address) + return -1; + else if (m1->address > m2->address) + return 1; + else + return 0; +} + +/* Compare an address against a macho_symbol for bsearch. We allocate + one extra entry in the array so that this can safely look at the + next entry. */ + +static int +macho_symbol_search (const void *vkey, const void *ventry) +{ + const uintptr_t *key = (const uintptr_t *) vkey; + const struct macho_symbol *entry = (const struct macho_symbol *) ventry; + uintptr_t addr; + + addr = *key; + if (addr < entry->address) + return -1; + else if (entry->name[0] == '\0' + && entry->address == ~(uintptr_t) 0) + return -1; + else if ((entry + 1)->name[0] == '\0' + && (entry + 1)->address == ~(uintptr_t) 0) + return -1; + else if (addr >= (entry + 1)->address) + return 1; + else + return 0; +} + +/* Return whether the symbol type field indicates a symbol table entry + that we care about: a function or data symbol. */ + +static int +macho_defined_symbol (uint8_t type) +{ + if ((type & MACH_O_N_STAB) != 0) + return 0; + if ((type & MACH_O_N_EXT) != 0) + return 0; + switch (type & MACH_O_N_TYPE) + { + case MACH_O_N_ABS: + return 1; + case MACH_O_N_SECT: + return 1; + default: + return 0; + } +} + +/* Add symbol table information for a Mach-O file. */ + +static int +macho_add_symtab (struct backtrace_state *state, int descriptor, + uintptr_t base_address, int is_64, + off_t symoff, unsigned int nsyms, off_t stroff, + unsigned int strsize, + backtrace_error_callback error_callback, void *data) +{ + size_t symsize; + struct backtrace_view sym_view; + int sym_view_valid; + struct backtrace_view str_view; + int str_view_valid; + size_t ndefs; + size_t symtaboff; + unsigned int i; + size_t macho_symbol_size; + struct macho_symbol *macho_symbols; + unsigned int j; + struct macho_syminfo_data *sdata; + + sym_view_valid = 0; + str_view_valid = 0; + macho_symbol_size = 0; + macho_symbols = NULL; + + if (is_64) + symsize = sizeof (struct macho_nlist_64); + else + symsize = sizeof (struct macho_nlist); + + if (!backtrace_get_view (state, descriptor, symoff, nsyms * symsize, + error_callback, data, &sym_view)) + goto fail; + sym_view_valid = 1; + + if (!backtrace_get_view (state, descriptor, stroff, strsize, + error_callback, data, &str_view)) + return 0; + str_view_valid = 1; + + ndefs = 0; + symtaboff = 0; + for (i = 0; i < nsyms; ++i, symtaboff += symsize) + { + if (is_64) + { + struct macho_nlist_64 nlist; + + memcpy (&nlist, (const char *) sym_view.data + symtaboff, + sizeof nlist); + if (macho_defined_symbol (nlist.n_type)) + ++ndefs; + } + else + { + struct macho_nlist nlist; + + memcpy (&nlist, (const char *) sym_view.data + symtaboff, + sizeof nlist); + if (macho_defined_symbol (nlist.n_type)) + ++ndefs; + } + } + + /* Add 1 to ndefs to make room for a sentinel. */ + macho_symbol_size = (ndefs + 1) * sizeof (struct macho_symbol); + macho_symbols = ((struct macho_symbol *) + backtrace_alloc (state, macho_symbol_size, error_callback, + data)); + if (macho_symbols == NULL) + goto fail; + + j = 0; + symtaboff = 0; + for (i = 0; i < nsyms; ++i, symtaboff += symsize) + { + uint32_t strx; + uint64_t value; + const char *name; + + strx = 0; + value = 0; + if (is_64) + { + struct macho_nlist_64 nlist; + + memcpy (&nlist, (const char *) sym_view.data + symtaboff, + sizeof nlist); + if (!macho_defined_symbol (nlist.n_type)) + continue; + + strx = nlist.n_strx; + value = nlist.n_value; + } + else + { + struct macho_nlist nlist; + + memcpy (&nlist, (const char *) sym_view.data + symtaboff, + sizeof nlist); + if (!macho_defined_symbol (nlist.n_type)) + continue; + + strx = nlist.n_strx; + value = nlist.n_value; + } + + if (strx >= strsize) + { + error_callback (data, "symbol string index out of range", 0); + goto fail; + } + + name = (const char *) str_view.data + strx; + if (name[0] == '_') + ++name; + macho_symbols[j].name = name; + macho_symbols[j].address = value + base_address; + ++j; + } + + sdata = ((struct macho_syminfo_data *) + backtrace_alloc (state, sizeof *sdata, error_callback, data)); + if (sdata == NULL) + goto fail; + + /* We need to keep the string table since it holds the names, but we + can release the symbol table. */ + + backtrace_release_view (state, &sym_view, error_callback, data); + sym_view_valid = 0; + str_view_valid = 0; + + /* Add a trailing sentinel symbol. */ + macho_symbols[j].name = ""; + macho_symbols[j].address = ~(uintptr_t) 0; + + backtrace_qsort (macho_symbols, ndefs + 1, sizeof (struct macho_symbol), + macho_symbol_compare); + + sdata->next = NULL; + sdata->symbols = macho_symbols; + sdata->count = ndefs; + + if (!state->threaded) + { + struct macho_syminfo_data **pp; + + for (pp = (struct macho_syminfo_data **) (void *) &state->syminfo_data; + *pp != NULL; + pp = &(*pp)->next) + ; + *pp = sdata; + } + else + { + while (1) + { + struct macho_syminfo_data **pp; + + pp = (struct macho_syminfo_data **) (void *) &state->syminfo_data; + + while (1) + { + struct macho_syminfo_data *p; + + p = backtrace_atomic_load_pointer (pp); + + if (p == NULL) + break; + + pp = &p->next; + } + + if (__sync_bool_compare_and_swap (pp, NULL, sdata)) + break; + } + } + + return 1; + + fail: + if (macho_symbols != NULL) + backtrace_free (state, macho_symbols, macho_symbol_size, + error_callback, data); + if (sym_view_valid) + backtrace_release_view (state, &sym_view, error_callback, data); + if (str_view_valid) + backtrace_release_view (state, &str_view, error_callback, data); + return 0; +} + +/* Return the symbol name and value for an ADDR. */ + +static void +macho_syminfo (struct backtrace_state *state, uintptr_t addr, + backtrace_syminfo_callback callback, + backtrace_error_callback error_callback ATTRIBUTE_UNUSED, + void *data) +{ + struct macho_syminfo_data *sdata; + struct macho_symbol *sym; + + sym = NULL; + if (!state->threaded) + { + for (sdata = (struct macho_syminfo_data *) state->syminfo_data; + sdata != NULL; + sdata = sdata->next) + { + sym = ((struct macho_symbol *) + bsearch (&addr, sdata->symbols, sdata->count, + sizeof (struct macho_symbol), macho_symbol_search)); + if (sym != NULL) + break; + } + } + else + { + struct macho_syminfo_data **pp; + + pp = (struct macho_syminfo_data **) (void *) &state->syminfo_data; + while (1) + { + sdata = backtrace_atomic_load_pointer (pp); + if (sdata == NULL) + break; + + sym = ((struct macho_symbol *) + bsearch (&addr, sdata->symbols, sdata->count, + sizeof (struct macho_symbol), macho_symbol_search)); + if (sym != NULL) + break; + + pp = &sdata->next; + } + } + + if (sym == NULL) + callback (data, addr, NULL, 0, 0); + else + callback (data, addr, sym->name, sym->address, 0); +} + +/* Look through a fat file to find the relevant executable. Returns 1 + on success, 0 on failure (in both cases descriptor is closed). */ + +static int +macho_add_fat (struct backtrace_state *state, const char *filename, + int descriptor, int swapped, off_t offset, + const unsigned char *match_uuid, uintptr_t base_address, + int skip_symtab, uint32_t nfat_arch, int is_64, + backtrace_error_callback error_callback, void *data, + fileline *fileline_fn, int *found_sym) +{ + int arch_view_valid; + unsigned int cputype; + size_t arch_size; + struct backtrace_view arch_view; + unsigned int i; + + arch_view_valid = 0; + +#if defined (__x86_64__) + cputype = MACH_O_CPU_TYPE_X86_64; +#elif defined (__i386__) + cputype = MACH_O_CPU_TYPE_X86; +#elif defined (__aarch64__) + cputype = MACH_O_CPU_TYPE_ARM64; +#elif defined (__arm__) + cputype = MACH_O_CPU_TYPE_ARM; +#elif defined (__ppc__) + cputype = MACH_O_CPU_TYPE_PPC; +#elif defined (__ppc64__) + cputype = MACH_O_CPU_TYPE_PPC64; +#else + error_callback (data, "unknown Mach-O architecture", 0); + goto fail; +#endif + + if (is_64) + arch_size = sizeof (struct macho_fat_arch_64); + else + arch_size = sizeof (struct macho_fat_arch); + + if (!backtrace_get_view (state, descriptor, offset, + nfat_arch * arch_size, + error_callback, data, &arch_view)) + goto fail; + + for (i = 0; i < nfat_arch; ++i) + { + uint32_t fcputype; + uint64_t foffset; + + if (is_64) + { + struct macho_fat_arch_64 fat_arch_64; + + memcpy (&fat_arch_64, + (const char *) arch_view.data + i * arch_size, + arch_size); + fcputype = fat_arch_64.cputype; + foffset = fat_arch_64.offset; + if (swapped) + { + fcputype = __builtin_bswap32 (fcputype); + foffset = __builtin_bswap64 (foffset); + } + } + else + { + struct macho_fat_arch fat_arch_32; + + memcpy (&fat_arch_32, + (const char *) arch_view.data + i * arch_size, + arch_size); + fcputype = fat_arch_32.cputype; + foffset = (uint64_t) fat_arch_32.offset; + if (swapped) + { + fcputype = __builtin_bswap32 (fcputype); + foffset = (uint64_t) __builtin_bswap32 ((uint32_t) foffset); + } + } + + if (fcputype == cputype) + { + /* FIXME: What about cpusubtype? */ + backtrace_release_view (state, &arch_view, error_callback, data); + return macho_add (state, filename, descriptor, foffset, match_uuid, + base_address, skip_symtab, error_callback, data, + fileline_fn, found_sym); + } + } + + error_callback (data, "could not find executable in fat file", 0); + + fail: + if (arch_view_valid) + backtrace_release_view (state, &arch_view, error_callback, data); + if (descriptor != -1) + backtrace_close (descriptor, error_callback, data); + return 0; +} + +/* Look for the dsym file for FILENAME. This is called if FILENAME + does not have debug info or a symbol table. Returns 1 on success, + 0 on failure. */ + +static int +macho_add_dsym (struct backtrace_state *state, const char *filename, + uintptr_t base_address, const unsigned char *uuid, + backtrace_error_callback error_callback, void *data, + fileline* fileline_fn) +{ + const char *p; + const char *dirname; + char *diralc; + size_t dirnamelen; + const char *basename; + size_t basenamelen; + const char *dsymsuffixdir; + size_t dsymsuffixdirlen; + size_t dsymlen; + char *dsym; + char *ps; + int d; + int does_not_exist; + int dummy_found_sym; + + diralc = NULL; + dirnamelen = 0; + dsym = NULL; + dsymlen = 0; + + p = strrchr (filename, '/'); + if (p == NULL) + { + dirname = "."; + dirnamelen = 1; + basename = filename; + basenamelen = strlen (basename); + diralc = NULL; + } + else + { + dirnamelen = p - filename; + diralc = backtrace_alloc (state, dirnamelen + 1, error_callback, data); + if (diralc == NULL) + goto fail; + memcpy (diralc, filename, dirnamelen); + diralc[dirnamelen] = '\0'; + dirname = diralc; + basename = p + 1; + basenamelen = strlen (basename); + } + + dsymsuffixdir = ".dSYM/Contents/Resources/DWARF/"; + dsymsuffixdirlen = strlen (dsymsuffixdir); + + dsymlen = (dirnamelen + + 1 + + basenamelen + + dsymsuffixdirlen + + basenamelen + + 1); + dsym = backtrace_alloc (state, dsymlen, error_callback, data); + if (dsym == NULL) + goto fail; + + ps = dsym; + memcpy (ps, dirname, dirnamelen); + ps += dirnamelen; + *ps++ = '/'; + memcpy (ps, basename, basenamelen); + ps += basenamelen; + memcpy (ps, dsymsuffixdir, dsymsuffixdirlen); + ps += dsymsuffixdirlen; + memcpy (ps, basename, basenamelen); + ps += basenamelen; + *ps = '\0'; + + if (diralc != NULL) + { + backtrace_free (state, diralc, dirnamelen + 1, error_callback, data); + diralc = NULL; + } + + d = backtrace_open (dsym, error_callback, data, &does_not_exist); + if (d < 0) + { + /* The file does not exist, so we can't read the debug info. + Just return success. */ + backtrace_free (state, dsym, dsymlen, error_callback, data); + return 1; + } + + if (!macho_add (state, dsym, d, 0, uuid, base_address, 1, + error_callback, data, fileline_fn, &dummy_found_sym)) + goto fail; + + backtrace_free (state, dsym, dsymlen, error_callback, data); + + return 1; + + fail: + if (dsym != NULL) + backtrace_free (state, dsym, dsymlen, error_callback, data); + if (diralc != NULL) + backtrace_free (state, diralc, dirnamelen, error_callback, data); + return 0; +} + +/* Add the backtrace data for a Macho-O file. Returns 1 on success, 0 + on failure (in both cases descriptor is closed). + + FILENAME: the name of the executable. + DESCRIPTOR: an open descriptor for the executable, closed here. + OFFSET: the offset within the file of this executable, for fat files. + MATCH_UUID: if not NULL, UUID that must match. + BASE_ADDRESS: the load address of the executable. + SKIP_SYMTAB: if non-zero, ignore the symbol table; used for dSYM files. + FILELINE_FN: set to the fileline function, by backtrace_dwarf_add. + FOUND_SYM: set to non-zero if we found the symbol table. +*/ + +static int +macho_add (struct backtrace_state *state, const char *filename, int descriptor, + off_t offset, const unsigned char *match_uuid, + uintptr_t base_address, int skip_symtab, + backtrace_error_callback error_callback, void *data, + fileline *fileline_fn, int *found_sym) +{ + struct backtrace_view header_view; + struct macho_header_32 header; + off_t hdroffset; + int is_64; + struct backtrace_view cmds_view; + int cmds_view_valid; + struct dwarf_sections dwarf_sections; + int have_dwarf; + unsigned char uuid[MACH_O_UUID_LEN]; + int have_uuid; + size_t cmdoffset; + unsigned int i; + + *found_sym = 0; + + cmds_view_valid = 0; + + /* The 32-bit and 64-bit file headers start out the same, so we can + just always read the 32-bit version. A fat header is shorter but + it will always be followed by data, so it's OK to read extra. */ + + if (!backtrace_get_view (state, descriptor, offset, + sizeof (struct macho_header_32), + error_callback, data, &header_view)) + goto fail; + + memcpy (&header, header_view.data, sizeof header); + + backtrace_release_view (state, &header_view, error_callback, data); + + switch (header.magic) + { + case MACH_O_MH_MAGIC_32: + is_64 = 0; + hdroffset = offset + sizeof (struct macho_header_32); + break; + case MACH_O_MH_MAGIC_64: + is_64 = 1; + hdroffset = offset + sizeof (struct macho_header_64); + break; + case MACH_O_MH_MAGIC_FAT: + case MACH_O_MH_MAGIC_FAT_64: + { + struct macho_header_fat fat_header; + + hdroffset = offset + sizeof (struct macho_header_fat); + memcpy (&fat_header, &header, sizeof fat_header); + return macho_add_fat (state, filename, descriptor, 0, hdroffset, + match_uuid, base_address, skip_symtab, + fat_header.nfat_arch, + header.magic == MACH_O_MH_MAGIC_FAT_64, + error_callback, data, fileline_fn, found_sym); + } + case MACH_O_MH_CIGAM_FAT: + case MACH_O_MH_CIGAM_FAT_64: + { + struct macho_header_fat fat_header; + uint32_t nfat_arch; + + hdroffset = offset + sizeof (struct macho_header_fat); + memcpy (&fat_header, &header, sizeof fat_header); + nfat_arch = __builtin_bswap32 (fat_header.nfat_arch); + return macho_add_fat (state, filename, descriptor, 1, hdroffset, + match_uuid, base_address, skip_symtab, + nfat_arch, + header.magic == MACH_O_MH_CIGAM_FAT_64, + error_callback, data, fileline_fn, found_sym); + } + default: + error_callback (data, "executable file is not in Mach-O format", 0); + goto fail; + } + + switch (header.filetype) + { + case MACH_O_MH_EXECUTE: + case MACH_O_MH_DYLIB: + case MACH_O_MH_DSYM: + break; + default: + error_callback (data, "executable file is not an executable", 0); + goto fail; + } + + if (!backtrace_get_view (state, descriptor, hdroffset, header.sizeofcmds, + error_callback, data, &cmds_view)) + goto fail; + cmds_view_valid = 1; + + memset (&dwarf_sections, 0, sizeof dwarf_sections); + have_dwarf = 0; + memset (&uuid, 0, sizeof uuid); + have_uuid = 0; + + cmdoffset = 0; + for (i = 0; i < header.ncmds; ++i) + { + const char *pcmd; + struct macho_load_command load_command; + + if (cmdoffset + sizeof load_command > header.sizeofcmds) + break; + + pcmd = (const char *) cmds_view.data + cmdoffset; + memcpy (&load_command, pcmd, sizeof load_command); + + switch (load_command.cmd) + { + case MACH_O_LC_SEGMENT: + { + struct macho_segment_command segcmd; + + memcpy (&segcmd, pcmd, sizeof segcmd); + if (memcmp (segcmd.segname, + "__DWARF\0\0\0\0\0\0\0\0\0", + MACH_O_NAMELEN) == 0) + { + if (!macho_add_dwarf_segment (state, descriptor, offset, + load_command.cmd, + pcmd + sizeof segcmd, + (load_command.cmdsize + - sizeof segcmd), + segcmd.nsects, error_callback, + data, &dwarf_sections)) + goto fail; + have_dwarf = 1; + } + } + break; + + case MACH_O_LC_SEGMENT_64: + { + struct macho_segment_64_command segcmd; + + memcpy (&segcmd, pcmd, sizeof segcmd); + if (memcmp (segcmd.segname, + "__DWARF\0\0\0\0\0\0\0\0\0", + MACH_O_NAMELEN) == 0) + { + if (!macho_add_dwarf_segment (state, descriptor, offset, + load_command.cmd, + pcmd + sizeof segcmd, + (load_command.cmdsize + - sizeof segcmd), + segcmd.nsects, error_callback, + data, &dwarf_sections)) + goto fail; + have_dwarf = 1; + } + } + break; + + case MACH_O_LC_SYMTAB: + if (!skip_symtab) + { + struct macho_symtab_command symcmd; + + memcpy (&symcmd, pcmd, sizeof symcmd); + if (!macho_add_symtab (state, descriptor, base_address, is_64, + offset + symcmd.symoff, symcmd.nsyms, + offset + symcmd.stroff, symcmd.strsize, + error_callback, data)) + goto fail; + + *found_sym = 1; + } + break; + + case MACH_O_LC_UUID: + { + struct macho_uuid_command uuidcmd; + + memcpy (&uuidcmd, pcmd, sizeof uuidcmd); + memcpy (&uuid[0], &uuidcmd.uuid[0], MACH_O_UUID_LEN); + have_uuid = 1; + } + break; + + default: + break; + } + + cmdoffset += load_command.cmdsize; + } + + if (!backtrace_close (descriptor, error_callback, data)) + goto fail; + descriptor = -1; + + backtrace_release_view (state, &cmds_view, error_callback, data); + cmds_view_valid = 0; + + if (match_uuid != NULL) + { + /* If we don't have a UUID, or it doesn't match, just ignore + this file. */ + if (!have_uuid + || memcmp (match_uuid, &uuid[0], MACH_O_UUID_LEN) != 0) + return 1; + } + + if (have_dwarf) + { + int is_big_endian; + + is_big_endian = 0; +#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + is_big_endian = 1; +#endif +#endif + + if (!backtrace_dwarf_add (state, base_address, &dwarf_sections, + is_big_endian, NULL, error_callback, data, + fileline_fn, NULL)) + goto fail; + } + + if (!have_dwarf && have_uuid) + { + if (!macho_add_dsym (state, filename, base_address, &uuid[0], + error_callback, data, fileline_fn)) + goto fail; + } + + return 1; + + fail: + if (cmds_view_valid) + backtrace_release_view (state, &cmds_view, error_callback, data); + if (descriptor != -1) + backtrace_close (descriptor, error_callback, data); + return 0; +} + +#ifdef HAVE_MACH_O_DYLD_H + +/* Initialize the backtrace data we need from a Mach-O executable + using the dyld support functions. This closes descriptor. */ + +int +backtrace_initialize (struct backtrace_state *state, const char *filename, + int descriptor, backtrace_error_callback error_callback, + void *data, fileline *fileline_fn) +{ + uint32_t c; + uint32_t i; + int closed_descriptor; + int found_sym; + fileline macho_fileline_fn; + + closed_descriptor = 0; + found_sym = 0; + macho_fileline_fn = macho_nodebug; + + c = _dyld_image_count (); + for (i = 0; i < c; ++i) + { + uintptr_t base_address; + const char *name; + int d; + fileline mff; + int mfs; + + name = _dyld_get_image_name (i); + if (name == NULL) + continue; + + if (strcmp (name, filename) == 0 && !closed_descriptor) + { + d = descriptor; + closed_descriptor = 1; + } + else + { + int does_not_exist; + + d = backtrace_open (name, error_callback, data, &does_not_exist); + if (d < 0) + continue; + } + + base_address = _dyld_get_image_vmaddr_slide (i); + + mff = macho_nodebug; + if (!macho_add (state, name, d, 0, NULL, base_address, 0, + error_callback, data, &mff, &mfs)) + continue; + + if (mff != macho_nodebug) + macho_fileline_fn = mff; + if (mfs) + found_sym = 1; + } + + if (!closed_descriptor) + backtrace_close (descriptor, error_callback, data); + + if (!state->threaded) + { + if (found_sym) + state->syminfo_fn = macho_syminfo; + else if (state->syminfo_fn == NULL) + state->syminfo_fn = macho_nosyms; + } + else + { + if (found_sym) + backtrace_atomic_store_pointer (&state->syminfo_fn, macho_syminfo); + else + (void) __sync_bool_compare_and_swap (&state->syminfo_fn, NULL, + macho_nosyms); + } + + if (!state->threaded) + *fileline_fn = state->fileline_fn; + else + *fileline_fn = backtrace_atomic_load_pointer (&state->fileline_fn); + + if (*fileline_fn == NULL || *fileline_fn == macho_nodebug) + *fileline_fn = macho_fileline_fn; + + return 1; +} + +#else /* !defined (HAVE_MACH_O_DYLD_H) */ + +/* Initialize the backtrace data we need from a Mach-O executable + without using the dyld support functions. This closes + descriptor. */ + +int +backtrace_initialize (struct backtrace_state *state, const char *filename, + int descriptor, backtrace_error_callback error_callback, + void *data, fileline *fileline_fn) +{ + fileline macho_fileline_fn; + int found_sym; + + macho_fileline_fn = macho_nodebug; + if (!macho_add (state, filename, descriptor, 0, NULL, 0, 0, + error_callback, data, &macho_fileline_fn, &found_sym)) + return 0; + + if (!state->threaded) + { + if (found_sym) + state->syminfo_fn = macho_syminfo; + else if (state->syminfo_fn == NULL) + state->syminfo_fn = macho_nosyms; + } + else + { + if (found_sym) + backtrace_atomic_store_pointer (&state->syminfo_fn, macho_syminfo); + else + (void) __sync_bool_compare_and_swap (&state->syminfo_fn, NULL, + macho_nosyms); + } + + if (!state->threaded) + *fileline_fn = state->fileline_fn; + else + *fileline_fn = backtrace_atomic_load_pointer (&state->fileline_fn); + + if (*fileline_fn == NULL || *fileline_fn == macho_nodebug) + *fileline_fn = macho_fileline_fn; + + return 1; +} + +#endif /* !defined (HAVE_MACH_O_DYLD_H) */ diff --git a/contrib/libs/cxxsupp/libcxx/include/experimental/functional b/contrib/libs/cxxsupp/libcxx/include/experimental/functional new file mode 100644 index 0000000000..1291894aa0 --- /dev/null +++ b/contrib/libs/cxxsupp/libcxx/include/experimental/functional @@ -0,0 +1,425 @@ +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _LIBCPP_EXPERIMENTAL_FUNCTIONAL +#define _LIBCPP_EXPERIMENTAL_FUNCTIONAL + +/* + experimental/functional synopsis + +#include <algorithm> + +namespace std { +namespace experimental { +inline namespace fundamentals_v1 { + // 4.3, Searchers + template<class ForwardIterator, class BinaryPredicate = equal_to<>> + class default_searcher; + + template<class RandomAccessIterator, + class Hash = hash<typename iterator_traits<RandomAccessIterator>::value_type>, + class BinaryPredicate = equal_to<>> + class boyer_moore_searcher; + + template<class RandomAccessIterator, + class Hash = hash<typename iterator_traits<RandomAccessIterator>::value_type>, + class BinaryPredicate = equal_to<>> + class boyer_moore_horspool_searcher; + + template<class ForwardIterator, class BinaryPredicate = equal_to<>> + default_searcher<ForwardIterator, BinaryPredicate> + make_default_searcher(ForwardIterator pat_first, ForwardIterator pat_last, + BinaryPredicate pred = BinaryPredicate()); + + template<class RandomAccessIterator, + class Hash = hash<typename iterator_traits<RandomAccessIterator>::value_type>, + class BinaryPredicate = equal_to<>> + boyer_moore_searcher<RandomAccessIterator, Hash, BinaryPredicate> + make_boyer_moore_searcher( + RandomAccessIterator pat_first, RandomAccessIterator pat_last, + Hash hf = Hash(), BinaryPredicate pred = BinaryPredicate()); + + template<class RandomAccessIterator, + class Hash = hash<typename iterator_traits<RandomAccessIterator>::value_type>, + class BinaryPredicate = equal_to<>> + boyer_moore_horspool_searcher<RandomAccessIterator, Hash, BinaryPredicate> + make_boyer_moore_horspool_searcher( + RandomAccessIterator pat_first, RandomAccessIterator pat_last, + Hash hf = Hash(), BinaryPredicate pred = BinaryPredicate()); + + } // namespace fundamentals_v1 + } // namespace experimental + +} // namespace std + +*/ + +#include <__debug> +#include <__memory/uses_allocator.h> +#include <array> +#include <experimental/__config> +#include <functional> +#include <type_traits> +#include <unordered_map> +#include <vector> + +#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) +# pragma GCC system_header +#endif + +_LIBCPP_PUSH_MACROS +#include <__undef_macros> + +_LIBCPP_BEGIN_NAMESPACE_LFTS + +#if _LIBCPP_STD_VER > 11 +// default searcher +template<class _ForwardIterator, class _BinaryPredicate = equal_to<>> +class _LIBCPP_TEMPLATE_VIS default_searcher { +public: + _LIBCPP_INLINE_VISIBILITY + default_searcher(_ForwardIterator __f, _ForwardIterator __l, + _BinaryPredicate __p = _BinaryPredicate()) + : __first_(__f), __last_(__l), __pred_(__p) {} + + template <typename _ForwardIterator2> + _LIBCPP_INLINE_VISIBILITY + pair<_ForwardIterator2, _ForwardIterator2> + operator () (_ForwardIterator2 __f, _ForwardIterator2 __l) const + { + return _VSTD::__search(__f, __l, __first_, __last_, __pred_, + typename iterator_traits<_ForwardIterator>::iterator_category(), + typename iterator_traits<_ForwardIterator2>::iterator_category()); + } + +private: + _ForwardIterator __first_; + _ForwardIterator __last_; + _BinaryPredicate __pred_; + }; + +template<class _ForwardIterator, class _BinaryPredicate = equal_to<>> +_LIBCPP_INLINE_VISIBILITY +default_searcher<_ForwardIterator, _BinaryPredicate> +make_default_searcher( _ForwardIterator __f, _ForwardIterator __l, _BinaryPredicate __p = _BinaryPredicate ()) +{ + return default_searcher<_ForwardIterator, _BinaryPredicate>(__f, __l, __p); +} + +template<class _Key, class _Value, class _Hash, class _BinaryPredicate, bool /*useArray*/> class _BMSkipTable; + +// General case for BM data searching; use a map +template<class _Key, typename _Value, class _Hash, class _BinaryPredicate> +class _BMSkipTable<_Key, _Value, _Hash, _BinaryPredicate, false> { + typedef _Value value_type; + typedef _Key key_type; + + const _Value __default_value_; + std::unordered_map<_Key, _Value, _Hash, _BinaryPredicate> __table; + +public: + _LIBCPP_INLINE_VISIBILITY + _BMSkipTable(size_t __sz, _Value __default, _Hash __hf, _BinaryPredicate __pred) + : __default_value_(__default), __table(__sz, __hf, __pred) {} + + _LIBCPP_INLINE_VISIBILITY + void insert(const key_type &__key, value_type __val) + { + __table [__key] = __val; // Would skip_.insert (val) be better here? + } + + _LIBCPP_INLINE_VISIBILITY + value_type operator [](const key_type & __key) const + { + auto __it = __table.find (__key); + return __it == __table.end() ? __default_value_ : __it->second; + } +}; + + +// Special case small numeric values; use an array +template<class _Key, typename _Value, class _Hash, class _BinaryPredicate> +class _BMSkipTable<_Key, _Value, _Hash, _BinaryPredicate, true> { +private: + typedef _Value value_type; + typedef _Key key_type; + + typedef typename make_unsigned<key_type>::type unsigned_key_type; + typedef std::array<value_type, numeric_limits<unsigned_key_type>::max()> skip_map; + skip_map __table; + +public: + _LIBCPP_INLINE_VISIBILITY + _BMSkipTable(size_t /*__sz*/, _Value __default, _Hash /*__hf*/, _BinaryPredicate /*__pred*/) + { + std::fill_n(__table.begin(), __table.size(), __default); + } + + _LIBCPP_INLINE_VISIBILITY + void insert(key_type __key, value_type __val) + { + __table[static_cast<unsigned_key_type>(__key)] = __val; + } + + _LIBCPP_INLINE_VISIBILITY + value_type operator [](key_type __key) const + { + return __table[static_cast<unsigned_key_type>(__key)]; + } +}; + + +template <class _RandomAccessIterator1, + class _Hash = hash<typename iterator_traits<_RandomAccessIterator1>::value_type>, + class _BinaryPredicate = equal_to<>> +class _LIBCPP_TEMPLATE_VIS boyer_moore_searcher { +private: + typedef typename std::iterator_traits<_RandomAccessIterator1>::difference_type difference_type; + typedef typename std::iterator_traits<_RandomAccessIterator1>::value_type value_type; + typedef _BMSkipTable<value_type, difference_type, _Hash, _BinaryPredicate, + is_integral<value_type>::value && // what about enums? + sizeof(value_type) == 1 && + is_same<_Hash, hash<value_type>>::value && + is_same<_BinaryPredicate, equal_to<>>::value + > skip_table_type; + +public: + boyer_moore_searcher(_RandomAccessIterator1 __f, _RandomAccessIterator1 __l, + _Hash __hf = _Hash(), _BinaryPredicate __pred = _BinaryPredicate()) + : __first_(__f), __last_(__l), __pred_(__pred), + __pattern_length_(_VSTD::distance(__first_, __last_)), + __skip_{make_shared<skip_table_type>(__pattern_length_, -1, __hf, __pred_)}, + __suffix_{make_shared<vector<difference_type>>(__pattern_length_ + 1)} + { + // build the skip table + for ( difference_type __i = 0; __f != __l; ++__f, (void) ++__i ) + __skip_->insert(*__f, __i); + + this->__build_suffix_table ( __first_, __last_, __pred_ ); + } + + template <typename _RandomAccessIterator2> + pair<_RandomAccessIterator2, _RandomAccessIterator2> + operator ()(_RandomAccessIterator2 __f, _RandomAccessIterator2 __l) const + { + static_assert(__is_same_uncvref<typename iterator_traits<_RandomAccessIterator1>::value_type, + typename iterator_traits<_RandomAccessIterator2>::value_type>::value, + "Corpus and Pattern iterators must point to the same type"); + + if (__f == __l ) return make_pair(__l, __l); // empty corpus + if (__first_ == __last_) return make_pair(__f, __f); // empty pattern + + // If the pattern is larger than the corpus, we can't find it! + if ( __pattern_length_ > _VSTD::distance(__f, __l)) + return make_pair(__l, __l); + + // Do the search + return this->__search(__f, __l); + } + +private: + _RandomAccessIterator1 __first_; + _RandomAccessIterator1 __last_; + _BinaryPredicate __pred_; + difference_type __pattern_length_; + shared_ptr<skip_table_type> __skip_; + shared_ptr<vector<difference_type>> __suffix_; + + template <typename _RandomAccessIterator2> + pair<_RandomAccessIterator2, _RandomAccessIterator2> + __search(_RandomAccessIterator2 __f, _RandomAccessIterator2 __l) const + { + _RandomAccessIterator2 __cur = __f; + const _RandomAccessIterator2 __last = __l - __pattern_length_; + const skip_table_type & __skip = *__skip_.get(); + const vector<difference_type> & __suffix = *__suffix_.get(); + + while (__cur <= __last) + { + + // Do we match right where we are? + difference_type __j = __pattern_length_; + while (__pred_(__first_ [__j-1], __cur [__j-1])) { + __j--; + // We matched - we're done! + if ( __j == 0 ) + return make_pair(__cur, __cur + __pattern_length_); + } + + // Since we didn't match, figure out how far to skip forward + difference_type __k = __skip[__cur [ __j - 1 ]]; + difference_type __m = __j - __k - 1; + if (__k < __j && __m > __suffix[ __j ]) + __cur += __m; + else + __cur += __suffix[ __j ]; + } + + return make_pair(__l, __l); // We didn't find anything + } + + + template<typename _Iterator, typename _Container> + void __compute_bm_prefix ( _Iterator __f, _Iterator __l, _BinaryPredicate __pred, _Container &__prefix ) + { + const size_t __count = _VSTD::distance(__f, __l); + + __prefix[0] = 0; + size_t __k = 0; + for ( size_t __i = 1; __i < __count; ++__i ) + { + while ( __k > 0 && !__pred ( __f[__k], __f[__i] )) + __k = __prefix [ __k - 1 ]; + + if ( __pred ( __f[__k], __f[__i] )) + __k++; + __prefix [ __i ] = __k; + } + } + + void __build_suffix_table(_RandomAccessIterator1 __f, _RandomAccessIterator1 __l, + _BinaryPredicate __pred) + { + const size_t __count = _VSTD::distance(__f, __l); + vector<difference_type> & __suffix = *__suffix_.get(); + if (__count > 0) + { + vector<value_type> __scratch(__count); + + __compute_bm_prefix(__f, __l, __pred, __scratch); + for ( size_t __i = 0; __i <= __count; __i++ ) + __suffix[__i] = __count - __scratch[__count-1]; + + typedef reverse_iterator<_RandomAccessIterator1> _RevIter; + __compute_bm_prefix(_RevIter(__l), _RevIter(__f), __pred, __scratch); + + for ( size_t __i = 0; __i < __count; __i++ ) + { + const size_t __j = __count - __scratch[__i]; + const difference_type __k = __i - __scratch[__i] + 1; + + if (__suffix[__j] > __k) + __suffix[__j] = __k; + } + } + } + +}; + +template<class _RandomAccessIterator, + class _Hash = hash<typename iterator_traits<_RandomAccessIterator>::value_type>, + class _BinaryPredicate = equal_to<>> +_LIBCPP_INLINE_VISIBILITY +boyer_moore_searcher<_RandomAccessIterator, _Hash, _BinaryPredicate> +make_boyer_moore_searcher( _RandomAccessIterator __f, _RandomAccessIterator __l, + _Hash __hf = _Hash(), _BinaryPredicate __p = _BinaryPredicate ()) +{ + return boyer_moore_searcher<_RandomAccessIterator, _Hash, _BinaryPredicate>(__f, __l, __hf, __p); +} + +// boyer-moore-horspool +template <class _RandomAccessIterator1, + class _Hash = hash<typename iterator_traits<_RandomAccessIterator1>::value_type>, + class _BinaryPredicate = equal_to<>> +class _LIBCPP_TEMPLATE_VIS boyer_moore_horspool_searcher { +private: + typedef typename std::iterator_traits<_RandomAccessIterator1>::difference_type difference_type; + typedef typename std::iterator_traits<_RandomAccessIterator1>::value_type value_type; + typedef _BMSkipTable<value_type, difference_type, _Hash, _BinaryPredicate, + is_integral<value_type>::value && // what about enums? + sizeof(value_type) == 1 && + is_same<_Hash, hash<value_type>>::value && + is_same<_BinaryPredicate, equal_to<>>::value + > skip_table_type; + +public: + boyer_moore_horspool_searcher(_RandomAccessIterator1 __f, _RandomAccessIterator1 __l, + _Hash __hf = _Hash(), _BinaryPredicate __pred = _BinaryPredicate()) + : __first_(__f), __last_(__l), __pred_(__pred), + __pattern_length_(_VSTD::distance(__first_, __last_)), + __skip_{_VSTD::make_shared<skip_table_type>(__pattern_length_, __pattern_length_, __hf, __pred_)} + { + // build the skip table + if ( __f != __l ) + { + __l = __l - 1; + for ( difference_type __i = 0; __f != __l; ++__f, (void) ++__i ) + __skip_->insert(*__f, __pattern_length_ - 1 - __i); + } + } + + template <typename _RandomAccessIterator2> + pair<_RandomAccessIterator2, _RandomAccessIterator2> + operator ()(_RandomAccessIterator2 __f, _RandomAccessIterator2 __l) const + { + static_assert(__is_same_uncvref<typename std::iterator_traits<_RandomAccessIterator1>::value_type, + typename std::iterator_traits<_RandomAccessIterator2>::value_type>::value, + "Corpus and Pattern iterators must point to the same type"); + + if (__f == __l ) return make_pair(__l, __l); // empty corpus + if (__first_ == __last_) return make_pair(__f, __f); // empty pattern + + // If the pattern is larger than the corpus, we can't find it! + if ( __pattern_length_ > _VSTD::distance(__f, __l)) + return make_pair(__l, __l); + + // Do the search + return this->__search(__f, __l); + } + +private: + _RandomAccessIterator1 __first_; + _RandomAccessIterator1 __last_; + _BinaryPredicate __pred_; + difference_type __pattern_length_; + shared_ptr<skip_table_type> __skip_; + + template <typename _RandomAccessIterator2> + pair<_RandomAccessIterator2, _RandomAccessIterator2> + __search ( _RandomAccessIterator2 __f, _RandomAccessIterator2 __l ) const { + _RandomAccessIterator2 __cur = __f; + const _RandomAccessIterator2 __last = __l - __pattern_length_; + const skip_table_type & __skip = *__skip_.get(); + + while (__cur <= __last) + { + // Do we match right where we are? + difference_type __j = __pattern_length_; + while (__pred_(__first_[__j-1], __cur[__j-1])) + { + __j--; + // We matched - we're done! + if ( __j == 0 ) + return make_pair(__cur, __cur + __pattern_length_); + } + __cur += __skip[__cur[__pattern_length_-1]]; + } + + return make_pair(__l, __l); + } +}; + +template<class _RandomAccessIterator, + class _Hash = hash<typename iterator_traits<_RandomAccessIterator>::value_type>, + class _BinaryPredicate = equal_to<>> +_LIBCPP_INLINE_VISIBILITY +boyer_moore_horspool_searcher<_RandomAccessIterator, _Hash, _BinaryPredicate> +make_boyer_moore_horspool_searcher( _RandomAccessIterator __f, _RandomAccessIterator __l, + _Hash __hf = _Hash(), _BinaryPredicate __p = _BinaryPredicate ()) +{ + return boyer_moore_horspool_searcher<_RandomAccessIterator, _Hash, _BinaryPredicate>(__f, __l, __hf, __p); +} + +#endif // _LIBCPP_STD_VER > 11 + +_LIBCPP_END_NAMESPACE_LFTS + +_LIBCPP_POP_MACROS + +#endif /* _LIBCPP_EXPERIMENTAL_FUNCTIONAL */ diff --git a/contrib/libs/sparsehash/src/sparsehash/dense_hash_set b/contrib/libs/sparsehash/src/sparsehash/dense_hash_set new file mode 100644 index 0000000000..050b15d1d5 --- /dev/null +++ b/contrib/libs/sparsehash/src/sparsehash/dense_hash_set @@ -0,0 +1,338 @@ +// Copyright (c) 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// This is just a very thin wrapper over densehashtable.h, just +// like sgi stl's stl_hash_set is a very thin wrapper over +// stl_hashtable. The major thing we define is operator[], because +// we have a concept of a data_type which stl_hashtable doesn't +// (it only has a key and a value). +// +// This is more different from dense_hash_map than you might think, +// because all iterators for sets are const (you obviously can't +// change the key, and for sets there is no value). +// +// NOTE: this is exactly like sparse_hash_set.h, with the word +// "sparse" replaced by "dense", except for the addition of +// set_empty_key(). +// +// YOU MUST CALL SET_EMPTY_KEY() IMMEDIATELY AFTER CONSTRUCTION. +// +// Otherwise your program will die in mysterious ways. (Note if you +// use the constructor that takes an InputIterator range, you pass in +// the empty key in the constructor, rather than after. As a result, +// this constructor differs from the standard STL version.) +// +// In other respects, we adhere mostly to the STL semantics for +// hash-map. One important exception is that insert() may invalidate +// iterators entirely -- STL semantics are that insert() may reorder +// iterators, but they all still refer to something valid in the +// hashtable. Not so for us. Likewise, insert() may invalidate +// pointers into the hashtable. (Whether insert invalidates iterators +// and pointers depends on whether it results in a hashtable resize). +// On the plus side, delete() doesn't invalidate iterators or pointers +// at all, or even change the ordering of elements. +// +// Here are a few "power user" tips: +// +// 1) set_deleted_key(): +// If you want to use erase() you must call set_deleted_key(), +// in addition to set_empty_key(), after construction. +// The deleted and empty keys must differ. +// +// 2) resize(0): +// When an item is deleted, its memory isn't freed right +// away. This allows you to iterate over a hashtable, +// and call erase(), without invalidating the iterator. +// To force the memory to be freed, call resize(0). +// For tr1 compatibility, this can also be called as rehash(0). +// +// 3) min_load_factor(0.0) +// Setting the minimum load factor to 0.0 guarantees that +// the hash table will never shrink. +// +// Roughly speaking: +// (1) dense_hash_set: fastest, uses the most memory unless entries are small +// (2) sparse_hash_set: slowest, uses the least memory +// (3) hash_set / unordered_set (STL): in the middle +// +// Typically I use sparse_hash_set when I care about space and/or when +// I need to save the hashtable on disk. I use hash_set otherwise. I +// don't personally use dense_hash_set ever; some people use it for +// small sets with lots of lookups. +// +// - dense_hash_set has, typically, about 78% memory overhead (if your +// data takes up X bytes, the hash_set uses .78X more bytes in overhead). +// - sparse_hash_set has about 4 bits overhead per entry. +// - sparse_hash_set can be 3-7 times slower than the others for lookup and, +// especially, inserts. See time_hash_map.cc for details. +// +// See /usr/(local/)?doc/sparsehash-*/dense_hash_set.html +// for information about how to use this class. + +#ifndef _DENSE_HASH_SET_H_ +#define _DENSE_HASH_SET_H_ + +#include <sparsehash/internal/sparseconfig.h> +#include <algorithm> // needed by stl_alloc +#include <functional> // for equal_to<>, select1st<>, etc +#include <memory> // for alloc +#include <utility> // for pair<> +#include <sparsehash/internal/densehashtable.h> // IWYU pragma: export +#include <sparsehash/internal/libc_allocator_with_realloc.h> +#include HASH_FUN_H // for hash<> +_START_GOOGLE_NAMESPACE_ + +template <class Value, + class HashFcn = SPARSEHASH_HASH<Value>, // defined in sparseconfig.h + class EqualKey = std::equal_to<Value>, + class Alloc = libc_allocator_with_realloc<Value> > +class dense_hash_set { + private: + // Apparently identity is not stl-standard, so we define our own + struct Identity { + typedef const Value& result_type; + const Value& operator()(const Value& v) const { return v; } + }; + struct SetKey { + void operator()(Value* value, const Value& new_key) const { + *value = new_key; + } + }; + + // The actual data + typedef dense_hashtable<Value, Value, HashFcn, Identity, SetKey, + EqualKey, Alloc> ht; + ht rep; + + public: + typedef typename ht::key_type key_type; + typedef typename ht::value_type value_type; + typedef typename ht::hasher hasher; + typedef typename ht::key_equal key_equal; + typedef Alloc allocator_type; + + typedef typename ht::size_type size_type; + typedef typename ht::difference_type difference_type; + typedef typename ht::const_pointer pointer; + typedef typename ht::const_pointer const_pointer; + typedef typename ht::const_reference reference; + typedef typename ht::const_reference const_reference; + + typedef typename ht::const_iterator iterator; + typedef typename ht::const_iterator const_iterator; + typedef typename ht::const_local_iterator local_iterator; + typedef typename ht::const_local_iterator const_local_iterator; + + + // Iterator functions -- recall all iterators are const + iterator begin() const { return rep.begin(); } + iterator end() const { return rep.end(); } + + // These come from tr1's unordered_set. For us, a bucket has 0 or 1 elements. + local_iterator begin(size_type i) const { return rep.begin(i); } + local_iterator end(size_type i) const { return rep.end(i); } + + + // Accessor functions + allocator_type get_allocator() const { return rep.get_allocator(); } + hasher hash_funct() const { return rep.hash_funct(); } + hasher hash_function() const { return hash_funct(); } // tr1 name + key_equal key_eq() const { return rep.key_eq(); } + + + // Constructors + explicit dense_hash_set(size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, Identity(), SetKey(), alloc) { + } + + template <class InputIterator> + dense_hash_set(InputIterator f, InputIterator l, + const key_type& empty_key_val, + size_type expected_max_items_in_table = 0, + const hasher& hf = hasher(), + const key_equal& eql = key_equal(), + const allocator_type& alloc = allocator_type()) + : rep(expected_max_items_in_table, hf, eql, Identity(), SetKey(), alloc) { + set_empty_key(empty_key_val); + rep.insert(f, l); + } + // We use the default copy constructor + // We use the default operator=() + // We use the default destructor + + void clear() { rep.clear(); } + // This clears the hash set without resizing it down to the minimum + // bucket count, but rather keeps the number of buckets constant + void clear_no_resize() { rep.clear_no_resize(); } + void swap(dense_hash_set& hs) { rep.swap(hs.rep); } + + + // Functions concerning size + size_type size() const { return rep.size(); } + size_type max_size() const { return rep.max_size(); } + bool empty() const { return rep.empty(); } + size_type bucket_count() const { return rep.bucket_count(); } + size_type max_bucket_count() const { return rep.max_bucket_count(); } + + // These are tr1 methods. bucket() is the bucket the key is or would be in. + size_type bucket_size(size_type i) const { return rep.bucket_size(i); } + size_type bucket(const key_type& key) const { return rep.bucket(key); } + float load_factor() const { + return size() * 1.0f / bucket_count(); + } + float max_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return grow; + } + void max_load_factor(float new_grow) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(shrink, new_grow); + } + // These aren't tr1 methods but perhaps ought to be. + float min_load_factor() const { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + return shrink; + } + void min_load_factor(float new_shrink) { + float shrink, grow; + rep.get_resizing_parameters(&shrink, &grow); + rep.set_resizing_parameters(new_shrink, grow); + } + // Deprecated; use min_load_factor() or max_load_factor() instead. + void set_resizing_parameters(float shrink, float grow) { + rep.set_resizing_parameters(shrink, grow); + } + + void resize(size_type hint) { rep.resize(hint); } + void rehash(size_type hint) { resize(hint); } // the tr1 name + + // Lookup routines + iterator find(const key_type& key) const { return rep.find(key); } + + size_type count(const key_type& key) const { return rep.count(key); } + + std::pair<iterator, iterator> equal_range(const key_type& key) const { + return rep.equal_range(key); + } + + + // Insertion routines + std::pair<iterator, bool> insert(const value_type& obj) { + std::pair<typename ht::iterator, bool> p = rep.insert(obj); + return std::pair<iterator, bool>(p.first, p.second); // const to non-const + } + template <class InputIterator> void insert(InputIterator f, InputIterator l) { + rep.insert(f, l); + } + void insert(const_iterator f, const_iterator l) { + rep.insert(f, l); + } + // Required for std::insert_iterator; the passed-in iterator is ignored. + iterator insert(iterator, const value_type& obj) { + return insert(obj).first; + } + + // Deletion and empty routines + // THESE ARE NON-STANDARD! I make you specify an "impossible" key + // value to identify deleted and empty buckets. You can change the + // deleted key as time goes on, or get rid of it entirely to be insert-only. + void set_empty_key(const key_type& key) { rep.set_empty_key(key); } + key_type empty_key() const { return rep.empty_key(); } + + void set_deleted_key(const key_type& key) { rep.set_deleted_key(key); } + void clear_deleted_key() { rep.clear_deleted_key(); } + key_type deleted_key() const { return rep.deleted_key(); } + + // These are standard + size_type erase(const key_type& key) { return rep.erase(key); } + void erase(iterator it) { rep.erase(it); } + void erase(iterator f, iterator l) { rep.erase(f, l); } + + + // Comparison + bool operator==(const dense_hash_set& hs) const { return rep == hs.rep; } + bool operator!=(const dense_hash_set& hs) const { return rep != hs.rep; } + + + // I/O -- this is an add-on for writing metainformation to disk + // + // For maximum flexibility, this does not assume a particular + // file type (though it will probably be a FILE *). We just pass + // the fp through to rep. + + // If your keys and values are simple enough, you can pass this + // serializer to serialize()/unserialize(). "Simple enough" means + // value_type is a POD type that contains no pointers. Note, + // however, we don't try to normalize endianness. + typedef typename ht::NopointerSerializer NopointerSerializer; + + // serializer: a class providing operator()(OUTPUT*, const value_type&) + // (writing value_type to OUTPUT). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an ostream*/subclass_of_ostream*, OR a + // pointer to a class providing size_t Write(const void*, size_t), + // which writes a buffer into a stream (which fp presumably + // owns) and returns the number of bytes successfully written. + // Note basic_ostream<not_char> is not currently supported. + template <typename ValueSerializer, typename OUTPUT> + bool serialize(ValueSerializer serializer, OUTPUT* fp) { + return rep.serialize(serializer, fp); + } + + // serializer: a functor providing operator()(INPUT*, value_type*) + // (reading from INPUT and into value_type). You can specify a + // NopointerSerializer object if appropriate (see above). + // fp: either a FILE*, OR an istream*/subclass_of_istream*, OR a + // pointer to a class providing size_t Read(void*, size_t), + // which reads into a buffer from a stream (which fp presumably + // owns) and returns the number of bytes successfully read. + // Note basic_istream<not_char> is not currently supported. + template <typename ValueSerializer, typename INPUT> + bool unserialize(ValueSerializer serializer, INPUT* fp) { + return rep.unserialize(serializer, fp); + } +}; + +template <class Val, class HashFcn, class EqualKey, class Alloc> +inline void swap(dense_hash_set<Val, HashFcn, EqualKey, Alloc>& hs1, + dense_hash_set<Val, HashFcn, EqualKey, Alloc>& hs2) { + hs1.swap(hs2); +} + +_END_GOOGLE_NAMESPACE_ + +#endif /* _DENSE_HASH_SET_H_ */ diff --git a/library/cpp/porto/libporto.cpp b/library/cpp/porto/libporto.cpp new file mode 100644 index 0000000000..8fd8924300 --- /dev/null +++ b/library/cpp/porto/libporto.cpp @@ -0,0 +1,1547 @@ +#include "libporto.hpp" +#include "metrics.hpp" + +#include <google/protobuf/text_format.h> +#include <google/protobuf/io/zero_copy_stream_impl.h> +#include <google/protobuf/io/coded_stream.h> + +extern "C" { +#include <errno.h> +#include <time.h> +#include <unistd.h> +#include <sys/socket.h> +#include <sys/un.h> + +#ifndef __linux__ +#include <fcntl.h> +#else +#include <sys/epoll.h> +#endif +} + +namespace Porto { + +TPortoApi::~TPortoApi() { + Disconnect(); +} + +EError TPortoApi::SetError(const TString &prefix, int _errno) { + LastErrorMsg = prefix + ": " + strerror(_errno); + + switch (_errno) { + case ENOENT: + LastError = EError::SocketUnavailable; + break; + case EAGAIN: + LastErrorMsg = prefix + ": Timeout exceeded. Timeout value: " + std::to_string(Timeout); + LastError = EError::SocketTimeout; + break; + case EIO: + case EPIPE: + LastError = EError::SocketError; + break; + default: + LastError = EError::Unknown; + break; + } + + Disconnect(); + return LastError; +} + +TString TPortoApi::GetLastError() const { + return EError_Name(LastError) + ":(" + LastErrorMsg + ")"; +} + +EError TPortoApi::Connect(const char *socket_path) { + struct sockaddr_un peer_addr; + socklen_t peer_addr_size; + + Disconnect(); + +#ifdef __linux__ + Fd = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0); + if (Fd < 0) + return SetError("socket", errno); +#else + Fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (Fd < 0) + return SetError("socket", errno); + if (fcntl(Fd, F_SETFD, FD_CLOEXEC) < 0) + return SetError("fcntl FD_CLOEXEC", errno); +#endif + + if (Timeout > 0 && SetSocketTimeout(3, Timeout)) + return LastError; + + memset(&peer_addr, 0, sizeof(struct sockaddr_un)); + peer_addr.sun_family = AF_UNIX; + strncpy(peer_addr.sun_path, socket_path, strlen(socket_path)); + + peer_addr_size = sizeof(struct sockaddr_un); + if (connect(Fd, (struct sockaddr *) &peer_addr, peer_addr_size) < 0) + return SetError("connect", errno); + + /* Restore async wait state */ + if (!AsyncWaitNames.empty()) { + for (auto &name: AsyncWaitNames) + Req.mutable_asyncwait()->add_name(name); + for (auto &label: AsyncWaitLabels) + Req.mutable_asyncwait()->add_label(label); + if (AsyncWaitTimeout >= 0) + Req.mutable_asyncwait()->set_timeout_ms(AsyncWaitTimeout * 1000); + return Call(); + } + + return EError::Success; +} + +void TPortoApi::Disconnect() { + if (Fd >= 0) + close(Fd); + Fd = -1; +} + +EError TPortoApi::SetSocketTimeout(int direction, int timeout) { + struct timeval tv; + + if (Fd < 0) + return EError::Success; + + tv.tv_sec = timeout > 0 ? timeout : 0; + tv.tv_usec = 0; + + if ((direction & 1) && setsockopt(Fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof tv)) + return SetError("setsockopt SO_SNDTIMEO", errno); + + if ((direction & 2) && setsockopt(Fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof tv)) + return SetError("setsockopt SO_RCVTIMEO", errno); + + return EError::Success; +} + +EError TPortoApi::SetTimeout(int timeout) { + Timeout = timeout ? timeout : DEFAULT_TIMEOUT; + return SetSocketTimeout(3, Timeout); +} + +EError TPortoApi::SetDiskTimeout(int timeout) { + DiskTimeout = timeout ? timeout : DEFAULT_DISK_TIMEOUT; + return EError::Success; +} + +EError TPortoApi::Send(const TPortoRequest &req) { + google::protobuf::io::FileOutputStream raw(Fd); + + if (!req.IsInitialized()) { + LastError = EError::InvalidMethod; + LastErrorMsg = "Request is not initialized"; + return EError::InvalidMethod; + } + + { + google::protobuf::io::CodedOutputStream output(&raw); + + output.WriteVarint32(req.ByteSize()); + req.SerializeWithCachedSizes(&output); + } + + raw.Flush(); + + int err = raw.GetErrno(); + if (err) + return SetError("send", err); + + return EError::Success; +} + +EError TPortoApi::Recv(TPortoResponse &rsp) { + google::protobuf::io::FileInputStream raw(Fd); + google::protobuf::io::CodedInputStream input(&raw); + + while (true) { + uint32_t size; + + if (!input.ReadVarint32(&size)) + return SetError("recv", raw.GetErrno() ?: EIO); + + auto prev_limit = input.PushLimit(size); + + rsp.Clear(); + + if (!rsp.ParseFromCodedStream(&input)) + return SetError("recv", raw.GetErrno() ?: EIO); + + input.PopLimit(prev_limit); + + if (rsp.has_asyncwait()) { + if (AsyncWaitCallback) + AsyncWaitCallback(rsp.asyncwait()); + + if (AsyncWaitOneShot) + return EError::Success; + + continue; + } + + return EError::Success; + } +} + +EError TPortoApi::Call(const TPortoRequest &req, + TPortoResponse &rsp, + int extra_timeout) { + bool reconnect = AutoReconnect; + EError err = EError::Success; + + if (Fd < 0) { + if (!reconnect) + return SetError("Not connected", EIO); + err = Connect(); + reconnect = false; + } + + if (!err) { + err = Send(req); + if (err == EError::SocketError && reconnect) { + err = Connect(); + if (!err) + err = Send(req); + } + } + + if (!err && extra_timeout && Timeout > 0) + err = SetSocketTimeout(2, extra_timeout > 0 ? (extra_timeout + Timeout) : -1); + + if (!err) + err = Recv(rsp); + + if (extra_timeout && Timeout > 0) { + EError err = SetSocketTimeout(2, Timeout); + (void)err; + } + + if (!err) { + err = LastError = rsp.error(); + LastErrorMsg = rsp.errormsg(); + } + + return err; +} + +EError TPortoApi::Call(int extra_timeout) { + return Call(Req, Rsp, extra_timeout); +} + +EError TPortoApi::Call(const TString &req, + TString &rsp, + int extra_timeout) { + Req.Clear(); + if (!google::protobuf::TextFormat::ParseFromString(req, &Req)) { + LastError = EError::InvalidMethod; + LastErrorMsg = "Cannot parse request"; + rsp = ""; + return EError::InvalidMethod; + } + + EError err = Call(Req, Rsp, extra_timeout); + + rsp = Rsp.DebugString(); + + return err; +} + +EError TPortoApi::GetVersion(TString &tag, TString &revision) { + Req.Clear(); + Req.mutable_version(); + + if (!Call()) { + tag = Rsp.version().tag(); + revision = Rsp.version().revision(); + } + + return LastError; +} + +const TGetSystemResponse *TPortoApi::GetSystem() { + Req.Clear(); + Req.mutable_getsystem(); + if (!Call()) + return &Rsp.getsystem(); + return nullptr; +} + +EError TPortoApi::SetSystem(const TString &key, const TString &val) { + TString rsp; + return Call("SetSystem {" + key + ":" + val + "}", rsp); +} + +/* Container */ + +EError TPortoApi::Create(const TString &name) { + Req.Clear(); + auto req = Req.mutable_create(); + req->set_name(name); + return Call(); +} + +EError TPortoApi::CreateWeakContainer(const TString &name) { + Req.Clear(); + auto req = Req.mutable_createweak(); + req->set_name(name); + return Call(); +} + +EError TPortoApi::Destroy(const TString &name) { + Req.Clear(); + auto req = Req.mutable_destroy(); + req->set_name(name); + return Call(); +} + +const TListResponse *TPortoApi::List(const TString &mask) { + Req.Clear(); + auto req = Req.mutable_list(); + + if(!mask.empty()) + req->set_mask(mask); + + if (!Call()) + return &Rsp.list(); + + return nullptr; +} + +EError TPortoApi::List(TVector<TString> &list, const TString &mask) { + Req.Clear(); + auto req = Req.mutable_list(); + if(!mask.empty()) + req->set_mask(mask); + if (!Call()) + list = TVector<TString>(std::begin(Rsp.list().name()), + std::end(Rsp.list().name())); + return LastError; +} + +const TListPropertiesResponse *TPortoApi::ListProperties() { + Req.Clear(); + Req.mutable_listproperties(); + + if (Call()) + return nullptr; + + bool has_data = false; + for (const auto &prop: Rsp.listproperties().list()) { + if (prop.read_only()) { + has_data = true; + break; + } + } + + if (!has_data) { + TPortoRequest req; + TPortoResponse rsp; + + req.mutable_listdataproperties(); + if (!Call(req, rsp)) { + for (const auto &data: rsp.listdataproperties().list()) { + auto d = Rsp.mutable_listproperties()->add_list(); + d->set_name(data.name()); + d->set_desc(data.desc()); + d->set_read_only(true); + } + } + } + + return &Rsp.listproperties(); +} + +EError TPortoApi::ListProperties(TVector<TString> &properties) { + properties.clear(); + auto rsp = ListProperties(); + if (rsp) { + for (auto &prop: rsp->list()) + properties.push_back(prop.name()); + } + return LastError; +} + +const TGetResponse *TPortoApi::Get(const TVector<TString> &names, + const TVector<TString> &vars, + int flags) { + Req.Clear(); + auto get = Req.mutable_get(); + + for (const auto &n : names) + get->add_name(n); + + for (const auto &v : vars) + get->add_variable(v); + + if (flags & GET_NONBLOCK) + get->set_nonblock(true); + if (flags & GET_SYNC) + get->set_sync(true); + if (flags & GET_REAL) + get->set_real(true); + + if (!Call()) + return &Rsp.get(); + + return nullptr; +} + +EError TPortoApi::GetContainerSpec(const TString &name, TContainer &container) { + Req.Clear(); + TListContainersRequest req; + auto filter = req.add_filters(); + filter->set_name(name); + + TVector<TContainer> containers; + + auto ret = ListContainersBy(req, containers); + if (containers.empty()) + return EError::ContainerDoesNotExist; + + if (!ret) + container = containers[0]; + + return ret; +} + +EError TPortoApi::ListContainersBy(const TListContainersRequest &listContainersRequest, TVector<TContainer> &containers) { + Req.Clear(); + auto req = Req.mutable_listcontainersby(); + *req = listContainersRequest; + + auto ret = Call(); + if (ret) + return ret; + + for (auto &ct : Rsp.listcontainersby().containers()) + containers.push_back(ct); + + return EError::Success; +} + +EError TPortoApi::CreateFromSpec(const TContainerSpec &container, TVector<TVolumeSpec> volumes, bool start) { + Req.Clear(); + auto req = Req.mutable_createfromspec(); + + auto ct = req->mutable_container(); + *ct = container; + + for (auto &volume : volumes) { + auto v = req->add_volumes(); + *v = volume; + } + + req->set_start(start); + + return Call(); +} + +EError TPortoApi::UpdateFromSpec(const TContainerSpec &container) { + Req.Clear(); + auto req = Req.mutable_updatefromspec(); + + auto ct = req->mutable_container(); + *ct = container; + + return Call(); +} + +EError TPortoApi::GetProperty(const TString &name, + const TString &property, + TString &value, + int flags) { + Req.Clear(); + auto req = Req.mutable_getproperty(); + + req->set_name(name); + req->set_property(property); + if (flags & GET_SYNC) + req->set_sync(true); + if (flags & GET_REAL) + req->set_real(true); + + if (!Call()) + value = Rsp.getproperty().value(); + + return LastError; +} + +EError TPortoApi::SetProperty(const TString &name, + const TString &property, + const TString &value) { + Req.Clear(); + auto req = Req.mutable_setproperty(); + + req->set_name(name); + req->set_property(property); + req->set_value(value); + + return Call(); +} + +EError TPortoApi::GetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t &value) { + TString key = property, str; + if (index.size()) + key = property + "[" + index + "]"; + if (!GetProperty(name, key, str)) { + const char *ptr = str.c_str(); + char *end; + errno = 0; + value = strtoull(ptr, &end, 10); + if (errno || end == ptr || *end) { + LastError = EError::InvalidValue; + LastErrorMsg = " value: " + str; + } + } + return LastError; +} + +EError TPortoApi::SetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t value) { + TString key = property; + if (index.size()) + key = property + "[" + index + "]"; + return SetProperty(name, key, ToString(value)); +} + +EError TPortoApi::GetProcMetric(const TVector<TString> &names, + const TString &metric, + TMap<TString, uint64_t> &values) { + auto it = ProcMetrics.find(metric); + + if (it == ProcMetrics.end()) { + LastError = EError::InvalidValue; + LastErrorMsg = " Unknown metric: " + metric; + return LastError; + } + + LastError = it->second->GetValues(names, values, *this); + + if (LastError) + LastErrorMsg = "Unknown error on Get() method"; + + return LastError; +} + +EError TPortoApi::SetLabel(const TString &name, + const TString &label, + const TString &value, + const TString &prev_value) { + Req.Clear(); + auto req = Req.mutable_setlabel(); + + req->set_name(name); + req->set_label(label); + req->set_value(value); + if (prev_value != " ") + req->set_prev_value(prev_value); + + return Call(); +} + +EError TPortoApi::IncLabel(const TString &name, + const TString &label, + int64_t add, + int64_t &result) { + Req.Clear(); + auto req = Req.mutable_inclabel(); + + req->set_name(name); + req->set_label(label); + req->set_add(add); + + EError err = Call(); + + if (Rsp.has_inclabel()) + result = Rsp.inclabel().result(); + + return err; +} + +EError TPortoApi::Start(const TString &name) { + Req.Clear(); + auto req = Req.mutable_start(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::Stop(const TString &name, int stop_timeout) { + Req.Clear(); + auto req = Req.mutable_stop(); + + req->set_name(name); + if (stop_timeout >= 0) + req->set_timeout_ms(stop_timeout * 1000); + + return Call(stop_timeout > 0 ? stop_timeout : 0); +} + +EError TPortoApi::Kill(const TString &name, int sig) { + Req.Clear(); + auto req = Req.mutable_kill(); + + req->set_name(name); + req->set_sig(sig); + + return Call(); +} + +EError TPortoApi::Pause(const TString &name) { + Req.Clear(); + auto req = Req.mutable_pause(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::Resume(const TString &name) { + Req.Clear(); + auto req = Req.mutable_resume(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::Respawn(const TString &name) { + Req.Clear(); + auto req = Req.mutable_respawn(); + + req->set_name(name); + + return Call(); +} + +EError TPortoApi::CallWait(TString &result_state, int wait_timeout) { + time_t deadline = 0; + time_t last_retry = 0; + + if (wait_timeout >= 0) { + deadline = time(nullptr) + wait_timeout; + Req.mutable_wait()->set_timeout_ms(wait_timeout * 1000); + } + +retry: + if (!Call(wait_timeout)) { + if (Rsp.wait().has_state()) + result_state = Rsp.wait().state(); + else if (Rsp.wait().name() == "") + result_state = "timeout"; + else + result_state = "dead"; + } else if (LastError == EError::SocketError && AutoReconnect) { + time_t now = time(nullptr); + + if (wait_timeout < 0 || now < deadline) { + if (wait_timeout >= 0) { + wait_timeout = deadline - now; + Req.mutable_wait()->set_timeout_ms(wait_timeout * 1000); + } + if (last_retry == now) + sleep(1); + last_retry = now; + goto retry; + } + + result_state = "timeout"; + } else + result_state = "unknown"; + + return LastError; +} + +EError TPortoApi::WaitContainer(const TString &name, + TString &result_state, + int wait_timeout) { + Req.Clear(); + auto req = Req.mutable_wait(); + + req->add_name(name); + + return CallWait(result_state, wait_timeout); +} + +EError TPortoApi::WaitContainers(const TVector<TString> &names, + TString &result_name, + TString &result_state, + int wait_timeout) { + Req.Clear(); + auto req = Req.mutable_wait(); + + for (auto &c : names) + req->add_name(c); + + EError err = CallWait(result_state, wait_timeout); + + result_name = Rsp.wait().name(); + + return err; +} + +const TWaitResponse *TPortoApi::Wait(const TVector<TString> &names, + const TVector<TString> &labels, + int wait_timeout) { + Req.Clear(); + auto req = Req.mutable_wait(); + TString result_state; + + for (auto &c : names) + req->add_name(c); + for (auto &label: labels) + req->add_label(label); + + EError err = CallWait(result_state, wait_timeout); + (void)err; + + if (Rsp.has_wait()) + return &Rsp.wait(); + + return nullptr; +} + +EError TPortoApi::AsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + TWaitCallback callback, + int wait_timeout, + const TString &targetState) { + Req.Clear(); + auto req = Req.mutable_asyncwait(); + + AsyncWaitNames.clear(); + AsyncWaitLabels.clear(); + AsyncWaitTimeout = wait_timeout; + AsyncWaitCallback = callback; + + for (auto &name: names) + req->add_name(name); + for (auto &label: labels) + req->add_label(label); + if (wait_timeout >= 0) + req->set_timeout_ms(wait_timeout * 1000); + if (!targetState.empty()) { + req->set_target_state(targetState); + AsyncWaitOneShot = true; + } else + AsyncWaitOneShot = false; + + if (Call()) { + AsyncWaitCallback = nullptr; + } else { + AsyncWaitNames = names; + AsyncWaitLabels = labels; + } + + return LastError; +} + +EError TPortoApi::StopAsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + const TString &targetState) { + Req.Clear(); + auto req = Req.mutable_stopasyncwait(); + + AsyncWaitNames.clear(); + AsyncWaitLabels.clear(); + + for (auto &name: names) + req->add_name(name); + for (auto &label: labels) + req->add_label(label); + if (!targetState.empty()) { + req->set_target_state(targetState); + } + + return Call(); +} + +EError TPortoApi::ConvertPath(const TString &path, + const TString &src, + const TString &dest, + TString &res) { + Req.Clear(); + auto req = Req.mutable_convertpath(); + + req->set_path(path); + req->set_source(src); + req->set_destination(dest); + + if (!Call()) + res = Rsp.convertpath().path(); + + return LastError; +} + +EError TPortoApi::AttachProcess(const TString &name, int pid, + const TString &comm) { + Req.Clear(); + auto req = Req.mutable_attachprocess(); + + req->set_name(name); + req->set_pid(pid); + req->set_comm(comm); + + return Call(); +} + +EError TPortoApi::AttachThread(const TString &name, int pid, + const TString &comm) { + Req.Clear(); + auto req = Req.mutable_attachthread(); + + req->set_name(name); + req->set_pid(pid); + req->set_comm(comm); + + return Call(); +} + +EError TPortoApi::LocateProcess(int pid, const TString &comm, + TString &name) { + Req.Clear(); + auto req = Req.mutable_locateprocess(); + + req->set_pid(pid); + req->set_comm(comm); + + if (!Call()) + name = Rsp.locateprocess().name(); + + return LastError; +} + +/* Volume */ + +const TListVolumePropertiesResponse *TPortoApi::ListVolumeProperties() { + Req.Clear(); + Req.mutable_listvolumeproperties(); + + if (!Call()) + return &Rsp.listvolumeproperties(); + + return nullptr; +} + +EError TPortoApi::ListVolumeProperties(TVector<TString> &properties) { + properties.clear(); + auto rsp = ListVolumeProperties(); + if (rsp) { + for (auto &prop: rsp->list()) + properties.push_back(prop.name()); + } + return LastError; +} + +EError TPortoApi::CreateVolume(TString &path, + const TMap<TString, TString> &config) { + Req.Clear(); + auto req = Req.mutable_createvolume(); + + req->set_path(path); + + *(req->mutable_properties()) = + google::protobuf::Map<TString, TString>(config.begin(), config.end()); + + if (!Call(DiskTimeout) && path.empty()) + path = Rsp.createvolume().path(); + + return LastError; +} + +EError TPortoApi::TuneVolume(const TString &path, + const TMap<TString, TString> &config) { + Req.Clear(); + auto req = Req.mutable_tunevolume(); + + req->set_path(path); + + *(req->mutable_properties()) = + google::protobuf::Map<TString, TString>(config.begin(), config.end()); + + return Call(DiskTimeout); +} + +EError TPortoApi::LinkVolume(const TString &path, + const TString &container, + const TString &target, + bool read_only, + bool required) { + Req.Clear(); + auto req = (target.empty() && !required) ? Req.mutable_linkvolume() : + Req.mutable_linkvolumetarget(); + + req->set_path(path); + if (!container.empty()) + req->set_container(container); + if (target != "") + req->set_target(target); + if (read_only) + req->set_read_only(read_only); + if (required) + req->set_required(required); + + return Call(); +} + +EError TPortoApi::UnlinkVolume(const TString &path, + const TString &container, + const TString &target, + bool strict) { + Req.Clear(); + auto req = (target == "***") ? Req.mutable_unlinkvolume() : + Req.mutable_unlinkvolumetarget(); + + req->set_path(path); + if (!container.empty()) + req->set_container(container); + if (target != "***") + req->set_target(target); + if (strict) + req->set_strict(strict); + + return Call(DiskTimeout); +} + +const TListVolumesResponse * +TPortoApi::ListVolumes(const TString &path, + const TString &container) { + Req.Clear(); + auto req = Req.mutable_listvolumes(); + + if (!path.empty()) + req->set_path(path); + + if (!container.empty()) + req->set_container(container); + + if (Call()) + return nullptr; + + auto list = Rsp.mutable_listvolumes(); + + /* compat */ + for (auto v: *list->mutable_volumes()) { + if (v.links().size()) + break; + for (auto &ct: v.containers()) + v.add_links()->set_container(ct); + } + + return list; +} + +EError TPortoApi::ListVolumes(TVector<TString> &paths) { + Req.Clear(); + auto rsp = ListVolumes(); + paths.clear(); + if (rsp) { + for (auto &v : rsp->volumes()) + paths.push_back(v.path()); + } + return LastError; +} + +const TVolumeDescription *TPortoApi::GetVolumeDesc(const TString &path) { + Req.Clear(); + auto rsp = ListVolumes(path); + + if (rsp && rsp->volumes().size()) + return &rsp->volumes(0); + + return nullptr; +} + +const TVolumeSpec *TPortoApi::GetVolume(const TString &path) { + Req.Clear(); + auto req = Req.mutable_getvolume(); + + req->add_path(path); + + if (!Call() && Rsp.getvolume().volume().size()) + return &Rsp.getvolume().volume(0); + + return nullptr; +} + +const TGetVolumeResponse *TPortoApi::GetVolumes(uint64_t changed_since) { + Req.Clear(); + auto req = Req.mutable_getvolume(); + + if (changed_since) + req->set_changed_since(changed_since); + + if (!Call() && Rsp.has_getvolume()) + return &Rsp.getvolume(); + + return nullptr; +} + + +EError TPortoApi::ListVolumesBy(const TGetVolumeRequest &getVolumeRequest, TVector<TVolumeSpec> &volumes) { + Req.Clear(); + auto req = Req.mutable_getvolume(); + *req = getVolumeRequest; + + auto ret = Call(); + if (ret) + return ret; + + for (auto volume : Rsp.getvolume().volume()) + volumes.push_back(volume); + return EError::Success; +} + +EError TPortoApi::CreateVolumeFromSpec(const TVolumeSpec &volume, TVolumeSpec &resultSpec) { + Req.Clear(); + auto req = Req.mutable_newvolume(); + auto vol = req->mutable_volume(); + *vol = volume; + + auto ret = Call(); + if (ret) + return ret; + + resultSpec = Rsp.newvolume().volume(); + + return ret; +} + +/* Layer */ + +EError TPortoApi::ImportLayer(const TString &layer, + const TString &tarball, + bool merge, + const TString &place, + const TString &private_value, + bool verboseError) { + Req.Clear(); + auto req = Req.mutable_importlayer(); + + req->set_layer(layer); + req->set_tarball(tarball); + req->set_merge(merge); + req->set_verbose_error(verboseError); + if (place.size()) + req->set_place(place); + if (private_value.size()) + req->set_private_value(private_value); + + return Call(DiskTimeout); +} + +EError TPortoApi::ExportLayer(const TString &volume, + const TString &tarball, + const TString &compress) { + Req.Clear(); + auto req = Req.mutable_exportlayer(); + + req->set_volume(volume); + req->set_tarball(tarball); + if (compress.size()) + req->set_compress(compress); + + return Call(DiskTimeout); +} + +EError TPortoApi::ReExportLayer(const TString &layer, + const TString &tarball, + const TString &compress) { + Req.Clear(); + auto req = Req.mutable_exportlayer(); + + req->set_volume(""); + req->set_layer(layer); + req->set_tarball(tarball); + if (compress.size()) + req->set_compress(compress); + + return Call(DiskTimeout); +} + +EError TPortoApi::RemoveLayer(const TString &layer, + const TString &place, + bool async) { + Req.Clear(); + auto req = Req.mutable_removelayer(); + + req->set_layer(layer); + req->set_async(async); + if (place.size()) + req->set_place(place); + + return Call(DiskTimeout); +} + +const TListLayersResponse *TPortoApi::ListLayers(const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_listlayers(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (Call()) + return nullptr; + + auto list = Rsp.mutable_listlayers(); + + /* compat conversion */ + if (!list->layers().size() && list->layer().size()) { + for (auto &name: list->layer()) { + auto l = list->add_layers(); + l->set_name(name); + l->set_owner_user(""); + l->set_owner_group(""); + l->set_last_usage(0); + l->set_private_value(""); + } + } + + return list; +} + +EError TPortoApi::ListLayers(TVector<TString> &layers, + const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_listlayers(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (!Call()) + layers = TVector<TString>(std::begin(Rsp.listlayers().layer()), + std::end(Rsp.listlayers().layer())); + + return LastError; +} + +EError TPortoApi::GetLayerPrivate(TString &private_value, + const TString &layer, + const TString &place) { + Req.Clear(); + auto req = Req.mutable_getlayerprivate(); + + req->set_layer(layer); + if (place.size()) + req->set_place(place); + + if (!Call()) + private_value = Rsp.getlayerprivate().private_value(); + + return LastError; +} + +EError TPortoApi::SetLayerPrivate(const TString &private_value, + const TString &layer, + const TString &place) { + Req.Clear(); + auto req = Req.mutable_setlayerprivate(); + + req->set_layer(layer); + req->set_private_value(private_value); + if (place.size()) + req->set_place(place); + + return Call(); +} + +/* Docker images */ + +DockerImage::DockerImage(const TDockerImage &i) { + Id = i.id(); + for (const auto &tag: i.tags()) + Tags.emplace_back(tag); + for (const auto &digest: i.digests()) + Digests.emplace_back(digest); + for (const auto &layer: i.layers()) + Layers.emplace_back(layer); + if (i.has_size()) + Size = i.size(); + if (i.has_config()) { + auto &cfg = i.config(); + for (const auto &cmd: cfg.cmd()) + Config.Cmd.emplace_back(cmd); + for (const auto &env: cfg.env()) + Config.Env.emplace_back(env); + } +} + +EError TPortoApi::DockerImageStatus(DockerImage &image, + const TString &name, + const TString &place) { + auto req = Req.mutable_dockerimagestatus(); + req->set_name(name); + if (!place.empty()) + req->set_place(place); + EError ret = Call(); + if (!ret && Rsp.dockerimagestatus().has_image()) + image = DockerImage(Rsp.dockerimagestatus().image()); + return ret; +} + +EError TPortoApi::ListDockerImages(std::vector<DockerImage> &images, + const TString &place, + const TString &mask) { + auto req = Req.mutable_listdockerimages(); + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + EError ret = Call(); + if (!ret) { + for (const auto &i: Rsp.listdockerimages().images()) + images.emplace_back(i); + } + return ret; +} + +EError TPortoApi::PullDockerImage(DockerImage &image, + const TString &name, + const TString &place, + const TString &auth_token, + const TString &auth_path, + const TString &auth_service) { + auto req = Req.mutable_pulldockerimage(); + req->set_name(name); + if (place.size()) + req->set_place(place); + if (auth_token.size()) + req->set_auth_token(auth_token); + if (auth_path.size()) + req->set_auth_path(auth_path); + if (auth_service.size()) + req->set_auth_service(auth_service); + EError ret = Call(); + if (!ret && Rsp.pulldockerimage().has_image()) + image = DockerImage(Rsp.pulldockerimage().image()); + return ret; +} + +EError TPortoApi::RemoveDockerImage(const TString &name, + const TString &place) { + auto req = Req.mutable_removedockerimage(); + req->set_name(name); + if (place.size()) + req->set_place(place); + return Call(); +} + +/* Storage */ + +const TListStoragesResponse *TPortoApi::ListStorages(const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_liststorages(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (Call()) + return nullptr; + + return &Rsp.liststorages(); +} + +EError TPortoApi::ListStorages(TVector<TString> &storages, + const TString &place, + const TString &mask) { + Req.Clear(); + auto req = Req.mutable_liststorages(); + + if (place.size()) + req->set_place(place); + if (mask.size()) + req->set_mask(mask); + + if (!Call()) { + storages.clear(); + for (auto &storage: Rsp.liststorages().storages()) + storages.push_back(storage.name()); + } + + return LastError; +} + +EError TPortoApi::RemoveStorage(const TString &storage, + const TString &place) { + Req.Clear(); + auto req = Req.mutable_removestorage(); + + req->set_name(storage); + if (place.size()) + req->set_place(place); + + return Call(DiskTimeout); +} + +EError TPortoApi::ImportStorage(const TString &storage, + const TString &archive, + const TString &place, + const TString &compression, + const TString &private_value) { + Req.Clear(); + auto req = Req.mutable_importstorage(); + + req->set_name(storage); + req->set_tarball(archive); + if (place.size()) + req->set_place(place); + if (compression.size()) + req->set_compress(compression); + if (private_value.size()) + req->set_private_value(private_value); + + return Call(DiskTimeout); +} + +EError TPortoApi::ExportStorage(const TString &storage, + const TString &archive, + const TString &place, + const TString &compression) { + Req.Clear(); + auto req = Req.mutable_exportstorage(); + + req->set_name(storage); + req->set_tarball(archive); + if (place.size()) + req->set_place(place); + if (compression.size()) + req->set_compress(compression); + + return Call(DiskTimeout); +} + +#ifdef __linux__ +void TAsyncWaiter::MainCallback(const TWaitResponse &event) { + CallbacksCount++; + + auto it = AsyncCallbacks.find(event.name()); + if (it != AsyncCallbacks.end() && it->second.State == event.state()) { + it->second.Callback(event); + AsyncCallbacks.erase(it); + } +} + +int TAsyncWaiter::Repair() { + for (const auto &it : AsyncCallbacks) { + int ret = Api.AsyncWait({it.first}, {}, GetMainCallback(), -1, it.second.State); + if (ret) + return ret; + } + return 0; +} + +void TAsyncWaiter::WatchDog() { + int ret; + auto apiFd = Api.Fd; + + while (true) { + struct epoll_event events[2]; + int nfds = epoll_wait(EpollFd, events, 2, -1); + + if (nfds < 0) { + if (errno == EINTR) + continue; + + Fatal("Can not make epoll_wait", errno); + return; + } + + for (int n = 0; n < nfds; ++n) { + if (events[n].data.fd == apiFd) { + TPortoResponse rsp; + ret = Api.Recv(rsp); + // portod reloaded - async_wait must be repaired + if (ret == EError::SocketError) { + ret = Api.Connect(); + if (ret) { + Fatal("Can not connect to porto api", ret); + return; + } + + ret = Repair(); + if (ret) { + Fatal("Can not repair", ret); + return; + } + + apiFd = Api.Fd; + + struct epoll_event portoEv; + portoEv.events = EPOLLIN; + portoEv.data.fd = apiFd; + if (epoll_ctl(EpollFd, EPOLL_CTL_ADD, apiFd, &portoEv)) { + Fatal("Can not epoll_ctl", errno); + return; + } + } + } else if (events[n].data.fd == Sock) { + ERequestType requestType = static_cast<ERequestType>(RecvInt(Sock)); + + switch (requestType) { + case ERequestType::Add: + HandleAddRequest(); + break; + case ERequestType::Del: + HandleDelRequest(); + break; + case ERequestType::Stop: + return; + case ERequestType::None: + default: + Fatal("Unknown request", static_cast<int>(requestType)); + } + } + } + } +} + +void TAsyncWaiter::SendInt(int fd, int value) { + int ret = write(fd, &value, sizeof(value)); + if (ret != sizeof(value)) + Fatal("Can not send int", errno); +} + +int TAsyncWaiter::RecvInt(int fd) { + int value; + int ret = read(fd, &value, sizeof(value)); + if (ret != sizeof(value)) + Fatal("Can not recv int", errno); + + return value; +} + +void TAsyncWaiter::HandleAddRequest() { + int ret = 0; + + auto it = AsyncCallbacks.find(ReqCt); + if (it != AsyncCallbacks.end()) { + ret = Api.StopAsyncWait({ReqCt}, {}, it->second.State); + AsyncCallbacks.erase(it); + } + + AsyncCallbacks.insert(std::make_pair(ReqCt, TCallbackData({ReqCallback, ReqState}))); + + ret = Api.AsyncWait({ReqCt}, {}, GetMainCallback(), -1, ReqState); + SendInt(Sock, ret); +} + +void TAsyncWaiter::HandleDelRequest() { + int ret = 0; + + auto it = AsyncCallbacks.find(ReqCt); + if (it != AsyncCallbacks.end()) { + ret = Api.StopAsyncWait({ReqCt}, {}, it->second.State); + AsyncCallbacks.erase(it); + } + + SendInt(Sock, ret); +} + +TAsyncWaiter::TAsyncWaiter(std::function<void(const TString &error, int ret)> fatalCallback) + : CallbacksCount(0ul) + , FatalCallback(fatalCallback) +{ + int socketPair[2]; + int ret = socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, socketPair); + if (ret) { + Fatal("Can not make socketpair", ret); + return; + } + + MasterSock = socketPair[0]; + Sock = socketPair[1]; + + ret = Api.Connect(); + if (ret) { + Fatal("Can not connect to porto api", ret); + return; + } + + auto apiFd = Api.Fd; + + EpollFd = epoll_create1(EPOLL_CLOEXEC); + + if (EpollFd == -1) { + Fatal("Can not epoll_create", errno); + return; + } + + struct epoll_event pairEv; + pairEv.events = EPOLLIN; + pairEv.data.fd = Sock; + + struct epoll_event portoEv; + portoEv.events = EPOLLIN; + portoEv.data.fd = apiFd; + + if (epoll_ctl(EpollFd, EPOLL_CTL_ADD, Sock, &pairEv)) { + Fatal("Can not epoll_ctl", errno); + return; + } + + if (epoll_ctl(EpollFd, EPOLL_CTL_ADD, apiFd, &portoEv)) { + Fatal("Can not epoll_ctl", errno); + return; + } + + WatchDogThread = std::unique_ptr<std::thread>(new std::thread(&TAsyncWaiter::WatchDog, this)); +} + +TAsyncWaiter::~TAsyncWaiter() { + SendInt(MasterSock, static_cast<int>(ERequestType::Stop)); + WatchDogThread->join(); + + // pedantic check, that porto api is watching by epoll + if (epoll_ctl(EpollFd, EPOLL_CTL_DEL, Api.Fd, nullptr) || epoll_ctl(EpollFd, EPOLL_CTL_DEL, Sock, nullptr)) { + Fatal("Can not epoll_ctl_del", errno); + } + + close(EpollFd); + close(Sock); + close(MasterSock); +} + +int TAsyncWaiter::Add(const TString &ct, const TString &state, TWaitCallback callback) { + if (FatalError) + return -1; + + ReqCt = ct; + ReqState = state; + ReqCallback = callback; + + SendInt(MasterSock, static_cast<int>(ERequestType::Add)); + return RecvInt(MasterSock); +} + +int TAsyncWaiter::Remove(const TString &ct) { + if (FatalError) + return -1; + + ReqCt = ct; + + SendInt(MasterSock, static_cast<int>(ERequestType::Del)); + return RecvInt(MasterSock); +} +#endif + +} /* namespace Porto */ diff --git a/library/cpp/porto/libporto.hpp b/library/cpp/porto/libporto.hpp new file mode 100644 index 0000000000..e30f22a41e --- /dev/null +++ b/library/cpp/porto/libporto.hpp @@ -0,0 +1,492 @@ +#pragma once + +#include <atomic> +#include <thread> + +#include <util/string/cast.h> +#include <util/generic/hash.h> +#include <util/generic/map.h> +#include <util/generic/vector.h> + +#include <library/cpp/porto/proto/rpc.pb.h> + +namespace Porto { + +constexpr int INFINITE_TIMEOUT = -1; +constexpr int DEFAULT_TIMEOUT = 300; // 5min +constexpr int DEFAULT_DISK_TIMEOUT = 900; // 15min + +constexpr char SOCKET_PATH[] = "/run/portod.socket"; + +typedef std::function<void(const TWaitResponse &event)> TWaitCallback; + +enum { + GET_NONBLOCK = 1, // try lock container state + GET_SYNC = 2, // refresh cached values, cache ttl 5s + GET_REAL = 4, // no faked or inherited values +}; + +struct DockerImage { + std::string Id; + std::vector<std::string> Tags; + std::vector<std::string> Digests; + std::vector<std::string> Layers; + uint64_t Size; + struct Config { + std::vector<std::string> Cmd; + std::vector<std::string> Env; + } Config; + + DockerImage() = default; + DockerImage(const TDockerImage &i); + + DockerImage(const DockerImage &i) = default; + DockerImage(DockerImage &&i) = default; + + DockerImage& operator=(const DockerImage &i) = default; + DockerImage& operator=(DockerImage &&i) = default; +}; + +class TPortoApi { +#ifdef __linux__ + friend class TAsyncWaiter; +#endif +private: + int Fd = -1; + int Timeout = DEFAULT_TIMEOUT; + int DiskTimeout = DEFAULT_DISK_TIMEOUT; + bool AutoReconnect = true; + + EError LastError = EError::Success; + TString LastErrorMsg; + + /* + * These keep last request and response. Method might return + * pointers to Rsp innards -> pointers valid until next call. + */ + TPortoRequest Req; + TPortoResponse Rsp; + + std::vector<TString> AsyncWaitNames; + std::vector<TString> AsyncWaitLabels; + int AsyncWaitTimeout = INFINITE_TIMEOUT; + TWaitCallback AsyncWaitCallback; + bool AsyncWaitOneShot = false; + + EError SetError(const TString &prefix, int _errno) Y_WARN_UNUSED_RESULT; + + EError SetSocketTimeout(int direction, int timeout) Y_WARN_UNUSED_RESULT; + + EError Send(const TPortoRequest &req) Y_WARN_UNUSED_RESULT; + + EError Recv(TPortoResponse &rsp) Y_WARN_UNUSED_RESULT; + + EError Call(int extra_timeout = 0) Y_WARN_UNUSED_RESULT; + + EError CallWait(TString &result_state, int wait_timeout) Y_WARN_UNUSED_RESULT; + +public: + TPortoApi() { } + ~TPortoApi(); + + int GetFd() const { + return Fd; + } + + bool Connected() const { + return Fd >= 0; + } + + EError Connect(const char *socket_path = SOCKET_PATH) Y_WARN_UNUSED_RESULT; + void Disconnect(); + + /* Requires signal(SIGPIPE, SIG_IGN) */ + void SetAutoReconnect(bool auto_reconnect) { + AutoReconnect = auto_reconnect; + } + + /* Request and response timeout in seconds */ + int GetTimeout() const { + return Timeout; + } + EError SetTimeout(int timeout); + + /* Extra timeout for disk operations in seconds */ + int GetDiskTimeout() const { + return DiskTimeout; + } + EError SetDiskTimeout(int timeout); + + EError Error() const Y_WARN_UNUSED_RESULT { + return LastError; + } + + EError GetLastError(TString &msg) const Y_WARN_UNUSED_RESULT { + msg = LastErrorMsg; + return LastError; + } + + /* Returns "LastError:(LastErrorMsg)" */ + TString GetLastError() const Y_WARN_UNUSED_RESULT; + + /* Returns text protobuf */ + TString GetLastRequest() const { + return Req.DebugString(); + } + TString GetLastResponse() const { + return Rsp.DebugString(); + } + + /* To be used for next changed_since */ + uint64_t ResponseTimestamp() const Y_WARN_UNUSED_RESULT { + return Rsp.timestamp(); + } + + // extra_timeout: 0 - none, -1 - infinite + EError Call(const TPortoRequest &req, + TPortoResponse &rsp, + int extra_timeout = 0) Y_WARN_UNUSED_RESULT; + + EError Call(const TString &req, + TString &rsp, + int extra_timeout = 0) Y_WARN_UNUSED_RESULT; + + /* System */ + + EError GetVersion(TString &tag, TString &revision) Y_WARN_UNUSED_RESULT; + + const TGetSystemResponse *GetSystem(); + + EError SetSystem(const TString &key, const TString &val) Y_WARN_UNUSED_RESULT; + + /* Container */ + + const TListPropertiesResponse *ListProperties(); + + EError ListProperties(TVector<TString> &properties) Y_WARN_UNUSED_RESULT; + + const TListResponse *List(const TString &mask = ""); + + EError List(TVector<TString> &names, const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError Create(const TString &name) Y_WARN_UNUSED_RESULT; + + EError CreateWeakContainer(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Destroy(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Start(const TString &name)Y_WARN_UNUSED_RESULT; + + // stop_timeout: time between SIGTERM and SIGKILL, -1 - default + EError Stop(const TString &name, int stop_timeout = -1) Y_WARN_UNUSED_RESULT; + + EError Kill(const TString &name, int sig = 9) Y_WARN_UNUSED_RESULT; + + EError Pause(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Resume(const TString &name) Y_WARN_UNUSED_RESULT; + + EError Respawn(const TString &name) Y_WARN_UNUSED_RESULT; + + // wait_timeout: 0 - nonblock, -1 - infinite + EError WaitContainer(const TString &name, + TString &result_state, + int wait_timeout = INFINITE_TIMEOUT) Y_WARN_UNUSED_RESULT; + + EError WaitContainers(const TVector<TString> &names, + TString &result_name, + TString &result_state, + int wait_timeout = INFINITE_TIMEOUT) Y_WARN_UNUSED_RESULT; + + const TWaitResponse *Wait(const TVector<TString> &names, + const TVector<TString> &labels, + int wait_timeout = INFINITE_TIMEOUT) Y_WARN_UNUSED_RESULT; + + EError AsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + TWaitCallback callbacks, + int wait_timeout = INFINITE_TIMEOUT, + const TString &targetState = "") Y_WARN_UNUSED_RESULT; + + EError StopAsyncWait(const TVector<TString> &names, + const TVector<TString> &labels, + const TString &targetState = "") Y_WARN_UNUSED_RESULT; + + const TGetResponse *Get(const TVector<TString> &names, + const TVector<TString> &properties, + int flags = 0) Y_WARN_UNUSED_RESULT; + + /* Porto v5 api */ + EError GetContainerSpec(const TString &name, TContainer &container) Y_WARN_UNUSED_RESULT ; + EError ListContainersBy(const TListContainersRequest &listContainersRequest, TVector<TContainer> &containers) Y_WARN_UNUSED_RESULT; + EError CreateFromSpec(const TContainerSpec &container, TVector<TVolumeSpec> volumes, bool start = false) Y_WARN_UNUSED_RESULT; + EError UpdateFromSpec(const TContainerSpec &container) Y_WARN_UNUSED_RESULT; + + EError GetProperty(const TString &name, + const TString &property, + TString &value, + int flags = 0) Y_WARN_UNUSED_RESULT; + + EError GetProperty(const TString &name, + const TString &property, + const TString &index, + TString &value, + int flags = 0) Y_WARN_UNUSED_RESULT { + return GetProperty(name, property + "[" + index + "]", value, flags); + } + + EError SetProperty(const TString &name, + const TString &property, + const TString &value) Y_WARN_UNUSED_RESULT; + + EError SetProperty(const TString &name, + const TString &property, + const TString &index, + const TString &value) Y_WARN_UNUSED_RESULT { + return SetProperty(name, property + "[" + index + "]", value); + } + + EError GetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t &value) Y_WARN_UNUSED_RESULT; + + EError GetInt(const TString &name, + const TString &property, + uint64_t &value) Y_WARN_UNUSED_RESULT { + return GetInt(name, property, "", value); + } + + EError SetInt(const TString &name, + const TString &property, + const TString &index, + uint64_t value) Y_WARN_UNUSED_RESULT; + + EError SetInt(const TString &name, + const TString &property, + uint64_t value) Y_WARN_UNUSED_RESULT { + return SetInt(name, property, "", value); + } + + EError GetProcMetric(const TVector<TString> &names, + const TString &metric, + TMap<TString, uint64_t> &values); + + EError GetLabel(const TString &name, + const TString &label, + TString &value) Y_WARN_UNUSED_RESULT { + return GetProperty(name, "labels", label, value); + } + + EError SetLabel(const TString &name, + const TString &label, + const TString &value, + const TString &prev_value = " ") Y_WARN_UNUSED_RESULT; + + EError IncLabel(const TString &name, + const TString &label, + int64_t add, + int64_t &result) Y_WARN_UNUSED_RESULT; + + EError IncLabel(const TString &name, + const TString &label, + int64_t add = 1) Y_WARN_UNUSED_RESULT { + int64_t result; + return IncLabel(name, label, add, result); + } + + EError ConvertPath(const TString &path, + const TString &src_name, + const TString &dst_name, + TString &result_path) Y_WARN_UNUSED_RESULT; + + EError AttachProcess(const TString &name, int pid, + const TString &comm = "") Y_WARN_UNUSED_RESULT; + + EError AttachThread(const TString &name, int pid, + const TString &comm = "") Y_WARN_UNUSED_RESULT; + + EError LocateProcess(int pid, + const TString &comm /* = "" */, + TString &name) Y_WARN_UNUSED_RESULT; + + /* Volume */ + + const TListVolumePropertiesResponse *ListVolumeProperties(); + + EError ListVolumeProperties(TVector<TString> &properties) Y_WARN_UNUSED_RESULT; + + const TListVolumesResponse *ListVolumes(const TString &path = "", + const TString &container = ""); + + EError ListVolumes(TVector<TString> &paths) Y_WARN_UNUSED_RESULT; + + const TVolumeDescription *GetVolumeDesc(const TString &path); + + /* Porto v5 api */ + EError ListVolumesBy(const TGetVolumeRequest &getVolumeRequest, TVector<TVolumeSpec> &volumes) Y_WARN_UNUSED_RESULT; + EError CreateVolumeFromSpec(const TVolumeSpec &volume, TVolumeSpec &resultSpec) Y_WARN_UNUSED_RESULT; + + const TVolumeSpec *GetVolume(const TString &path); + + const TGetVolumeResponse *GetVolumes(uint64_t changed_since = 0); + + EError CreateVolume(TString &path, + const TMap<TString, TString> &config) Y_WARN_UNUSED_RESULT; + + EError LinkVolume(const TString &path, + const TString &container = "", + const TString &target = "", + bool read_only = false, + bool required = false) Y_WARN_UNUSED_RESULT; + + EError UnlinkVolume(const TString &path, + const TString &container = "", + const TString &target = "***", + bool strict = false) Y_WARN_UNUSED_RESULT; + + EError TuneVolume(const TString &path, + const TMap<TString, TString> &config) Y_WARN_UNUSED_RESULT; + + /* Layer */ + + const TListLayersResponse *ListLayers(const TString &place = "", + const TString &mask = ""); + + EError ListLayers(TVector<TString> &layers, + const TString &place = "", + const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError ImportLayer(const TString &layer, + const TString &tarball, + bool merge = false, + const TString &place = "", + const TString &private_value = "", + bool verboseError = false) Y_WARN_UNUSED_RESULT; + + EError ExportLayer(const TString &volume, + const TString &tarball, + const TString &compress = "") Y_WARN_UNUSED_RESULT; + + EError ReExportLayer(const TString &layer, + const TString &tarball, + const TString &compress = "") Y_WARN_UNUSED_RESULT; + + EError RemoveLayer(const TString &layer, + const TString &place = "", + bool async = false) Y_WARN_UNUSED_RESULT; + + EError GetLayerPrivate(TString &private_value, + const TString &layer, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + EError SetLayerPrivate(const TString &private_value, + const TString &layer, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + /* Docker images */ + + EError DockerImageStatus(DockerImage &image, + const TString &name, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + EError ListDockerImages(std::vector<DockerImage> &images, + const TString &place = "", + const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError PullDockerImage(DockerImage &image, + const TString &name, + const TString &place = "", + const TString &auth_token = "", + const TString &auth_host = "", + const TString &auth_service = "") Y_WARN_UNUSED_RESULT; + + EError RemoveDockerImage(const TString &name, + const TString &place = ""); + + /* Storage */ + + const TListStoragesResponse *ListStorages(const TString &place = "", + const TString &mask = ""); + + EError ListStorages(TVector<TString> &storages, + const TString &place = "", + const TString &mask = "") Y_WARN_UNUSED_RESULT; + + EError RemoveStorage(const TString &storage, + const TString &place = "") Y_WARN_UNUSED_RESULT; + + EError ImportStorage(const TString &storage, + const TString &archive, + const TString &place = "", + const TString &compression = "", + const TString &private_value = "") Y_WARN_UNUSED_RESULT; + + EError ExportStorage(const TString &storage, + const TString &archive, + const TString &place = "", + const TString &compression = "") Y_WARN_UNUSED_RESULT; +}; + +#ifdef __linux__ +class TAsyncWaiter { + struct TCallbackData { + const TWaitCallback Callback; + const TString State; + }; + + enum class ERequestType { + None, + Add, + Del, + Stop, + }; + + THashMap<TString, TCallbackData> AsyncCallbacks; + std::unique_ptr<std::thread> WatchDogThread; + std::atomic<uint64_t> CallbacksCount; + int EpollFd = -1; + TPortoApi Api; + + int Sock, MasterSock; + TString ReqCt; + TString ReqState; + TWaitCallback ReqCallback; + + std::function<void(const TString &error, int ret)> FatalCallback; + bool FatalError = false; + + void MainCallback(const TWaitResponse &event); + inline TWaitCallback GetMainCallback() { + return [this](const TWaitResponse &event) { + MainCallback(event); + }; + } + + int Repair(); + void WatchDog(); + + void SendInt(int fd, int value); + int RecvInt(int fd); + + void HandleAddRequest(); + void HandleDelRequest(); + + void Fatal(const TString &error, int ret) { + FatalError = true; + FatalCallback(error, ret); + } + + public: + TAsyncWaiter(std::function<void(const TString &error, int ret)> fatalCallback); + ~TAsyncWaiter(); + + int Add(const TString &ct, const TString &state, TWaitCallback callback); + int Remove(const TString &ct); + uint64_t InvocationCount() const { + return CallbacksCount; + } +}; +#endif + +} /* namespace Porto */ diff --git a/library/cpp/porto/metrics.cpp b/library/cpp/porto/metrics.cpp new file mode 100644 index 0000000000..7d17d0aee4 --- /dev/null +++ b/library/cpp/porto/metrics.cpp @@ -0,0 +1,183 @@ +#include "metrics.hpp" + +#include <util/folder/path.h> +#include <util/generic/maybe.h> +#include <util/stream/file.h> + +namespace Porto { + +TMap<TString, TMetric*> ProcMetrics; + +TMetric::TMetric(const TString& name, EMetric metric) { + Name = name; + Metric = metric; + ProcMetrics[name] = this; +} + +void TMetric::ClearValues(const TVector<TString>& names, TMap<TString, uint64_t>& values) const { + values.clear(); + + for (const auto&name : names) + values[name] = 0; +} + +EError TMetric::GetValues(const TVector<TString>& names, TMap<TString, uint64_t>& values, TPortoApi& api) const { + ClearValues(names, values); + + int procFd = open("/proc", O_RDONLY | O_CLOEXEC | O_DIRECTORY | O_NOCTTY); + TFileHandle procFdHandle(procFd); + if (procFd == -1) + return EError::Unknown; + + TVector<TString> tids; + TidSnapshot(tids); + + auto getResponse = api.Get(names, TVector<TString>{"cgroups[freezer]"}); + + if (getResponse == nullptr) + return EError::Unknown; + + const auto containersCgroups = GetCtFreezerCgroups(getResponse); + + for (const auto& tid : tids) { + const TString tidCgroup = GetFreezerCgroup(procFd, tid); + if (tidCgroup == "") + continue; + + TMaybe<uint64_t> metricValue; + + for (const auto& keyval : containersCgroups) { + const TString& containerCgroup = keyval.second; + if (MatchCgroups(tidCgroup, containerCgroup)) { + if (!metricValue) + metricValue = GetMetric(procFd, tid); + values[keyval.first] += *metricValue; + } + } + } + + return EError::Success; +} + +uint64_t TMetric::GetTidSchedMetricValue(int procFd, const TString& tid, const TString& metricName) const { + const TString schedPath = tid + "/sched"; + try { + int fd = openat(procFd, schedPath.c_str(), O_RDONLY | O_CLOEXEC | O_NOCTTY, 0); + TFile file(fd); + if (!file.IsOpen()) + return 0ul; + + TIFStream iStream(file); + TString line; + while (iStream.ReadLine(line)) { + auto metricPos = line.find(metricName); + + if (metricPos != TString::npos) { + auto valuePos = metricPos; + + while (valuePos < line.size() && !::isdigit(line[valuePos])) + ++valuePos; + + TString value = line.substr(valuePos); + if (!value.empty() && IsNumber(value)) + return IntFromString<uint64_t, 10>(value); + } + } + } + catch(...) {} + + return 0ul; +} + +void TMetric::GetPidTasks(const TString& pid, TVector<TString>& tids) const { + TFsPath task("/proc/" + pid + "/task"); + TVector<TString> rawTids; + + try { + task.ListNames(rawTids); + } + catch(...) {} + + for (const auto& tid : rawTids) { + tids.push_back(tid); + } +} + +void TMetric::TidSnapshot(TVector<TString>& tids) const { + TFsPath proc("/proc"); + TVector<TString> rawPids; + + try { + proc.ListNames(rawPids); + } + catch(...) {} + + for (const auto& pid : rawPids) { + if (IsNumber(pid)) + GetPidTasks(pid, tids); + } +} + +TString TMetric::GetFreezerCgroup(int procFd, const TString& tid) const { + const TString cgroupPath = tid + "/cgroup"; + try { + int fd = openat(procFd, cgroupPath.c_str(), O_RDONLY | O_CLOEXEC | O_NOCTTY, 0); + TFile file(fd); + if (!file.IsOpen()) + return TString(); + + TIFStream iStream(file); + TString line; + + while (iStream.ReadLine(line)) { + static const TString freezer = ":freezer:"; + auto freezerPos = line.find(freezer); + + if (freezerPos != TString::npos) { + line = line.substr(freezerPos + freezer.size()); + return line; + } + } + } + catch(...){} + + return TString(); +} + +TMap<TString, TString> TMetric::GetCtFreezerCgroups(const TGetResponse* response) const { + TMap<TString, TString> containersProps; + + for (const auto& ctGetListResponse : response->list()) { + for (const auto& keyval : ctGetListResponse.keyval()) { + if (!keyval.error()) { + TString value = keyval.value(); + static const TString freezerPath = "/sys/fs/cgroup/freezer"; + + if (value.find(freezerPath) != TString::npos) + value = value.substr(freezerPath.size()); + + containersProps[ctGetListResponse.name()] = value; + } + } + } + + return containersProps; +} + +bool TMetric::MatchCgroups(const TString& tidCgroup, const TString& ctCgroup) const { + if (tidCgroup.size() <= ctCgroup.size()) + return tidCgroup == ctCgroup; + return ctCgroup == tidCgroup.substr(0, ctCgroup.size()) && tidCgroup[ctCgroup.size()] == '/'; +} + +class TCtxsw : public TMetric { +public: + TCtxsw() : TMetric(M_CTXSW, EMetric::CTXSW) + {} + + uint64_t GetMetric(int procFd, const TString& tid) const override { + return GetTidSchedMetricValue(procFd, tid, "nr_switches"); + } +} static Ctxsw; + +} /* namespace Porto */ diff --git a/library/cpp/porto/metrics.hpp b/library/cpp/porto/metrics.hpp new file mode 100644 index 0000000000..5b2ffde8d9 --- /dev/null +++ b/library/cpp/porto/metrics.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "libporto.hpp" + +#include <util/generic/map.h> +#include <util/generic/vector.h> +#include <util/string/cast.h> +#include <util/string/type.h> + +#include <library/cpp/porto/proto/rpc.pb.h> +namespace Porto { + +constexpr const char *M_CTXSW = "ctxsw"; + +enum class EMetric { + NONE, + CTXSW, +}; + +class TMetric { +public: + TString Name; + EMetric Metric; + + TMetric(const TString& name, EMetric metric); + + void ClearValues(const TVector<TString>& names, TMap<TString, uint64_t>& values) const; + EError GetValues(const TVector<TString>& names, TMap<TString, uint64_t>& values, TPortoApi& api) const; + + // Returns value of metric from /proc/tid/sched for some tid + uint64_t GetTidSchedMetricValue(int procFd, const TString& tid, const TString& metricName) const; + + void TidSnapshot(TVector<TString>& tids) const; + void GetPidTasks(const TString& pid, TVector<TString>& tids) const; + + // Returns freezer cgroup from /proc/tid/cgroup + TString GetFreezerCgroup(int procFd, const TString& tid) const; + + // Resurns clean cgroup[freezer] for containers names + TMap<TString, TString> GetCtFreezerCgroups(const TGetResponse* response) const; + + // Verify inclusion of container cgroup in process cgroup + bool MatchCgroups(const TString& tidCgroup, const TString& ctCgroup) const; + +private: + virtual uint64_t GetMetric(int procFd, const TString& tid) const = 0; +}; + +extern TMap<TString, TMetric*> ProcMetrics; +} /* namespace Porto */ diff --git a/library/cpp/porto/proto/rpc.proto b/library/cpp/porto/proto/rpc.proto new file mode 100644 index 0000000000..5c2e9fdbc3 --- /dev/null +++ b/library/cpp/porto/proto/rpc.proto @@ -0,0 +1,1606 @@ +syntax = "proto2"; + +option go_package = "github.com/ydb-platform/ydb/library/cpp/porto/proto;myapi"; + +/* + Portod daemon listens on /run/portod.socket unix socket. + + Request: Varint32 length, TPortoRequest request + Response: Varint32 length, TPortoResponse response + + Command is defined by optional nested message field. + Result will be in nested message with the same name. + + Push notification is send as out of order response. + + Access level depends on client container and uid. + + See defails in porto.md or manpage porto + + TContainer, TVolume and related methods are Porto v5 API. +*/ + +package Porto; + +// List of error codes +enum EError { + // No errors occured. + Success = 0; + + // Unclassified error, usually unexpected syscall fail. + Unknown = 1; + + // Unknown method or bad request. + InvalidMethod = 2; + + // Container with specified name already exists. + ContainerAlreadyExists = 3; + + // Container with specified name doesn't exist. + ContainerDoesNotExist = 4; + + // Unknown property specified. + InvalidProperty = 5; + + // Unknown data specified. + InvalidData = 6; + + // Invalid value of property or data. + InvalidValue = 7; + + // Can't perform specified operation in current container state. + InvalidState = 8; + + // Permanent faulure: old kernel version, missing feature, configuration, etc. + NotSupported = 9; + + // Temporary failure: too much objects, not enough memory, etc. + ResourceNotAvailable = 10; + + // Insufficient rights for performing requested operation. + Permission = 11; + + // Can't create new volume with specified name, because there is already one. + VolumeAlreadyExists = 12; + + // Volume with specified name doesn't exist. + VolumeNotFound = 13; + + // Not enough disk space. + NoSpace = 14; + + // Object in use. + Busy = 15; + + // Volume already linked with container. + VolumeAlreadyLinked = 16; + + // Volume not linked with container. + VolumeNotLinked = 17; + + // Layer with this name already exists. + LayerAlreadyExists = 18; + + // Layer with this name not found. + LayerNotFound = 19; + + // Property has no value, data source permanently not available. + NoValue = 20; + + // Volume under construction or destruction. + VolumeNotReady = 21; + + // Cannot parse or execute command. + InvalidCommand = 22; + + // Error code is lost or came from future. + LostError = 23; + + // Device node not found. + DeviceNotFound = 24; + + // Path does not match restricitons or does not exist. + InvalidPath = 25; + + // Wrong or unuseable ip address. + InvalidNetworkAddress = 26; + + // Porto in system maintenance state. + PortoFrozen = 27; + + // Label with this name is not set. + LabelNotFound = 28; + + // Label name does not meet restrictions. + InvalidLabel = 29; + + // Errors in tar, on archive extraction + HelperError = 30; + HelperFatalError = 31; + + // Generic object not found. + NotFound = 404; + + // Reserved error code for client library. + SocketError = 502; + + // Reserved error code for client library. + SocketUnavailable = 503; + + // Reserved error code for client library. + SocketTimeout = 504; + + // Portod close client connections on reload + PortodReloaded = 505; + + // Reserved error code for taints. + Taint = 666; + + // Reserved error codes 700-800 to docker + Docker = 700; + DockerImageNotFound = 701; + + // Internal error code, not for users. + Queued = 1000; +} + + +message TPortoRequest { + + /* System methods */ + + // Get portod version + optional TVersionRequest Version = 14; + + // Get portod statistics + optional TGetSystemRequest GetSystem = 300; + + // Change portod state (for host root user only) + optional TSetSystemRequest SetSystem = 301; + + /* Container methods */ + + // Create new container + optional TCreateRequest Create = 1; + + // Create new contaienr and auto destroy when client disconnects + optional TCreateRequest CreateWeak = 17; + + // Force kill all and destroy container and nested containers + optional TDestroyRequest Destroy = 2; + + // List container names in current namespace + optional TListRequest List = 3; + + // Start contianer and parents if needed + optional TStartRequest Start = 7; + + // Kill all and stop container + optional TStopRequest Stop = 8; + + // Freeze execution + optional TPauseRequest Pause = 9; + + // Resume execution + optional TResumeRequest Resume = 10; + + // Send signal to main process + optional TKillRequest Kill = 13; + + // Restart dead container + optional TRespawnRequest Respawn = 18; + + // Wait for process finish or change of labels + optional TWaitRequest Wait = 16; + + // Subscribe to push notifictaions + optional TWaitRequest AsyncWait = 19; + optional TWaitRequest StopAsyncWait = 128; + + /* Container properties */ + + // List supported container properties + optional TListPropertiesRequest ListProperties = 11; + + // Get one property + optional TGetPropertyRequest GetProperty = 4; + + // Set one property + optional TSetPropertyRequest SetProperty = 5; + + // Deprecated, now data properties are also read-only properties + optional TListDataPropertiesRequest ListDataProperties = 12; + optional TGetDataPropertyRequest GetDataProperty = 6; + + // Get multiple properties for multiple containers + optional TGetRequest Get = 15; + + /* Container API based on TContainer (Porto v5 API) */ + + // Create, configure and start container with volumes + optional TCreateFromSpecRequest CreateFromSpec = 230; + + // Set multiple container properties + optional TUpdateFromSpecRequest UpdateFromSpec = 231; + + // Get multiple properties for multiple containers + optional TListContainersRequest ListContainersBy = 232; + + // Modify symlink in container + optional TSetSymlinkRequest SetSymlink = 125; + + /* Container labels - user defined key-value */ + + // Find containers with labels + optional TFindLabelRequest FindLabel = 20; + + // Atomic compare and set for label + optional TSetLabelRequest SetLabel = 21; + + // Atomic add and return for counter in label + optional TIncLabelRequest IncLabel = 22; + + /* Volume methods */ + + optional TListVolumePropertiesRequest ListVolumeProperties = 103; + + // List layers and their properties + optional TListVolumesRequest ListVolumes = 107; + + // Create, configure and build volume + optional TCreateVolumeRequest CreateVolume = 104; + + // Change volume properties - for now only resize + optional TTuneVolumeRequest TuneVolume = 108; + + // Volume API based on TVolume (Porto v5 API) + optional TNewVolumeRequest NewVolume = 126; + optional TGetVolumeRequest GetVolume = 127; + + // Add link between container and volume + optional TLinkVolumeRequest LinkVolume = 105; + + // Same as LinkVolume but fails if target is not supported + optional TLinkVolumeRequest LinkVolumeTarget = 120; + + // Del link between container and volume + optional TUnlinkVolumeRequest UnlinkVolume = 106; + + // Same as UnlinkVolume but fails if target is not supported + optional TUnlinkVolumeRequest UnlinkVolumeTarget = 121; + + /* Layer methods */ + + // Import layer from tarball + optional TImportLayerRequest ImportLayer = 110; + + // Remove layer + optional TRemoveLayerRequest RemoveLayer = 111; + + // List layers + optional TListLayersRequest ListLayers = 112; + + // Export volume or layer into tarball + optional TExportLayerRequest ExportLayer = 113; + + // Get/set layer private (user defined string) + optional TGetLayerPrivateRequest GetLayerPrivate = 114; + optional TSetLayerPrivateRequest SetLayerPrivate = 115; + + /* Storage methods */ + + // Volume creation creates required storage if missing + + // List storages and meta storages + optional TListStoragesRequest ListStorages = 116; + + optional TRemoveStorageRequest RemoveStorage = 117; + + // Import storage from tarball + optional TImportStorageRequest ImportStorage = 118; + + // Export storage into tarball + optional TExportStorageRequest ExportStorage = 119; + + // Meta storage (bundle for storages and layers) + + optional TMetaStorage CreateMetaStorage = 122; + optional TMetaStorage ResizeMetaStorage = 123; + optional TMetaStorage RemoveMetaStorage = 124; + + // Convert path between containers + optional TConvertPathRequest ConvertPath = 200; + + /* Process methods */ + + // Attach process to nested container + optional TAttachProcessRequest AttachProcess = 201; + + // Find container for process + optional TLocateProcessRequest LocateProcess = 202; + + // Attach one thread to nexted container + optional TAttachProcessRequest AttachThread = 203; + + /* Docker images API */ + + optional TDockerImageStatusRequest dockerImageStatus = 303; + optional TDockerImageListRequest listDockerImages = 304; + optional TDockerImagePullRequest pullDockerImage = 305; + optional TDockerImageRemoveRequest removeDockerImage = 306; +} + + +message TPortoResponse { + // Actually always set, hack for adding new error codes + optional EError error = 1 [ default = LostError ]; + + // Human readable comment - must be shown to user as is + optional string errorMsg = 2; + + optional uint64 timestamp = 1000; // for next changed_since + + /* System methods */ + + optional TVersionResponse Version = 8; + + optional TGetSystemResponse GetSystem = 300; + optional TSetSystemResponse SetSystem = 301; + + /* Container methods */ + + optional TListResponse List = 3; + + optional TWaitResponse Wait = 11; + + optional TWaitResponse AsyncWait = 19; + + /* Container properties */ + + optional TListPropertiesResponse ListProperties = 6; + + optional TGetPropertyResponse GetProperty = 4; + + + // Deprecated + optional TListDataPropertiesResponse ListDataProperties = 7; + optional TGetDataPropertyResponse GetDataProperty = 5; + + optional TGetResponse Get = 10; + + /* Container API based on TContainer (Porto v5 API) */ + + optional TListContainersResponse ListContainersBy = 232; + + /* Container Labels */ + + optional TFindLabelResponse FindLabel = 20; + optional TSetLabelResponse SetLabel = 21; + optional TIncLabelResponse IncLabel = 22; + + /* Volume methods */ + + optional TListVolumePropertiesResponse ListVolumeProperties = 12; + + optional TListVolumesResponse ListVolumes = 9; + + optional TVolumeDescription CreateVolume = 13; + + optional TNewVolumeResponse NewVolume = 126; + + optional TGetVolumeResponse GetVolume = 127; + + optional TListLayersResponse ListLayers = 14; + + optional TGetLayerPrivateResponse GetLayerPrivate = 16; + + // List storages and meta storages + optional TListStoragesResponse ListStorages = 17; + + optional TConvertPathResponse ConvertPath = 15; + + // Process + optional TLocateProcessResponse LocateProcess = 18; + + /* Docker images API */ + + optional TDockerImageStatusResponse dockerImageStatus = 302; + optional TDockerImageListResponse listDockerImages = 303; + optional TDockerImagePullResponse pullDockerImage = 304; +} + + +// Common objects + + +message TStringMap { + message TStringMapEntry { + optional string key = 1; + optional string val = 2; + } + // TODO replace with map + // map<string, string> map = 1; + repeated TStringMapEntry map = 1; + optional bool merge = 2; // in, default: replace +} + + +message TUintMap { + message TUintMapEntry { + optional string key = 1; + optional uint64 val = 2; + } + // TODO replace with map + // map<string, uint64> map = 1; + repeated TUintMapEntry map = 1; + optional bool merge = 2; // in, default: replace +} + + +message TError { + optional EError error = 1 [ default = LostError ]; + optional string msg = 2; +} + + +message TCred { + optional string user = 1; // requires user or uid or both + optional fixed32 uid = 2; + optional string group = 3; + optional fixed32 gid = 4; + repeated fixed32 grp = 5; // out, supplementary groups +} + + +message TCapabilities { + repeated string cap = 1; + optional string hex = 2; // out +} + + +message TContainerCommandArgv { + repeated string argv = 1; +} + + +// Container + + +message TContainerEnvVar { + optional string name = 1; //required + optional string value = 2; + optional bool unset = 3; // out + optional string salt = 4; + optional string hash = 5; +} + +message TContainerEnv { + repeated TContainerEnvVar var = 1; + optional bool merge = 2; // in, default: replace +} + + +message TContainerUlimit { + optional string type = 1; //required + optional bool unlimited = 2; + optional uint64 soft = 3; + optional uint64 hard = 4; + optional bool inherited = 5; // out +} + +message TContainerUlimits { + repeated TContainerUlimit ulimit = 1; + optional bool merge = 2; // in, default: replace +} + + +message TContainerControllers { + repeated string controller = 1; +} + + +message TContainerCgroup { + optional string controller = 1; //required + optional string path = 2; //required + optional bool inherited = 3; +} + +message TContainerCgroups { + repeated TContainerCgroup cgroup = 1; +} + + +message TContainerCpuSet { + optional string policy = 1; // inherit|set|node|reserve|threads|cores + optional uint32 arg = 2; // for node|reserve|threads|cores + optional string list = 3; // for set + repeated uint32 cpu = 4; // for set (used if list isn't set) + optional uint32 count = 5; // out + optional string mems = 6; +} + + +message TContainerBindMount { + optional string source = 1; //required + optional string target = 2; //required + repeated string flag = 3; +} + +message TContainerBindMounts { + repeated TContainerBindMount bind = 1; +} + + +message TContainerVolumeLink { + optional string volume = 1; //required + optional string target = 2; + optional bool required = 3; + optional bool read_only = 4; +} + +message TContainerVolumeLinks { + repeated TContainerVolumeLink link = 1; +} + + +message TContainerVolumes { + repeated string volume = 1; +} + + +message TContainerPlace { + optional string place = 1; //required + optional string alias = 2; +} + +message TContainerPlaceConfig { + repeated TContainerPlace cfg = 1; +} + + +message TContainerDevice { + optional string device = 1; //required + optional string access = 2; //required + optional string path = 3; + optional string mode = 4; + optional string user = 5; + optional string group = 6; +} + +message TContainerDevices { + repeated TContainerDevice device = 1; + optional bool merge = 2; // in, default: replace +} + + +message TContainerNetOption { + optional string opt = 1; //required + repeated string arg = 2; +} + +message TContainerNetConfig { + repeated TContainerNetOption cfg = 1; + optional bool inherited = 2; // out +} + + +message TContainerIpLimit { + optional string policy = 1; //required any|none|some + repeated string ip = 2; +} + + +message TContainerIpConfig { + message TContainerIp { + optional string dev = 1; //required + optional string ip = 2; //required + } + repeated TContainerIp cfg = 1; +} + + +message TVmStat { + optional uint64 count = 1; + optional uint64 size = 2; + optional uint64 max_size = 3; + optional uint64 used = 4; + optional uint64 max_used = 5; + optional uint64 anon = 6; + optional uint64 file = 7; + optional uint64 shmem = 8; + optional uint64 huge = 9; + optional uint64 swap = 10; + optional uint64 data = 11; + optional uint64 stack = 12; + optional uint64 code = 13; + optional uint64 locked = 14; + optional uint64 table = 15; +} + +message TContainerStatus { + optional string absolute_name = 1; // out, "/porto/..." + optional string state = 2; // out + optional uint64 id = 3; // out + optional uint32 level = 4; // out + optional string parent = 5; // out, "/porto/..." + + optional string absolute_namespace = 6; // out + + optional int32 root_pid = 7; // out + optional int32 exit_status = 8; // out + optional int32 exit_code = 9; // out + optional bool core_dumped = 10; // out + optional TError start_error = 11; // out + optional uint64 time = 12; // out + optional uint64 dead_time = 13; // out + + optional TCapabilities capabilities_allowed = 14; // out + optional TCapabilities capabilities_ambient_allowed = 15; // out + optional string root_path = 16; // out, in client namespace + optional uint64 stdout_offset = 17; // out + optional uint64 stderr_offset = 18; // out + optional string std_err = 69; // out + optional string std_out = 70; // out + + optional uint64 creation_time = 19; // out + optional uint64 start_time = 20; // out + optional uint64 death_time = 21; // out + optional uint64 change_time = 22; // out + optional bool no_changes = 23; // out, change_time < changed_since + optional string extra_properties = 73; + + optional TContainerCgroups cgroups = 24; // out + optional TContainerCpuSet cpu_set_affinity = 25; // out + + optional uint64 cpu_usage = 26; // out + optional uint64 cpu_usage_system = 27; // out + optional uint64 cpu_wait = 28; // out + optional uint64 cpu_throttled = 29; // out + + optional uint64 process_count = 30; // out + optional uint64 thread_count = 31; // out + + optional TUintMap io_read = 32; // out, bytes + optional TUintMap io_write = 33; // out, bytes + optional TUintMap io_ops = 34; // out, ops + optional TUintMap io_read_ops = 341; // out, ops + optional TUintMap io_write_ops = 342; // out, ops + optional TUintMap io_time = 35; // out, ns + optional TUintMap io_pressure = 351; // out + + optional TUintMap place_usage = 36; + optional uint64 memory_usage = 37; // out, bytes + + optional uint64 memory_guarantee_total = 38; // out + + optional uint64 memory_limit_total = 39; // out + + optional uint64 anon_limit_total = 40; + optional uint64 anon_usage = 41; // out, bytes + optional double cpu_guarantee_total = 42; + optional double cpu_guarantee_bound = 421; + optional double cpu_limit_total = 422; + optional double cpu_limit_bound = 423; + + optional uint64 cache_usage = 43; // out, bytes + + optional uint64 hugetlb_usage = 44; // out, bytes + optional uint64 hugetlb_limit = 45; + + optional uint64 minor_faults = 46; // out + optional uint64 major_faults = 47; // out + optional uint64 memory_reclaimed = 48; // out + optional TVmStat virtual_memory = 49; // out + + optional uint64 shmem_usage = 71; // out, bytes + optional uint64 mlock_usage = 72; // out, bytes + + optional uint64 oom_kills = 50; // out + optional uint64 oom_kills_total = 51; // out + optional bool oom_killed = 52; // out + + optional TUintMap net_bytes = 54; // out + optional TUintMap net_packets = 55; // out + optional TUintMap net_drops = 56; // out + optional TUintMap net_overlimits = 57; // out + optional TUintMap net_rx_bytes = 58; // out + optional TUintMap net_rx_packets = 59; // out + optional TUintMap net_rx_drops = 60; // out + optional TUintMap net_tx_bytes = 61; // out + optional TUintMap net_tx_packets = 62; // out + optional TUintMap net_tx_drops = 63; // out + + optional TContainerVolumeLinks volumes_linked = 64; // out + optional TContainerVolumes volumes_owned = 65; + + repeated TError error = 66; // out + repeated TError warning = 67; // out + repeated TError taint = 68; // out +} + +message TContainerSpec { + optional string name = 1; // required / in client namespace + optional bool weak = 2; + optional string private = 3; + optional TStringMap labels = 4; + + optional string command = 5; + optional TContainerCommandArgv command_argv = 76; + optional TContainerEnv env = 6; + optional TContainerEnv env_secret = 90; // in, out hides values + optional TContainerUlimits ulimit = 7; + optional string core_command = 8; + + optional bool isolate = 9; + optional string virt_mode = 10; + optional string enable_porto = 11; + optional string porto_namespace = 12; + optional string cgroupfs = 78; + optional bool userns = 79; + + optional uint64 aging_time = 13; + + optional TCred task_cred = 14; + optional string user = 15; + optional string group = 16; + + optional TCred owner_cred = 17; + optional string owner_user = 18; + optional string owner_group = 19; + optional string owner_containers = 77; + + optional TCapabilities capabilities = 20; + optional TCapabilities capabilities_ambient = 21; + + optional string root = 22; // in parent namespace + optional bool root_readonly = 23; + optional TContainerBindMounts bind = 24; + optional TStringMap symlink = 25; + optional TContainerDevices devices = 26; + optional TContainerPlaceConfig place = 27; + optional TUintMap place_limit = 28; + + optional string cwd = 29; + optional string stdin_path = 30; + optional string stdout_path = 31; + optional string stderr_path = 32; + optional uint64 stdout_limit = 33; + optional uint32 umask = 34; + + optional bool respawn = 35; + optional uint64 respawn_count = 36; + optional int64 max_respawns = 37; + optional uint64 respawn_delay = 38; + + optional TContainerControllers controllers = 39; + + optional string cpu_policy = 40; // normal|idle|batch|high|rt + optional double cpu_weight = 41; // 0.01 .. 100 + + optional double cpu_guarantee = 42; // in cores + optional double cpu_limit = 43; // in cores + optional double cpu_limit_total = 44; // deprecated (value moved to TContainerStatus) + optional uint64 cpu_period = 45; // ns + + optional TContainerCpuSet cpu_set = 46; + + optional uint64 thread_limit = 47; + + optional string io_policy = 48; // none|rt|high|normal|batch|idle + optional double io_weight = 49; // 0.01 .. 100 + + optional TUintMap io_limit = 50; // bps + optional TUintMap io_guarantee = 84; // bps + optional TUintMap io_ops_limit = 51; // iops + optional TUintMap io_ops_guarantee = 85; // iops + + optional uint64 memory_guarantee = 52; // bytes + + optional uint64 memory_limit = 53; // bytes + + optional uint64 anon_limit = 54; + optional uint64 anon_max_usage = 55; + + optional uint64 dirty_limit = 56; + + optional uint64 hugetlb_limit = 57; + + optional bool recharge_on_pgfault = 58; + optional bool pressurize_on_death = 59; + optional bool anon_only = 60; + + optional int32 oom_score_adj = 61; // -1000 .. +1000 + optional bool oom_is_fatal = 62; + + optional TContainerNetConfig net = 63; + optional TContainerIpLimit ip_limit = 64; + optional TContainerIpConfig ip = 65; + optional TContainerIpConfig default_gw = 66; + optional string hostname = 67; + optional string resolv_conf = 68; + optional string etc_hosts = 69; + optional TStringMap sysctl = 70; + optional TUintMap net_guarantee = 71; // bytes per second + optional TUintMap net_limit = 72; // bytes per second + optional TUintMap net_rx_limit = 73; // bytes per second + + optional TContainerVolumes volumes_required = 75; +} + +message TContainer { + optional TContainerSpec spec = 1; //required + optional TContainerStatus status = 2; + optional TError error = 3; +} + + +// Volumes + +message TVolumeDescription { + required string path = 1; // path in client namespace + map<string, string> properties = 2; + repeated string containers = 3; // linked containers (legacy) + repeated TVolumeLink links = 4; // linked containers with details + + optional uint64 change_time = 5; // sec since epoch + optional bool no_changes = 6; // change_time < changed_since +} + + +message TVolumeLink { + optional string container = 1; + optional string target = 2; // absolute path in container, default: anon + optional bool required = 3; // container cannot work without it + optional bool read_only = 4; + optional string host_target = 5; // out, absolute path in host + optional bool container_root = 6; // in, set container root + optional bool container_cwd = 7; // in, set container cwd +} + +message TVolumeResource { + optional uint64 limit = 1; // bytes or inodes + optional uint64 guarantee = 2; // bytes or inodes + optional uint64 usage = 3; // out, bytes or inodes + optional uint64 available = 4; // out, bytes or inodes +} + +message TVolumeDirectory { + optional string path = 1; // relative path in volume + optional TCred cred = 2; // default: volume cred + optional fixed32 permissions = 3; // default: volume permissions +} + +message TVolumeSymlink { + optional string path = 1; // relative path in volume + optional string target_path = 2; +} + +message TVolumeShare { + optional string path = 1; // relative path in volume + optional string origin_path = 2; // absolute path to origin + optional bool cow = 3; // default: mutable share +} + +// Structured Volume description (Porto V5 API) + +message TVolumeSpec { + optional string path = 1; // path in container, default: auto + optional string container = 2; // defines root for paths, default: self (client container) + repeated TVolumeLink links = 3; // initial links, default: anon link to self + + optional string id = 4; // out + optional string state = 5; // out + + optional string private_value = 6; // at most 4096 bytes + + optional string device_name = 7; // out + + optional string backend = 10; // default: auto + optional string place = 11; // path in host or alias, default from client container + optional string storage = 12; // persistent storage, path or name, default: non-persistent + repeated string layers = 13; // name or path + optional bool read_only = 14; + + // defines root directory user, group and permissions + optional TCred cred = 20; // default: self task cred + optional fixed32 permissions = 21; // default: 0775 + + optional TVolumeResource space = 22; + optional TVolumeResource inodes = 23; + + optional TCred owner = 30; // default: self owner + optional string owner_container = 31; // default: self + optional string place_key = 32; // out, key for place_limit + optional string creator = 33; // out + optional bool auto_path = 34; // out + optional uint32 device_index = 35; // out + optional uint64 build_time = 37; // out, sec since epoch + + // customization at creation + repeated TVolumeDirectory directories = 40; // in + repeated TVolumeSymlink symlinks = 41; // in + repeated TVolumeShare shares = 42; // in + + optional uint64 change_time = 50; // out, sec since epoch + optional bool no_changes = 51; // out, change_time < changed_since +} + + +message TLayer { + optional string name = 1; // name or meta/name + optional string owner_user = 2; + optional string owner_group = 3; + optional uint64 last_usage = 4; // out, sec since last usage + optional string private_value = 5; +} + + +message TStorage { + optional string name = 1; // name or meta/name + optional string owner_user = 2; + optional string owner_group = 3; + optional uint64 last_usage = 4; // out, sec since last usage + optional string private_value = 5; +} + + +message TMetaStorage { + optional string name = 1; + optional string place = 2; + optional string private_value = 3; + optional uint64 space_limit = 4; // bytes + optional uint64 inode_limit = 5; // inodes + + optional uint64 space_used = 6; // out, bytes + optional uint64 space_available = 7; // out, bytes + optional uint64 inode_used = 8; // out, inodes + optional uint64 inode_available = 9; // out, inodes + optional string owner_user = 10; // out + optional string owner_group = 11; // out + optional uint64 last_usage = 12; // out, sec since last usage +} + + +// COMMANDS + +// System + +// Get porto version +message TVersionRequest { +} + +message TVersionResponse { + optional string tag = 1; + optional string revision = 2; +} + + +// Get porto statistics +message TGetSystemRequest { +} + +message TGetSystemResponse { + optional string porto_version = 1; + optional string porto_revision = 2; + optional string kernel_version = 3; + + optional fixed64 errors = 4; + optional fixed64 warnings = 5; + optional fixed64 porto_starts = 6; + optional fixed64 porto_uptime = 7; + optional fixed64 master_uptime = 8; + optional fixed64 taints = 9; + + optional bool frozen = 10; + optional bool verbose = 100; + optional bool debug = 101; + optional fixed64 log_lines = 102; + optional fixed64 log_bytes = 103; + + optional fixed64 stream_rotate_bytes = 104; + optional fixed64 stream_rotate_errors = 105; + + optional fixed64 log_lines_lost = 106; + optional fixed64 log_bytes_lost = 107; + optional fixed64 log_open = 108; + + optional fixed64 container_count = 200; + optional fixed64 container_limit = 201; + optional fixed64 container_running = 202; + optional fixed64 container_created = 203; + optional fixed64 container_started = 204; + optional fixed64 container_start_failed = 205; + optional fixed64 container_oom = 206; + optional fixed64 container_buried = 207; + optional fixed64 container_lost = 208; + optional fixed64 container_tainted = 209; + + optional fixed64 volume_count = 300; + optional fixed64 volume_limit = 301; + optional fixed64 volume_created = 303; + optional fixed64 volume_failed = 304; + optional fixed64 volume_links = 305; + optional fixed64 volume_links_mounted = 306; + optional fixed64 volume_lost = 307; + + optional fixed64 layer_import = 390; + optional fixed64 layer_export = 391; + optional fixed64 layer_remove = 392; + + optional fixed64 client_count = 400; + optional fixed64 client_max = 401; + optional fixed64 client_connected = 402; + + optional fixed64 request_queued = 500; + optional fixed64 request_completed = 501; + optional fixed64 request_failed = 502; + optional fixed64 request_threads = 503; + optional fixed64 request_longer_1s = 504; + optional fixed64 request_longer_3s = 505; + optional fixed64 request_longer_30s = 506; + optional fixed64 request_longer_5m = 507; + + optional fixed64 fail_system = 600; + optional fixed64 fail_invalid_value = 601; + optional fixed64 fail_invalid_command = 602; + optional fixed64 fail_memory_guarantee = 603; + optional fixed64 fail_invalid_netaddr = 604; + + optional fixed64 porto_crash = 666; + + optional fixed64 network_count = 700; + optional fixed64 network_created = 701; + optional fixed64 network_problems = 702; + optional fixed64 network_repairs = 703; +} + + +// Change porto state +message TSetSystemRequest { + optional bool frozen = 10; + optional bool verbose = 100; + optional bool debug = 101; +} + +message TSetSystemResponse { +} + +message TCreateFromSpecRequest { + optional TContainerSpec container = 1; //required + repeated TVolumeSpec volumes = 2; + optional bool start = 3; +} + +message TUpdateFromSpecRequest { + optional TContainerSpec container = 1; //required + optional bool start = 2; +} + +message TListContainersFilter { + optional string name = 1; // name or wildcards, default: all + optional TStringMap labels = 2; + optional uint64 changed_since = 3; // change_time >= changed_since +} + +message TStreamDumpOptions { + optional uint64 stdstream_offset = 2; // default: 0 + optional uint64 stdstream_limit = 3; // default: 8Mb +} + +message TListContainersFieldOptions { + repeated string properties = 1; // property names, default: all + optional TStreamDumpOptions stdout_options = 2; // for GetIndexed stdout + optional TStreamDumpOptions stderr_options = 3; // for GetIndexed stderr +} + +message TListContainersRequest { + repeated TListContainersFilter filters = 1; + optional TListContainersFieldOptions field_options = 2; +} + +message TListContainersResponse { + repeated TContainer containers = 1; +} + +// List available properties +message TListPropertiesRequest { +} + +message TListPropertiesResponse { + message TContainerPropertyListEntry { + optional string name = 1; + optional string desc = 2; + optional bool read_only = 3; + optional bool dynamic = 4; + } + repeated TContainerPropertyListEntry list = 1; +} + + +// deprecated, use ListProperties +message TListDataPropertiesRequest { +} + +message TListDataPropertiesResponse { + message TContainerDataListEntry { + optional string name = 1; + optional string desc = 2; + } + repeated TContainerDataListEntry list = 1; +} + + +// Create stopped container +message TCreateRequest { + optional string name = 1; +} + + +// Stop and destroy container +message TDestroyRequest { + optional string name = 1; +} + + +// List container names +message TListRequest { + optional string mask = 1; + optional uint64 changed_since = 2; // change_time >= changed_since +} + +message TListResponse { + repeated string name = 1; + optional string absolute_namespace = 2; +} + + +// Read one property +message TGetPropertyRequest { + optional string name = 1; + optional string property = 2; + // update cached counters + optional bool sync = 3; + optional bool real = 4; +} + +message TGetPropertyResponse { + optional string value = 1; +} + + +// Alias for GetProperty, deprecated +message TGetDataPropertyRequest { + optional string name = 1; + optional string data = 2; + // update cached counters + optional bool sync = 3; + optional bool real = 4; +} + +message TGetDataPropertyResponse { + optional string value = 1; +} + + +// Change one property +message TSetPropertyRequest { + optional string name = 1; + optional string property = 2; + optional string value = 3; +} + + +// Get multiple properties/data of many containers with one request +message TGetRequest { + // list of containers or wildcards, "***" - all + repeated string name = 1; + + // list of properties/data + repeated string variable = 2; + + // do not wait busy containers + optional bool nonblock = 3; + + // update cached counters + optional bool sync = 4; + optional bool real = 5; + + // change_time >= changed_since + optional uint64 changed_since = 6; +} + +message TGetResponse { + message TContainerGetValueResponse { + optional string variable = 1; + optional EError error = 2; + optional string errorMsg = 3; + optional string value = 4; + } + + message TContainerGetListResponse { + optional string name = 1; + repeated TContainerGetValueResponse keyval = 2; + + optional uint64 change_time = 3; + optional bool no_changes = 4; // change_time < changed_since + } + + repeated TContainerGetListResponse list = 1; +} + + +// Start stopped container +message TStartRequest { + optional string name = 1; +} + + +// Restart dead container +message TRespawnRequest { + optional string name = 1; +} + + +// Stop dead or running container +message TStopRequest { + optional string name = 1; + // Timeout in 1/1000 seconds between SIGTERM and SIGKILL, default 30s + optional uint32 timeout_ms = 2; +} + + +// Freeze running container +message TPauseRequest { + optional string name = 1; +} + + +// Unfreeze paused container +message TResumeRequest { + optional string name = 1; +} + + +// Translate filesystem path between containers +message TConvertPathRequest { + optional string path = 1; + optional string source = 2; + optional string destination = 3; +} + +message TConvertPathResponse { + optional string path = 1; +} + + +// Wait while container(s) is/are in running state +message TWaitRequest { + // list of containers or wildcards, "***" - all + repeated string name = 1; + + // timeout in 1/1000 seconds, 0 - nonblock + optional uint32 timeout_ms = 2; + + // list of label names or wildcards + repeated string label = 3; + + // async wait with target_state works only once + optional string target_state = 4; +} + +message TWaitResponse { + optional string name = 1; // container name + optional string state = 2; // container state or "timeout" + optional uint64 when = 3; // unix time stamp in seconds + optional string label = 4; + optional string value = 5; +} + + +// Send signal main process in container +message TKillRequest { + optional string name = 1; + optional int32 sig = 2; +} + + +// Move process into container +message TAttachProcessRequest { + optional string name = 1; + optional uint32 pid = 2; + optional string comm = 3; // ignored if empty +} + + +// Determine container by pid +message TLocateProcessRequest { + optional uint32 pid = 1; + optional string comm = 2; // ignored if empty +} + +message TLocateProcessResponse { + optional string name = 1; +} + + +// Labels + + +message TFindLabelRequest { + optional string mask = 1; // containers name or wildcard + optional string state = 2; // filter by container state + optional string label = 3; // label name or wildcard + optional string value = 4; // filter by label value +} + +message TFindLabelResponse { + message TFindLabelEntry { + optional string name = 1; + optional string state = 2; + optional string label = 3; + optional string value = 4; + } + repeated TFindLabelEntry list = 1; +} + + +message TSetLabelRequest { + optional string name = 1; + optional string label = 2; + optional string value = 3; + optional string prev_value = 4; // fail with Busy if does not match + optional string state = 5; // fail with InvalidState if not match +} + +message TSetLabelResponse { + optional string prev_value = 1; + optional string state = 2; +} + + +message TIncLabelRequest { + optional string name = 1; + optional string label = 2; // missing label starts from 0 + optional int64 add = 3 [ default = 1]; +} + +message TIncLabelResponse { + optional int64 result = 1; +} + + +message TSetSymlinkRequest { + optional string container = 1; + optional string symlink = 2; + optional string target = 3; +} + + +// Volumes + + +message TNewVolumeRequest { + optional TVolumeSpec volume = 1; +} + +message TNewVolumeResponse { + optional TVolumeSpec volume = 1; +} + + +message TGetVolumeRequest { + optional string container = 1; // get paths in container, default: self (client container) + repeated string path = 2; // volume path in container, default: all + optional uint64 changed_since = 3; // change_time >= changed_since + repeated string label = 4; // labels or wildcards +} + +message TGetVolumeResponse { + repeated TVolumeSpec volume = 1; +} + + +// List available volume properties +message TListVolumePropertiesRequest { +} + +message TListVolumePropertiesResponse { + message TVolumePropertyDescription { + optional string name = 1; + optional string desc = 2; + } + repeated TVolumePropertyDescription list = 1; +} + + +// Create new volume +// "createVolume" returns TVolumeDescription in "volume" +message TCreateVolumeRequest { + optional string path = 1; + map<string, string> properties = 2; +} + + +message TLinkVolumeRequest { + optional string path = 1; + optional string container = 2; // default - self (client container) + optional string target = 3; // path in container, "" - anon + optional bool required = 4; // stop container at fail + optional bool read_only = 5; +} + + +message TUnlinkVolumeRequest { + optional string path = 1; + optional string container = 2; // default - self, "***" - all + optional bool strict = 3; // non-lazy umount + optional string target = 4; // path in container, "" - anon, default - "***" - all +} + + +message TListVolumesRequest { + optional string path = 1; + optional string container = 2; + optional uint64 changed_since = 3; // change_time >= changed_since +} + +message TListVolumesResponse { + repeated TVolumeDescription volumes = 1; +} + + +message TTuneVolumeRequest { + optional string path = 1; + map<string, string> properties = 2; +} + +// Layers + + +message TListLayersRequest { + optional string place = 1; // default from client container + optional string mask = 2; +} + +message TListLayersResponse { + repeated string layer = 1; // layer names (legacy) + repeated TLayer layers = 2; // layer with description +} + + +message TImportLayerRequest { + optional string layer = 1; + optional string tarball = 2; + optional bool merge = 3; + optional string place = 4; + optional string private_value = 5; + optional string compress = 6; + optional bool verbose_error = 7; +} + + +message TExportLayerRequest { + optional string volume = 1; + optional string tarball = 2; + optional string layer = 3; + optional string place = 4; + optional string compress = 5; +} + + +message TRemoveLayerRequest { + optional string layer = 1; + optional string place = 2; + optional bool async = 3; +} + + +message TGetLayerPrivateRequest { + optional string layer = 1; + optional string place = 2; +} + +message TGetLayerPrivateResponse { + optional string private_value = 1; +} + + +message TSetLayerPrivateRequest { + optional string layer = 1; + optional string place = 2; + optional string private_value = 3; +} + + +// Storages + + +message TListStoragesRequest { + optional string place = 1; + optional string mask = 2; // "name" - storage, "name/" - meta-storage +} + +message TListStoragesResponse { + repeated TStorage storages = 1; + repeated TMetaStorage meta_storages = 2; +} + + +message TRemoveStorageRequest { + optional string name = 1; + optional string place = 2; +} + + +message TImportStorageRequest { + optional string name = 1; + optional string tarball = 2; + optional string place = 3; + optional string private_value = 5; + optional string compress = 6; +} + + +message TExportStorageRequest { + optional string name = 1; + optional string tarball = 2; + optional string place = 3; + optional string compress = 4; +} + + +// Docker images API + + +message TDockerImageConfig { + repeated string cmd = 1; + repeated string env = 2; +} + +message TDockerImage { + required string id = 1; + repeated string tags = 2; + repeated string digests = 3; + repeated string layers = 4; + optional uint64 size = 5; + optional TDockerImageConfig config = 6; +} + + +message TDockerImageStatusRequest { + required string name = 1; + optional string place = 2; +} + +message TDockerImageStatusResponse { + optional TDockerImage image = 1; +} + + +message TDockerImageListRequest { + optional string place = 1; + optional string mask = 2; +} + +message TDockerImageListResponse { + repeated TDockerImage images = 1; +} + + +message TDockerImagePullRequest { + required string name = 1; + optional string place = 2; + optional string auth_token = 3; + optional string auth_path = 4; + optional string auth_service = 5; +} + +message TDockerImagePullResponse { + optional TDockerImage image = 1; +} + + +message TDockerImageRemoveRequest { + required string name = 1; + optional string place = 2; +} diff --git a/library/cpp/porto/proto/ya.make b/library/cpp/porto/proto/ya.make new file mode 100644 index 0000000000..525a807ee0 --- /dev/null +++ b/library/cpp/porto/proto/ya.make @@ -0,0 +1,5 @@ +PROTO_LIBRARY() +INCLUDE_TAGS(GO_PROTO) +SRCS(rpc.proto) +END() + diff --git a/library/cpp/porto/ya.make b/library/cpp/porto/ya.make new file mode 100644 index 0000000000..e1ccbac281 --- /dev/null +++ b/library/cpp/porto/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +BUILD_ONLY_IF(WARNING WARNING LINUX) + +PEERDIR( + library/cpp/porto/proto + contrib/libs/protobuf +) + +SRCS( + libporto.cpp + metrics.cpp +) + +END() + +RECURSE_FOR_TESTS(ut) diff --git a/library/cpp/yt/mlock/README.md b/library/cpp/yt/mlock/README.md new file mode 100644 index 0000000000..b61b6072c4 --- /dev/null +++ b/library/cpp/yt/mlock/README.md @@ -0,0 +1,11 @@ +# mlock + +MlockFileMappings подгружает и лочит в память все страницы исполняемого файла. + +В отличии от вызова mlockall, функция не лочит другие страницы процесса. +mlockall явно выделяет физическую память под все vma. Типичный процесс сначала +стартует и инициализирует аллокатор, а потом уже вызывает функцию для mlock страниц. +Аллокатор при старте выделяет большие диапазоны через mmap, но реально их не использует. +Поэтому mlockall приводит в повышенному потреблению памяти. + +Также, в отличии от mlockall, функция может подгрузить страницы в память сразу. diff --git a/library/cpp/yt/mlock/mlock.h b/library/cpp/yt/mlock/mlock.h new file mode 100644 index 0000000000..035fc47e37 --- /dev/null +++ b/library/cpp/yt/mlock/mlock.h @@ -0,0 +1,11 @@ +#pragma once + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +bool MlockFileMappings(bool populate = true); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/mlock/mlock_linux.cpp b/library/cpp/yt/mlock/mlock_linux.cpp new file mode 100644 index 0000000000..8791869f95 --- /dev/null +++ b/library/cpp/yt/mlock/mlock_linux.cpp @@ -0,0 +1,89 @@ +#include "mlock.h" + +#include <stdio.h> +#include <sys/mman.h> +#include <stdint.h> +#include <inttypes.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void PopulateFile(void* ptr, size_t size) +{ + constexpr size_t PageSize = 4096; + + auto* begin = static_cast<volatile char*>(ptr); + for (auto* current = begin; current < begin + size; current += PageSize) { + *current; + } +} + +bool MlockFileMappings(bool populate) +{ + auto* file = ::fopen("/proc/self/maps", "r"); + if (!file) { + return false; + } + + // Each line of /proc/<pid>/smaps has the following format: + // address perms offset dev inode path + // E.g. + // 08048000-08056000 r-xp 00000000 03:0c 64593 /usr/sbin/gpm + + bool failed = false; + while (true) { + char line[1024]; + if (!fgets(line, sizeof(line), file)) { + break; + } + + char addressStr[64]; + char permsStr[64]; + char offsetStr[64]; + char devStr[64]; + int inode; + if (sscanf(line, "%s %s %s %s %d", + addressStr, + permsStr, + offsetStr, + devStr, + &inode) != 5) + { + continue; + } + + if (inode == 0) { + continue; + } + + if (permsStr[0] != 'r') { + continue; + } + + uintptr_t startAddress; + uintptr_t endAddress; + if (sscanf(addressStr, "%" PRIx64 "-%" PRIx64, + &startAddress, + &endAddress) != 2) + { + continue; + } + + if (::mlock(reinterpret_cast<const void*>(startAddress), endAddress - startAddress) != 0) { + failed = true; + continue; + } + + if (populate) { + PopulateFile(reinterpret_cast<void*>(startAddress), endAddress - startAddress); + } + } + + ::fclose(file); + return !failed; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/mlock/mlock_other.cpp b/library/cpp/yt/mlock/mlock_other.cpp new file mode 100644 index 0000000000..269c5c3cb9 --- /dev/null +++ b/library/cpp/yt/mlock/mlock_other.cpp @@ -0,0 +1,14 @@ +#include "mlock.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +bool MlockFileMappings(bool /* populate */) +{ + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/mlock/ya.make b/library/cpp/yt/mlock/ya.make new file mode 100644 index 0000000000..2603d128ed --- /dev/null +++ b/library/cpp/yt/mlock/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/library/cpp/yt/ya_cpp.make.inc) + +IF (OS_LINUX AND NOT SANITIZER_TYPE) + SRCS(mlock_linux.cpp) +ELSE() + SRCS(mlock_other.cpp) +ENDIF() + +END() + +IF (OS_LINUX AND NOT SANITIZER_TYPE) + RECURSE(unittests) +ENDIF() + diff --git a/library/cpp/yt/stockpile/README.md b/library/cpp/yt/stockpile/README.md new file mode 100644 index 0000000000..6ee4cd1b1f --- /dev/null +++ b/library/cpp/yt/stockpile/README.md @@ -0,0 +1,12 @@ +# stockpile + +При приближении к лимиту памяти в memory cgroup, linux запускает механизм direct reclaim, +чтобы освободить свободную память. По опыту YT, direct reclaim очень сильно замедляет работу +всего процесса. + +Проблема возникает не только, когда память занята анонимными страницами. 50% памяти контейнера +может быть занято не dirty страницами page cache, но проблема всёравно будет проявляться. Например, +если процесс активно читает файлы с диска без O_DIRECT, вся память очень быстро будет забита. + +Чтобы бороться с этой проблемой, в яндексовом ядре добавлена ручка `madvise(MADV_STOCKPILE)`. +Больше подробностей в https://st.yandex-team.ru/KERNEL-186
\ No newline at end of file diff --git a/library/cpp/yt/stockpile/stockpile.h b/library/cpp/yt/stockpile/stockpile.h new file mode 100644 index 0000000000..1df9591de4 --- /dev/null +++ b/library/cpp/yt/stockpile/stockpile.h @@ -0,0 +1,29 @@ +#pragma once + +#include <util/system/types.h> + +#include <util/generic/size_literals.h> + +#include <util/datetime/base.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +struct TStockpileOptions +{ + static constexpr i64 DefaultBufferSize = 4_GBs; + i64 BufferSize = DefaultBufferSize; + + static constexpr int DefaultThreadCount = 4; + int ThreadCount = DefaultThreadCount; + + static constexpr TDuration DefaultPeriod = TDuration::MilliSeconds(10); + TDuration Period = DefaultPeriod; +}; + +void ConfigureStockpile(const TStockpileOptions& options); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/stockpile/stockpile_linux.cpp b/library/cpp/yt/stockpile/stockpile_linux.cpp new file mode 100644 index 0000000000..3ee83d9334 --- /dev/null +++ b/library/cpp/yt/stockpile/stockpile_linux.cpp @@ -0,0 +1,42 @@ +#include "stockpile.h" + +#include <thread> +#include <mutex> + +#include <sys/mman.h> + +#include <util/system/thread.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void RunStockpile(const TStockpileOptions& options) +{ + TThread::SetCurrentThreadName("Stockpile"); + + constexpr int MADV_STOCKPILE = 0x59410004; + + while (true) { + ::madvise(nullptr, options.BufferSize, MADV_STOCKPILE); + Sleep(options.Period); + } +} + +} // namespace + +void ConfigureStockpile(const TStockpileOptions& options) +{ + static std::once_flag OnceFlag; + std::call_once(OnceFlag, [options] { + for (int i = 0; i < options.ThreadCount; i++) { + std::thread(RunStockpile, options).detach(); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/stockpile/stockpile_other.cpp b/library/cpp/yt/stockpile/stockpile_other.cpp new file mode 100644 index 0000000000..3495d9c1cb --- /dev/null +++ b/library/cpp/yt/stockpile/stockpile_other.cpp @@ -0,0 +1,12 @@ +#include "stockpile.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void ConfigureStockpile(const TStockpileOptions& /*options*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/stockpile/ya.make b/library/cpp/yt/stockpile/ya.make new file mode 100644 index 0000000000..39d51aaf97 --- /dev/null +++ b/library/cpp/yt/stockpile/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/library/cpp/yt/ya_cpp.make.inc) + +IF (OS_LINUX AND NOT SANITIZER_TYPE) + SRCS(stockpile_linux.cpp) +ELSE() + SRCS(stockpile_other.cpp) +ENDIF() + +END() diff --git a/yt/yt/client/cell_master_client/public.h b/yt/yt/client/cell_master_client/public.h new file mode 100644 index 0000000000..df4920bf0d --- /dev/null +++ b/yt/yt/client/cell_master_client/public.h @@ -0,0 +1,40 @@ +#pragma once + +#include "public.h" + +#include <library/cpp/yt/misc/enum.h> + +namespace NYT::NCellMasterClient { + +/////////////////////////////////////////////////////////////////////////////// + +namespace NProto { + +class TCellDirectory; + +} // namespace NProto + +//////////////////////////////////////////////////////////////////////////////// + +// Keep these two enums consistent. + +DEFINE_BIT_ENUM(EMasterCellRoles, + ((None) (0x0000)) + ((CypressNodeHost) (0x0001)) + ((TransactionCoordinator) (0x0002)) + ((ChunkHost) (0x0004)) + ((DedicatedChunkHost) (0x0008)) + ((ExTransactionCoordinator) (0x0010)) +); + +DEFINE_ENUM(EMasterCellRole, + ((CypressNodeHost) (0x0001)) + ((TransactionCoordinator) (0x0002)) + ((ChunkHost) (0x0004)) + ((DedicatedChunkHost) (0x0008)) + ((ExTransactionCoordinator) (0x0010)) +); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NCellMasterClient diff --git a/yt/yt/client/chaos_client/helpers.h b/yt/yt/client/chaos_client/helpers.h new file mode 100644 index 0000000000..b007fc8902 --- /dev/null +++ b/yt/yt/client/chaos_client/helpers.h @@ -0,0 +1,22 @@ +#pragma once + +#include "public.h" + +namespace NYT::NChaosClient { + +//////////////////////////////////////////////////////////////////////////////// + +TReplicationCardId MakeReplicationCardId(NObjectClient::TObjectId randomId); +TReplicaId MakeReplicaId(TReplicationCardId replicationCardId, TReplicaIdIndex index); +TReplicationCardId ReplicationCardIdFromReplicaId(TReplicaId replicaId); +TReplicationCardId ReplicationCardIdFromUpstreamReplicaIdOrNull(TReplicaId upstreamReplicaId); +TReplicationCardId MakeReplicationCardCollocationId(NObjectClient::TObjectId randomId); + +NObjectClient::TCellTag GetSiblingChaosCellTag(NObjectClient::TCellTag cellTag); + +bool IsOrderedTabletReplicationProgress(const TReplicationProgress& progress); +void ValidateOrderedTabletReplicationProgress(const TReplicationProgress& progress); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NTabletClient diff --git a/yt/yt/client/formats/arrow_writer.cpp b/yt/yt/client/formats/arrow_writer.cpp new file mode 100644 index 0000000000..352587348c --- /dev/null +++ b/yt/yt/client/formats/arrow_writer.cpp @@ -0,0 +1,1065 @@ +#include "arrow_writer.h" + +#include "public.h" +#include "schemaless_writer_adapter.h" + +#include <yt/yt/client/arrow/fbs/Message.fbs.h> +#include <yt/yt/client/arrow/fbs/Schema.fbs.h> + +#include <yt/yt/client/table_client/columnar.h> +#include <yt/yt/client/table_client/logical_type.h> +#include <yt/yt/client/table_client/name_table.h> +#include <yt/yt/client/table_client/public.h> +#include <yt/yt/client/table_client/row_batch.h> +#include <yt/yt/client/table_client/schema.h> + +#include <yt/yt/library/column_converters/column_converter.h> + +#include <yt/yt/core/concurrency/async_stream.h> +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/misc/blob_output.h> +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/range.h> + +#include <vector> + +namespace NYT::NFormats { + +using namespace NTableClient; +using namespace NComplexTypes; + +static const auto& Logger = FormatsLogger; + +using TBodyWriter = std::function<void(TMutableRef)>; +using TBatchColumn = IUnversionedColumnarRowBatch::TColumn; + +//////////////////////////////////////////////////////////////////////////////// + +struct TTypedBatchColumn +{ + const TBatchColumn* Column; + TLogicalTypePtr Type; +}; + +//////////////////////////////////////////////////////////////////////////////// + +constexpr i64 ArrowAlignment = 8; + +flatbuffers::Offset<flatbuffers::String> SerializeString( + flatbuffers::FlatBufferBuilder* flatbufBuilder, + const TString& str) +{ + return flatbufBuilder->CreateString(str.data(), str.length()); +} + +std::tuple<org::apache::arrow::flatbuf::Type, flatbuffers::Offset<void>> SerializeColumnType( + flatbuffers::FlatBufferBuilder* flatbufBuilder, + TColumnSchema schema) +{ + auto simpleType = CastToV1Type(schema.LogicalType()).first; + switch (simpleType) { + case ESimpleLogicalValueType::Null: + return std::make_tuple( + org::apache::arrow::flatbuf::Type_Null, + org::apache::arrow::flatbuf::CreateNull(*flatbufBuilder) + .Union()); + + case ESimpleLogicalValueType::Int64: + case ESimpleLogicalValueType::Uint64: + case ESimpleLogicalValueType::Int8: + case ESimpleLogicalValueType::Uint8: + case ESimpleLogicalValueType::Int16: + case ESimpleLogicalValueType::Uint16: + case ESimpleLogicalValueType::Int32: + case ESimpleLogicalValueType::Uint32: + return std::make_tuple( + org::apache::arrow::flatbuf::Type_Int, + org::apache::arrow::flatbuf::CreateInt( + *flatbufBuilder, + GetIntegralTypeBitWidth(simpleType), + IsIntegralTypeSigned(simpleType)) + .Union()); + + case ESimpleLogicalValueType::Double: + return std::make_tuple( + org::apache::arrow::flatbuf::Type_FloatingPoint, + org::apache::arrow::flatbuf::CreateFloatingPoint( + *flatbufBuilder, + org::apache::arrow::flatbuf::Precision_DOUBLE) + .Union()); + + case ESimpleLogicalValueType::Boolean: + return std::make_tuple( + org::apache::arrow::flatbuf::Type_Bool, + org::apache::arrow::flatbuf::CreateBool(*flatbufBuilder) + .Union()); + + case ESimpleLogicalValueType::String: + case ESimpleLogicalValueType::Any: + return std::make_tuple( + org::apache::arrow::flatbuf::Type_Binary, + org::apache::arrow::flatbuf::CreateBinary(*flatbufBuilder) + .Union()); + + case ESimpleLogicalValueType::Utf8: + return std::make_tuple( + org::apache::arrow::flatbuf::Type_Utf8, + org::apache::arrow::flatbuf::CreateUtf8(*flatbufBuilder) + .Union()); + + // TODO(babenko): the following types are not supported: + // Date + // Datetime + // Interval + // Timestamp + + default: + THROW_ERROR_EXCEPTION("Column %v has type %Qlv that is not currently supported by Arrow encoder", + schema.GetDiagnosticNameString(), + simpleType); + } +} + +bool IsRleButNotDictionaryEncodedStringLikeColumn(const TBatchColumn& column) +{ + auto simpleType = CastToV1Type(column.Type).first; + return IsStringLikeType(simpleType) && + column.Rle && + !column.Rle->ValueColumn->Dictionary; +} + +bool IsRleAndDictionaryEncodedColumn(const TBatchColumn& column) +{ + return column.Rle && + column.Rle->ValueColumn->Dictionary; +} + +bool IsDictionaryEncodedColumn(const TBatchColumn& column) +{ + return column.Dictionary || + IsRleAndDictionaryEncodedColumn(column) || + IsRleButNotDictionaryEncodedStringLikeColumn(column); +} + + +struct TRecordBatchBodyPart +{ + i64 Size; + TBodyWriter Writer; +}; + +struct TRecordBatchSerializationContext final +{ + explicit TRecordBatchSerializationContext(flatbuffers::FlatBufferBuilder* flatbufBuilder) + : FlatbufBuilder(flatbufBuilder) + {} + + void AddFieldNode(i64 length, i64 nullCount) + { + FieldNodes.emplace_back(length, nullCount); + } + + void AddBuffer(i64 size, TBodyWriter writer) + { + YT_LOG_DEBUG("Buffer registered (Offset: %v, Size: %v)", + CurrentBodyOffset, + size); + + Buffers.emplace_back(CurrentBodyOffset, size); + CurrentBodyOffset += AlignUp<i64>(size, ArrowAlignment); + Parts.push_back(TRecordBatchBodyPart{size, std::move(writer)}); + } + + flatbuffers::FlatBufferBuilder* const FlatbufBuilder; + + i64 CurrentBodyOffset = 0; + std::vector<org::apache::arrow::flatbuf::FieldNode> FieldNodes; + std::vector<org::apache::arrow::flatbuf::Buffer> Buffers; + std::vector<TRecordBatchBodyPart> Parts; +}; + +template <class T> +TMutableRange<T> GetTypedValues(TMutableRef ref) +{ + return MakeMutableRange( + reinterpret_cast<T*>(ref.Begin()), + reinterpret_cast<T*>(ref.End())); +} + +void SerializeColumnPrologue( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + if (column->NullBitmap || + column->Rle && column->Rle->ValueColumn->NullBitmap) + { + if (column->Rle) { + const auto* valueColumn = column->Rle->ValueColumn; + auto rleIndexes = column->GetTypedValues<ui64>(); + + context->AddFieldNode( + column->ValueCount, + CountOnesInRleBitmap( + valueColumn->NullBitmap->Data, + rleIndexes, + column->StartIndex, + column->StartIndex + column->ValueCount)); + + context->AddBuffer( + GetBitmapByteSize(column->ValueCount), + [=] (TMutableRef dstRef) { + BuildValidityBitmapFromRleNullBitmap( + valueColumn->NullBitmap->Data, + rleIndexes, + column->StartIndex, + column->StartIndex + column->ValueCount, + dstRef); + }); + } else { + context->AddFieldNode( + column->ValueCount, + CountOnesInBitmap( + column->NullBitmap->Data, + column->StartIndex, + column->StartIndex + column->ValueCount)); + + context->AddBuffer( + GetBitmapByteSize(column->ValueCount), + [=] (TMutableRef dstRef) { + CopyBitmapRangeToBitmapNegated( + column->NullBitmap->Data, + column->StartIndex, + column->StartIndex + column->ValueCount, + dstRef); + }); + } + } else { + context->AddFieldNode( + column->ValueCount, + 0); + + context->AddBuffer( + 0, + [=] (TMutableRef /*dstRef*/) { + }); + } +} + +void SerializeRleButNotDictionaryEncodedStringLikeColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + YT_VERIFY(column->Values->BitWidth == 64); + YT_VERIFY(column->Values->BaseValue == 0); + YT_VERIFY(!column->Values->ZigZagEncoded); + + YT_LOG_DEBUG("Adding RLE but not dictionary-encoded string-like column (ColumnId: %v, StartIndex: %v, ValueCount: %v)", + column->Id, + column->StartIndex, + column->ValueCount); + + SerializeColumnPrologue(typedColumn, context); + + auto rleIndexes = column->GetTypedValues<ui64>(); + + context->AddBuffer( + sizeof(ui32) * column->ValueCount, + [=] (TMutableRef dstRef) { + BuildIotaDictionaryIndexesFromRleIndexes( + rleIndexes, + column->StartIndex, + column->StartIndex + column->ValueCount, + GetTypedValues<ui32>(dstRef)); + }); +} + +void SerializeDictionaryColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + YT_VERIFY(column->Dictionary->ZeroMeansNull); + YT_VERIFY(column->Values->BitWidth == 32); + YT_VERIFY(column->Values->BaseValue == 0); + YT_VERIFY(!column->Values->ZigZagEncoded); + + YT_LOG_DEBUG("Adding dictionary column (ColumnId: %v, StartIndex: %v, ValueCount: %v, Rle: %v)", + column->Id, + column->StartIndex, + column->ValueCount, + column->Rle.has_value()); + + auto relevantDictionaryIndexes = column->GetRelevantTypedValues<ui32>(); + + context->AddFieldNode( + column->ValueCount, + CountNullsInDictionaryIndexesWithZeroNull(relevantDictionaryIndexes)); + + context->AddBuffer( + GetBitmapByteSize(column->ValueCount), + [=] (TMutableRef dstRef) { + BuildValidityBitmapFromDictionaryIndexesWithZeroNull( + relevantDictionaryIndexes, + dstRef); + }); + + context->AddBuffer( + sizeof(ui32) * column->ValueCount, + [=] (TMutableRef dstRef) { + BuildDictionaryIndexesFromDictionaryIndexesWithZeroNull( + relevantDictionaryIndexes, + GetTypedValues<ui32>(dstRef)); + }); +} + +void SerializeRleDictionaryColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + YT_VERIFY(column->Values->BitWidth == 64); + YT_VERIFY(column->Values->BaseValue == 0); + YT_VERIFY(!column->Values->ZigZagEncoded); + YT_VERIFY(column->Rle->ValueColumn->Dictionary->ZeroMeansNull); + YT_VERIFY(column->Rle->ValueColumn->Values->BitWidth == 32); + YT_VERIFY(column->Rle->ValueColumn->Values->BaseValue == 0); + YT_VERIFY(!column->Rle->ValueColumn->Values->ZigZagEncoded); + + YT_LOG_DEBUG("Adding dictionary column (ColumnId: %v, StartIndex: %v, ValueCount: %v, Rle: %v)", + column->Id, + column->StartIndex, + column->ValueCount, + column->Rle.has_value()); + + auto dictionaryIndexes = column->Rle->ValueColumn->GetTypedValues<ui32>(); + auto rleIndexes = column->GetTypedValues<ui64>(); + + context->AddFieldNode( + column->ValueCount, + CountNullsInRleDictionaryIndexesWithZeroNull( + dictionaryIndexes, + rleIndexes, + column->StartIndex, + column->StartIndex + column->ValueCount)); + + context->AddBuffer( + GetBitmapByteSize(column->ValueCount), + [=] (TMutableRef dstRef) { + BuildValidityBitmapFromRleDictionaryIndexesWithZeroNull( + dictionaryIndexes, + rleIndexes, + column->StartIndex, + column->StartIndex + column->ValueCount, + dstRef); + }); + + context->AddBuffer( + sizeof(ui32) * column->ValueCount, + [=] (TMutableRef dstRef) { + BuildDictionaryIndexesFromRleDictionaryIndexesWithZeroNull( + dictionaryIndexes, + rleIndexes, + column->StartIndex, + column->StartIndex + column->ValueCount, + GetTypedValues<ui32>(dstRef)); + }); +} + +void SerializeIntegerColumn( + const TTypedBatchColumn& typedColumn, + ESimpleLogicalValueType simpleType, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + + YT_LOG_DEBUG("Adding integer column (ColumnId: %v, StartIndex: %v, ValueCount: %v, Rle: %v)", + column->Id, + column->StartIndex, + column->ValueCount, + column->Rle.has_value()); + + SerializeColumnPrologue(typedColumn, context); + + context->AddBuffer( + column->ValueCount * GetIntegralTypeByteSize(simpleType), + [=] (TMutableRef dstRef) { + const auto* valueColumn = column->Rle + ? column->Rle->ValueColumn + : column; + auto values = valueColumn->GetTypedValues<ui64>(); + + auto rleIndexes = column->Rle + ? column->GetTypedValues<ui64>() + : TRange<ui64>(); + + switch (simpleType) { +#define XX(cppType, ytType) \ + case ESimpleLogicalValueType::ytType: { \ + auto dstValues = GetTypedValues<cppType>(dstRef); \ + auto* currentOutput = dstValues.Begin(); \ + DecodeIntegerVector( \ + column->StartIndex, \ + column->StartIndex + column->ValueCount, \ + valueColumn->Values->BaseValue, \ + valueColumn->Values->ZigZagEncoded, \ + TRange<ui32>(), \ + rleIndexes, \ + [&] (auto index) { \ + return values[index]; \ + }, \ + [&] (auto value) { \ + *currentOutput++ = value; \ + }); \ + break; \ + } + + XX(i8, Int8) + XX(i16, Int16) + XX(i32, Int32) + XX(i64, Int64) + XX(ui8, Uint8) + XX(ui16, Uint16) + XX(ui32, Uint32) + XX(ui64, Uint64) + +#undef XX + + default: + THROW_ERROR_EXCEPTION("Integer column %v has unexpected type %Qlv", + typedColumn.Column->Id, + simpleType); + } + }); +} + +void SerializeDoubleColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + YT_VERIFY(column->Values->BitWidth == 64); + YT_VERIFY(column->Values->BaseValue == 0); + YT_VERIFY(!column->Values->ZigZagEncoded); + + YT_LOG_DEBUG("Adding double column (ColumnId: %v, StartIndex: %v, ValueCount: %v)", + column->Id, + column->StartIndex, + column->ValueCount, + column->Rle.has_value()); + + SerializeColumnPrologue(typedColumn, context); + + context->AddBuffer( + column->ValueCount * sizeof(double), + [=] (TMutableRef dstRef) { + auto relevantValues = column->GetRelevantTypedValues<double>(); + ::memcpy( + dstRef.Begin(), + relevantValues.Begin(), + column->ValueCount * sizeof(double)); + }); +} + +void SerializeStringLikeColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + YT_VERIFY(column->Values->BaseValue == 0); + YT_VERIFY(column->Values->BitWidth == 32); + YT_VERIFY(column->Values->ZigZagEncoded); + YT_VERIFY(column->Strings); + YT_VERIFY(column->Strings->AvgLength); + YT_VERIFY(!column->Rle); + + auto startIndex = column->StartIndex; + auto endIndex = startIndex + column->ValueCount; + auto stringData = column->Strings->Data; + auto avgLength = *column->Strings->AvgLength; + + auto offsets = column->GetTypedValues<ui32>(); + auto startOffset = DecodeStringOffset(offsets, avgLength, startIndex); + auto endOffset = DecodeStringOffset(offsets, avgLength, endIndex); + auto stringsSize = endOffset - startOffset; + + YT_LOG_DEBUG("Adding string-like column (ColumnId: %v, StartIndex: %v, ValueCount: %v, StartOffset: %v, EndOffset: %v, StringsSize: %v)", + column->Id, + column->StartIndex, + column->ValueCount, + startOffset, + endOffset, + stringsSize); + + SerializeColumnPrologue(typedColumn, context); + + context->AddBuffer( + sizeof(i32) * (column->ValueCount + 1), + [=] (TMutableRef dstRef) { + DecodeStringOffsets( + offsets, + avgLength, + startIndex, + endIndex, + GetTypedValues<ui32>(dstRef)); + }); + + context->AddBuffer( + stringsSize, + [=] (TMutableRef dstRef) { + ::memcpy( + dstRef.Begin(), + stringData.Begin() + startOffset, + stringsSize); + }); +} + +void SerializeBooleanColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + YT_VERIFY(column->Values); + YT_VERIFY(!column->Values->ZigZagEncoded); + YT_VERIFY(column->Values->BaseValue == 0); + YT_VERIFY(column->Values->BitWidth == 1); + + YT_LOG_DEBUG("Adding boolean column (ColumnId: %v, StartIndex: %v, ValueCount: %v)", + column->Id, + column->StartIndex, + column->ValueCount); + + SerializeColumnPrologue(typedColumn, context); + + context->AddBuffer( + GetBitmapByteSize(column->ValueCount), + [=] (TMutableRef dstRef) { + CopyBitmapRangeToBitmap( + column->Values->Data, + column->StartIndex, + column->StartIndex + column->ValueCount, + dstRef); + }); +} + +void SerializeColumn( + const TTypedBatchColumn& typedColumn, + TRecordBatchSerializationContext* context) +{ + const auto* column = typedColumn.Column; + + if (IsRleButNotDictionaryEncodedStringLikeColumn(*typedColumn.Column)) { + SerializeRleButNotDictionaryEncodedStringLikeColumn(typedColumn, context); + return; + } + + if (column->Dictionary) { + SerializeDictionaryColumn(typedColumn, context); + return; + } + + if (column->Rle && column->Rle->ValueColumn->Dictionary) { + SerializeRleDictionaryColumn(typedColumn, context); + return; + } + + auto simpleType = CastToV1Type(typedColumn.Type).first; + if (IsIntegralType(simpleType)) { + SerializeIntegerColumn(typedColumn, simpleType, context); + } else if (simpleType == ESimpleLogicalValueType::Double) { + SerializeDoubleColumn(typedColumn, context); + } else if (IsStringLikeType(simpleType)) { + SerializeStringLikeColumn(typedColumn, context); + } else if (simpleType == ESimpleLogicalValueType::Boolean) { + SerializeBooleanColumn(typedColumn, context); + } else if (simpleType == ESimpleLogicalValueType::Null) { + // No buffers are allocated for null columns. + } else { + THROW_ERROR_EXCEPTION("Column %v has unexpected type %Qlv", + typedColumn.Column->Id, + simpleType); + } +} + +auto SerializeRecordBatch( + flatbuffers::FlatBufferBuilder* flatbufBuilder, + int length, + TRange<TTypedBatchColumn> typedColumns) +{ + auto context = New<TRecordBatchSerializationContext>(flatbufBuilder); + + for (const auto& typedColumn : typedColumns) { + SerializeColumn(typedColumn, context.Get()); + } + + auto fieldNodesOffset = flatbufBuilder->CreateVectorOfStructs(context->FieldNodes); + + auto buffersOffset = flatbufBuilder->CreateVectorOfStructs(context->Buffers); + + auto recordBatchOffset = org::apache::arrow::flatbuf::CreateRecordBatch( + *flatbufBuilder, + length, + fieldNodesOffset, + buffersOffset); + + auto totalSize = context->CurrentBodyOffset; + + return std::make_tuple( + recordBatchOffset, + totalSize, + [context = std::move(context)] (TMutableRef dstRef) { + char* current = dstRef.Begin(); + for (const auto& part : context->Parts) { + part.Writer(TMutableRef(current, current + part.Size)); + current += AlignUp<i64>(part.Size, ArrowAlignment); + } + YT_VERIFY(current == dstRef.End()); + }); +} +/////////////////////////////////////////////////////////////////////////////// + +class TArrowWriter + : public TSchemalessFormatWriterBase +{ +public: + TArrowWriter( + TNameTablePtr nameTable, + const std::vector<NTableClient::TTableSchemaPtr>& tableSchemas, + NConcurrency::IAsyncOutputStreamPtr output, + bool enableContextSaving, + TControlAttributesConfigPtr controlAttributesConfig, + int keyColumnCount) + : TSchemalessFormatWriterBase( + std::move(nameTable), + std::move(output), + enableContextSaving, + std::move(controlAttributesConfig), + keyColumnCount) + { + YT_VERIFY(tableSchemas.size() > 0); + + auto tableSchema = tableSchemas[0]; + auto columnCount = NameTable_->GetSize(); + + for (int columnIndex = 0; columnIndex < columnCount; columnIndex++) { + ColumnSchemas_.push_back(GetColumnSchema(tableSchema, columnIndex)); + } + } + +private: + void Reset() + { + Messages_.clear(); + TypedColumns_.clear(); + NumberOfRows_ = 0; + } + + void DoWrite(TRange<TUnversionedRow> rows) override + { + Reset(); + + auto convertedColumns = NColumnConverters::ConvertRowsToColumns(rows, ColumnSchemas_); + + std::vector<const TBatchColumn*> rootColumns; + rootColumns.reserve( std::ssize(convertedColumns)); + for (ssize_t columnIndex = 0; columnIndex < std::ssize(convertedColumns); columnIndex++) { + rootColumns.push_back(convertedColumns[columnIndex].RootColumn); + } + NumberOfRows_ = rows.size(); + PrepareColumns(rootColumns); + Encode(); + } + + void DoWriteBatch(NTableClient::IUnversionedRowBatchPtr rowBatch) override + { + auto columnarBatch = rowBatch->TryAsColumnar(); + if (!columnarBatch) { + YT_LOG_DEBUG("Encoding non-columnar batch; running write rows"); + DoWrite(rowBatch->MaterializeRows()); + } else { + YT_LOG_DEBUG("Encoding columnar batch"); + Reset(); + NumberOfRows_ = rowBatch->GetRowCount(); + PrepareColumns(columnarBatch->MaterializeColumns()); + Encode(); + } + } + + void Encode() + { + auto output = GetOutputStream(); + if (IsSchemaMessageNeeded()) { + if (!IsFirstBatch_) { + RegisterEosMarker(); + } + ResetArrowDictionaries(); + PrepareSchema(); + } + IsFirstBatch_ = false; + PrepareDictionaryBatches(); + PrepareRecordBatch(); + + WritePayload(output); + TryFlushBuffer(true); + } + +private: + bool IsFirstBatch_ = true; + size_t NumberOfRows_ = 0; + std::vector<TTypedBatchColumn> TypedColumns_; + std::vector<TColumnSchema> ColumnSchemas_; + std::vector<IUnversionedColumnarRowBatch::TDictionaryId> ArrowDictionaryIds_; + + struct TMessage + { + std::optional<flatbuffers::FlatBufferBuilder> FlatbufBuilder; + i64 BodySize; + TBodyWriter BodyWriter; + }; + + std::vector<TMessage> Messages_; + + bool CheckIfSystemColumnEnable(int columnIndex) + { + return ControlAttributesConfig_->EnableTableIndex && IsTableIndexColumnId(columnIndex) || + ControlAttributesConfig_->EnableRangeIndex && IsRangeIndexColumnId(columnIndex) || + ControlAttributesConfig_->EnableRowIndex && IsRowIndexColumnId(columnIndex) || + ControlAttributesConfig_->EnableTabletIndex && IsTabletIndexColumnId(columnIndex); + } + + bool CheckIfTypeIsNotNull(int columnIndex) + { + YT_VERIFY(columnIndex >= 0 && columnIndex < std::ssize(ColumnSchemas_)); + return CastToV1Type(ColumnSchemas_[columnIndex].LogicalType()).first != ESimpleLogicalValueType::Null; + } + + TColumnSchema GetColumnSchema(NTableClient::TTableSchemaPtr& tableSchema, int columnIndex) + { + YT_VERIFY(columnIndex >= 0); + auto name = NameTable_->GetName(columnIndex); + auto columnSchema = tableSchema->FindColumn(name); + if (!columnSchema) { + if (IsSystemColumnId(columnIndex) && CheckIfSystemColumnEnable(columnIndex)) { + return TColumnSchema(TString(name), EValueType::Int64); + } + return TColumnSchema(TString(name), EValueType::Null); + } + return *columnSchema; + } + + void PrepareColumns(const TRange<const TBatchColumn*>& batchColumns) + { + TypedColumns_.reserve(batchColumns.Size()); + for (const auto* column : batchColumns) { + if (CheckIfTypeIsNotNull(column->Id)) { + YT_VERIFY(column->Id >= 0 && column->Id < std::ssize(ColumnSchemas_)); + TypedColumns_.push_back(TTypedBatchColumn{ + column, + ColumnSchemas_[column->Id].LogicalType()}); + } + } + } + + bool IsSchemaMessageNeeded() + { + if (IsFirstBatch_) { + return true; + } + YT_VERIFY(ArrowDictionaryIds_.size() == TypedColumns_.size()); + bool result = false; + for (int index = 0; index < std::ssize(TypedColumns_); ++index) { + bool currentDictionary = IsDictionaryEncodedColumn(*TypedColumns_[index].Column); + bool previousDictionary = ArrowDictionaryIds_[index] != IUnversionedColumnarRowBatch::NullDictionaryId; + if (currentDictionary != previousDictionary) { + result = true; + } + } + return result; + } + + void ResetArrowDictionaries() + { + ArrowDictionaryIds_.assign(TypedColumns_.size(), IUnversionedColumnarRowBatch::NullDictionaryId); + } + + void RegisterEosMarker() + { + YT_LOG_DEBUG("EOS marker registered"); + + Messages_.push_back(TMessage{ + std::nullopt, + 0, + TBodyWriter()}); + } + + void RegisterMessage( + [[maybe_unused]] org::apache::arrow::flatbuf::MessageHeader type, + flatbuffers::FlatBufferBuilder&& flatbufBuilder, + i64 bodySize = 0, + std::function<void(TMutableRef)> bodyWriter = nullptr) + { + YT_LOG_DEBUG("Message registered (Type: %v, MessageSize: %v, BodySize: %v)", + org::apache::arrow::flatbuf::EnumNamesMessageHeader()[type], + flatbufBuilder.GetSize(), + bodySize); + + YT_VERIFY((bodySize % ArrowAlignment) == 0); + Messages_.push_back(TMessage{ + std::move(flatbufBuilder), + bodySize, + std::move(bodyWriter)}); + } + + void PrepareSchema() + { + flatbuffers::FlatBufferBuilder flatbufBuilder; + + int arrowDictionaryIdCounter = 0; + std::vector<flatbuffers::Offset<org::apache::arrow::flatbuf::Field>> fieldOffsets; + for (int columnIndex = 0; columnIndex < std::ssize(TypedColumns_); columnIndex++) { + const auto& typedColumn = TypedColumns_[columnIndex]; + YT_VERIFY(typedColumn.Column->Id >= 0 && typedColumn.Column->Id < std::ssize(ColumnSchemas_)); + auto columnSchema = ColumnSchemas_[typedColumn.Column->Id]; + auto nameOffset = SerializeString(&flatbufBuilder, columnSchema.Name()); + + auto [typeType, typeOffset] = SerializeColumnType(&flatbufBuilder, columnSchema); + + flatbuffers::Offset<org::apache::arrow::flatbuf::DictionaryEncoding> dictionaryEncodingOffset; + auto index_type_offset = org::apache::arrow::flatbuf::CreateInt(flatbufBuilder, 32, false); + + if (IsDictionaryEncodedColumn(*typedColumn.Column)) { + dictionaryEncodingOffset = org::apache::arrow::flatbuf::CreateDictionaryEncoding( + flatbufBuilder, + arrowDictionaryIdCounter++, + index_type_offset); + } + + auto fieldOffset = org::apache::arrow::flatbuf::CreateField( + flatbufBuilder, + nameOffset, + columnSchema.LogicalType()->IsNullable(), + typeType, + typeOffset, + dictionaryEncodingOffset); + + fieldOffsets.push_back(fieldOffset); + } + + auto fieldsOffset = flatbufBuilder.CreateVector(fieldOffsets); + + auto schemaOffset = org::apache::arrow::flatbuf::CreateSchema( + flatbufBuilder, + org::apache::arrow::flatbuf::Endianness_Little, + fieldsOffset); + + auto messageOffset = org::apache::arrow::flatbuf::CreateMessage( + flatbufBuilder, + org::apache::arrow::flatbuf::MetadataVersion_V4, + org::apache::arrow::flatbuf::MessageHeader_Schema, + schemaOffset.Union(), + 0); + + flatbufBuilder.Finish(messageOffset); + + RegisterMessage( + org::apache::arrow::flatbuf::MessageHeader_Schema, + std::move(flatbufBuilder)); + } + + void PrepareDictionaryBatches() + { + int arrowDictionaryIdCounter = 0; + auto prepareDictionaryBatch = [&] ( + int columnIndex, + IUnversionedColumnarRowBatch::TDictionaryId ytDictionaryId, + const TBatchColumn* dictionaryColumn) { + int arrowDictionaryId = arrowDictionaryIdCounter++; + const auto& typedColumn = TypedColumns_[columnIndex]; + auto previousYTDictionaryId = ArrowDictionaryIds_[columnIndex]; + if (ytDictionaryId == previousYTDictionaryId) { + YT_LOG_DEBUG("Reusing previous dictionary (ColumnId: %v, YTDictionaryId: %v, ArrowDictionaryId: %v)", + typedColumn.Column->Id, + ytDictionaryId, + arrowDictionaryId); + } else { + YT_LOG_DEBUG("Sending new dictionary (ColumnId: %v, YTDictionaryId: %v, ArrowDictionaryId: %v)", + typedColumn.Column->Id, + ytDictionaryId, + arrowDictionaryId); + PrepareDictionaryBatch( + TTypedBatchColumn{dictionaryColumn, typedColumn.Type}, + arrowDictionaryId); + ArrowDictionaryIds_[columnIndex] = ytDictionaryId; + } + }; + + for (int columnIndex = 0; columnIndex < std::ssize(TypedColumns_); ++columnIndex) { + const auto& typedColumn = TypedColumns_[columnIndex]; + if (typedColumn.Column->Dictionary) { + YT_LOG_DEBUG("Adding dictionary batch for dictionary-encoded column (ColumnId: %v)", + typedColumn.Column->Id); + prepareDictionaryBatch( + columnIndex, + typedColumn.Column->Dictionary->DictionaryId, + typedColumn.Column->Dictionary->ValueColumn); + } else if (IsRleButNotDictionaryEncodedStringLikeColumn(*typedColumn.Column)) { + YT_LOG_DEBUG("Adding dictionary batch for RLE but not dictionary-encoded string-like column (ColumnId: %v)", + typedColumn.Column->Id); + prepareDictionaryBatch( + columnIndex, + IUnversionedColumnarRowBatch::GenerateDictionaryId(), // any unique one will do + typedColumn.Column->Rle->ValueColumn); + } else if (IsRleAndDictionaryEncodedColumn(*typedColumn.Column)) { + YT_LOG_DEBUG("Adding dictionary batch for RLE and dictionary-encoded column (ColumnId: %v)", + typedColumn.Column->Id); + prepareDictionaryBatch( + columnIndex, + typedColumn.Column->Rle->ValueColumn->Dictionary->DictionaryId, + typedColumn.Column->Rle->ValueColumn->Dictionary->ValueColumn); + } + } + } + + void PrepareDictionaryBatch( + const TTypedBatchColumn& typedColumn, + int arrowDictionaryId) + { + flatbuffers::FlatBufferBuilder flatbufBuilder; + + auto [recordBatchOffset, bodySize, bodyWriter] = SerializeRecordBatch( + &flatbufBuilder, + typedColumn.Column->ValueCount, + MakeRange({typedColumn})); + + auto dictionaryBatchOffset = org::apache::arrow::flatbuf::CreateDictionaryBatch( + flatbufBuilder, + arrowDictionaryId, + recordBatchOffset); + + auto messageOffset = org::apache::arrow::flatbuf::CreateMessage( + flatbufBuilder, + org::apache::arrow::flatbuf::MetadataVersion_V4, + org::apache::arrow::flatbuf::MessageHeader_DictionaryBatch, + dictionaryBatchOffset.Union(), + bodySize); + + flatbufBuilder.Finish(messageOffset); + + RegisterMessage( + org::apache::arrow::flatbuf::MessageHeader_DictionaryBatch, + std::move(flatbufBuilder), + bodySize, + std::move(bodyWriter)); + } + + void PrepareRecordBatch() + { + flatbuffers::FlatBufferBuilder flatbufBuilder; + + auto [recordBatchOffset, bodySize, bodyWriter] = SerializeRecordBatch( + &flatbufBuilder, + NumberOfRows_, + TypedColumns_); + + auto messageOffset = org::apache::arrow::flatbuf::CreateMessage( + flatbufBuilder, + org::apache::arrow::flatbuf::MetadataVersion_V4, + org::apache::arrow::flatbuf::MessageHeader_RecordBatch, + recordBatchOffset.Union(), + bodySize); + + flatbufBuilder.Finish(messageOffset); + + RegisterMessage( + org::apache::arrow::flatbuf::MessageHeader_RecordBatch, + std::move(flatbufBuilder), + bodySize, + std::move(bodyWriter)); + } + + i64 GetPayloadSize() const + { + i64 size = 0; + for (const auto& message : Messages_) { + size += sizeof(ui32); // continuation indicator + size += sizeof(ui32); // metadata size + if (message.FlatbufBuilder) { + size += AlignUp<i64>(message.FlatbufBuilder->GetSize(), ArrowAlignment); // metadata message + size += AlignUp<i64>(message.BodySize, ArrowAlignment); // body + } + } + return size; + } + + void WritePayload(TBlobOutput* output) + { + YT_LOG_DEBUG("Started writing payload"); + for (const auto& message : Messages_) { + // Continuation indicator + ui32 constMax = 0xFFFFFFFF; + output->Write(&constMax, sizeof(ui32)); + + if (message.FlatbufBuilder) { + auto metadataSize = message.FlatbufBuilder->GetSize(); + + auto metadataPtr = message.FlatbufBuilder->GetBufferPointer(); + + + ui32 metadataSz = AlignUp<i64>(metadataSize, ArrowAlignment); + + output->Write(&metadataSz, sizeof(ui32)); + output->Write(metadataPtr, metadataSize); + + // Body + if (message.BodyWriter) { + TString current; + current.resize(message.BodySize); + // Double copying. + message.BodyWriter(TMutableRef::FromString(current)); + output->Write(current.data(), message.BodySize); + } else { + YT_VERIFY(message.BodySize == 0); + } + } else { + // EOS marker + ui32 zero = 0; + output->Write(&zero, sizeof(ui32)); + } + } + + YT_LOG_DEBUG("Finished writing payload"); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ISchemalessFormatWriterPtr CreateWriterForArrow( + NTableClient::TNameTablePtr nameTable, + const std::vector<NTableClient::TTableSchemaPtr>& schemas, + NConcurrency::IAsyncOutputStreamPtr output, + bool enableContextSaving, + TControlAttributesConfigPtr controlAttributesConfig, + int keyColumnCount) +{ + auto result = New<TArrowWriter>( + std::move(nameTable), + schemas, + std::move(output), + enableContextSaving, + std::move(controlAttributesConfig), + keyColumnCount); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NFormats diff --git a/yt/yt/client/formats/arrow_writer.h b/yt/yt/client/formats/arrow_writer.h new file mode 100644 index 0000000000..16aacc4722 --- /dev/null +++ b/yt/yt/client/formats/arrow_writer.h @@ -0,0 +1,26 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/table_client/public.h> + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/ytree/public.h> + + +namespace NYT::NFormats { + +//////////////////////////////////////////////////////////////////////////////// + +ISchemalessFormatWriterPtr CreateWriterForArrow( + NTableClient::TNameTablePtr nameTable, + const std::vector<NTableClient::TTableSchemaPtr>& schemas, + NConcurrency::IAsyncOutputStreamPtr output, + bool enableContextSaving, + TControlAttributesConfigPtr controlAttributesConfig, + int keyColumnCount); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NFormat diff --git a/yt/yt/client/formats/public.h b/yt/yt/client/formats/public.h index 4087a81d44..af572cc98f 100644 --- a/yt/yt/client/formats/public.h +++ b/yt/yt/client/formats/public.h @@ -2,6 +2,8 @@ #include <yt/yt/core/misc/public.h> +#include <yt/yt/core/logging/log.h> + namespace NYT::NFormats { //////////////////////////////////////////////////////////////////////////////// @@ -88,4 +90,8 @@ class TFormat; //////////////////////////////////////////////////////////////////////////////// +inline const NLogging::TLogger FormatsLogger("Formats"); + +//////////////////////////////////////////////////////////////////////////////// + } // namespace NYT::NFormats diff --git a/yt/yt/client/formats/ya.make b/yt/yt/client/formats/ya.make index 18eb0e8384..a171707b02 100644 --- a/yt/yt/client/formats/ya.make +++ b/yt/yt/client/formats/ya.make @@ -3,6 +3,7 @@ LIBRARY() INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) SRCS( + arrow_writer.cpp config.cpp dsv_parser.cpp dsv_writer.cpp @@ -37,7 +38,9 @@ SRCS( PEERDIR( yt/yt/client + yt/yt/client/arrow/fbs yt/yt/library/skiff_ext + yt/yt/library/column_converters yt/yt_proto/yt/formats library/cpp/string_utils/base64 ) diff --git a/yt/yt/client/job_proxy/public.h b/yt/yt/client/job_proxy/public.h new file mode 100644 index 0000000000..dc12e5ed5e --- /dev/null +++ b/yt/yt/client/job_proxy/public.h @@ -0,0 +1,25 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT::NJobProxy { + +//////////////////////////////////////////////////////////////////////////////// + +YT_DEFINE_ERROR_ENUM( + ((MemoryLimitExceeded) (1200)) + ((MemoryCheckFailed) (1201)) + ((JobTimeLimitExceeded) (1202)) + ((UnsupportedJobType) (1203)) + ((JobNotPrepared) (1204)) + ((UserJobFailed) (1205)) + ((UserJobProducedCoreFiles) (1206)) + ((ShallowMergeFailed) (1207)) + ((JobNotRunning) (1208)) + ((InterruptionUnsupported) (1209)) + ((InterruptionTimeout) (1210)) +); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NJobProxy diff --git a/yt/yt/client/rpc/helpers-inl.h b/yt/yt/client/rpc/helpers-inl.h new file mode 100644 index 0000000000..1a1041bda6 --- /dev/null +++ b/yt/yt/client/rpc/helpers-inl.h @@ -0,0 +1,43 @@ +#ifndef HELPERS_INL_H_ +#error "Direct inclusion of this file is not allowed, include helpers.h" +// For the sake of sane code completion. +#include "helpers.h" +#endif + +#include <yt/yt_proto/yt/client/misc/proto/workload.pb.h> + +#include <yt/yt/core/rpc/service.h> + +namespace NYT::NRpc { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TContextPtr> +TWorkloadDescriptor GetRequestWorkloadDescriptor(const TContextPtr& context) +{ + using NYT::FromProto; + const auto& header = context->GetRequestHeader(); + auto extensionId = NYT::NProto::TWorkloadDescriptorExt::workload_descriptor; + if (header.HasExtension(extensionId)) { + return FromProto<TWorkloadDescriptor>(header.GetExtension(extensionId)); + } + // COMPAT(babenko): drop descriptor from request body + return FromProto<TWorkloadDescriptor>(context->Request().workload_descriptor()); +} + +template <class TRequestPtr> +void SetRequestWorkloadDescriptor( + const TRequestPtr& request, + const TWorkloadDescriptor& workloadDescriptor) +{ + using NYT::ToProto; + auto extensionId = NYT::NProto::TWorkloadDescriptorExt::workload_descriptor; + auto* ext = request->Header().MutableExtension(extensionId); + ToProto(ext, workloadDescriptor); + // COMPAT(babenko): drop descriptor from request body + ToProto(request->mutable_workload_descriptor(), workloadDescriptor); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NRpc diff --git a/yt/yt/client/rpc/helpers.h b/yt/yt/client/rpc/helpers.h new file mode 100644 index 0000000000..53c0e7fce0 --- /dev/null +++ b/yt/yt/client/rpc/helpers.h @@ -0,0 +1,25 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/misc/workload.h> + +namespace NYT::NRpc { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TContextPtr> +TWorkloadDescriptor GetRequestWorkloadDescriptor( + const TContextPtr& context); +template <class TRequestPtr> +void SetRequestWorkloadDescriptor( + const TRequestPtr& request, + const TWorkloadDescriptor& workloadDescriptor); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NRpc + +#define HELPERS_INL_H_ +#include "helpers-inl.h" +#undef HELPERS_INL_H_ diff --git a/yt/yt/client/rpc/public.h b/yt/yt/client/rpc/public.h new file mode 100644 index 0000000000..c237da11bf --- /dev/null +++ b/yt/yt/client/rpc/public.h @@ -0,0 +1,11 @@ +#pragma once + +#include <yt/yt/core/rpc/public.h> + +namespace NYT::NRpc { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NRpc diff --git a/yt/yt/client/table_chunk_format/public.h b/yt/yt/client/table_chunk_format/public.h new file mode 100644 index 0000000000..d4495252b5 --- /dev/null +++ b/yt/yt/client/table_chunk_format/public.h @@ -0,0 +1,15 @@ +#pragma once + +namespace NYT::NTableChunkFormat { + +namespace NProto { + +class TSegmentMeta; +class TTimestampSegmentMeta; +class TIntegerSegmentMeta; +class TStringSegmentMeta; +class TDenseVersionedSegmentMeta; + +} // namespace NProto + +} // namespace NYT::NTableChunkFormat diff --git a/yt/yt/client/table_client/record_codegen_deps.h b/yt/yt/client/table_client/record_codegen_deps.h new file mode 100644 index 0000000000..441dadc45f --- /dev/null +++ b/yt/yt/client/table_client/record_codegen_deps.h @@ -0,0 +1,4 @@ +#pragma once + +#include "record_codegen_h.h" +#include "record_codegen_cpp.h" diff --git a/yt/yt/client/table_client/record_codegen_h.h b/yt/yt/client/table_client/record_codegen_h.h new file mode 100644 index 0000000000..40c8d963c2 --- /dev/null +++ b/yt/yt/client/table_client/record_codegen_h.h @@ -0,0 +1,12 @@ +#pragma once + +#include <yt/yt/client/table_client/record_descriptor.h> + +#include <yt/yt/core/misc/singleton.h> + +#include <library/cpp/yt/memory/range.h> +#include <library/cpp/yt/memory/shared_range.h> + +#include <initializer_list> +#include <optional> +#include <vector> diff --git a/yt/yt/client/table_client/record_descriptor.h b/yt/yt/client/table_client/record_descriptor.h new file mode 100644 index 0000000000..374d5177f6 --- /dev/null +++ b/yt/yt/client/table_client/record_descriptor.h @@ -0,0 +1,19 @@ +#pragma once + +#include "public.h" + +namespace NYT::NTableClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct IRecordDescriptor +{ + virtual ~IRecordDescriptor() = default; + + virtual const TTableSchemaPtr& GetSchema() const = 0; + virtual const TNameTablePtr& GetNameTable() const = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NTableClient diff --git a/yt/yt/client/tablet_client/dynamic_value.h b/yt/yt/client/tablet_client/dynamic_value.h new file mode 100644 index 0000000000..58fa5a6eec --- /dev/null +++ b/yt/yt/client/tablet_client/dynamic_value.h @@ -0,0 +1,51 @@ +#pragma once + +#include <yt/yt/client/table_client/row_base.h> +#include <yt/yt/client/table_client/unversioned_value.h> + +namespace NYT::NTabletClient { + +//////////////////////////////////////////////////////////////////////////////// + +// NB: 4-aligned. +struct TDynamicString +{ + ui32 Length; + char Data[1]; // the actual length is above +}; + +// NB: TDynamicValueData must be binary compatible with TUnversionedValueData for all simple types. +union TDynamicValueData +{ + //! |Int64| value. + i64 Int64; + //! |Uint64| value. + ui64 Uint64; + //! |Double| value. + double Double; + //! |Boolean| value. + bool Boolean; + //! String value for |String| type or YSON-encoded value for |Any| type. + TDynamicString* String; +}; + +static_assert( + sizeof(TDynamicValueData) == sizeof(NTableClient::TUnversionedValueData), + "TDynamicValueData and TUnversionedValueData must be of the same size."); + +struct TDynamicValue +{ + TDynamicValueData Data; + ui32 Revision; + bool Null; + NTableClient::EValueFlags Flags; + char Padding[2]; +}; + +static_assert( + sizeof(TDynamicValue) == 16, + "Wrong TDynamicValue size."); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NTabletClient diff --git a/yt/yt/client/unittests/arrow_writer_ut.cpp b/yt/yt/client/unittests/arrow_writer_ut.cpp new file mode 100644 index 0000000000..668f30d04e --- /dev/null +++ b/yt/yt/client/unittests/arrow_writer_ut.cpp @@ -0,0 +1,942 @@ +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/client/formats/arrow_writer.h> +#include <yt/yt/client/formats/config.h> +#include <yt/yt/client/formats/format.h> + +#include <yt/yt/client/table_client/helpers.h> +#include <yt/yt/client/table_client/name_table.h> +#include <yt/yt/client/table_client/unversioned_row.h> +#include <yt/yt/client/table_client/validate_logical_type.h> + +#include <yt/yt/ytlib/chunk_client/chunk_reader.h> +#include <yt/yt/ytlib/chunk_client/chunk_reader_options.h> +#include <yt/yt/ytlib/chunk_client/chunk_reader_statistics.h> +#include <yt/yt/ytlib/chunk_client/memory_reader.h> +#include <yt/yt/ytlib/chunk_client/memory_writer.h> + +#include <yt/yt/ytlib/table_client/cached_versioned_chunk_meta.h> +#include <yt/yt/ytlib/table_client/chunk_state.h> +#include <yt/yt/ytlib/table_client/config.h> +#include <yt/yt/ytlib/table_client/schemaless_chunk_writer.h> +#include <yt/yt/ytlib/table_client/schemaless_multi_chunk_reader.h> + +#include <yt/yt/library/named_value/named_value.h> + +#include <util/stream/null.h> +#include <util/string/hex.h> + +#include <contrib/libs/apache/arrow/cpp/src/arrow/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/io/api.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h> +#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/api.h> + +#include <stdlib.h> + +namespace NYT::NTableClient { + +namespace { + +using namespace NChunkClient; +using namespace NFormats; +using namespace NNamedValue; +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +IUnversionedRowBatchPtr MakeColumnarRowBatch( + TRange<NTableClient::TUnversionedRow> rows, + TTableSchemaPtr Schema_) +{ + + auto memoryWriter = New<TMemoryWriter>(); + + auto config = New<TChunkWriterConfig>(); + config->Postprocess(); + config->BlockSize = 256; + config->Postprocess(); + + auto options = New<TChunkWriterOptions>(); + options->OptimizeFor = EOptimizeFor::Scan; + options->Postprocess(); + + auto chunkWriter = CreateSchemalessChunkWriter( + config, + options, + Schema_, + /*nameTable*/ nullptr, + memoryWriter, + /*dataSink*/ std::nullopt); + + TUnversionedRowsBuilder builder; + + chunkWriter->Write(rows); + chunkWriter->Close().Get().IsOK(); + + auto MemoryReader_ = CreateMemoryReader( + memoryWriter->GetChunkMeta(), + memoryWriter->GetBlocks()); + + NChunkClient::NProto::TChunkSpec ChunkSpec_; + ToProto(ChunkSpec_.mutable_chunk_id(), NullChunkId); + ChunkSpec_.set_table_row_index(42); + + auto ChunkMeta_ = New<TColumnarChunkMeta>(*memoryWriter->GetChunkMeta()); + + auto ChunkState_ = New<TChunkState>(TChunkState{ + .BlockCache = GetNullBlockCache(), + .ChunkSpec = ChunkSpec_, + .TableSchema = Schema_, + }); + + auto schemalessRangeChunkReader = CreateSchemalessRangeChunkReader( + ChunkState_, + ChunkMeta_, + TChunkReaderConfig::GetDefault(), + TChunkReaderOptions::GetDefault(), + MemoryReader_, + TNameTable::FromSchema(*Schema_), + /* chunkReadOptions */ {}, + /* sortColumns */ {}, + /* omittedInaccessibleColumns */ {}, + TColumnFilter(), + TReadRange()); + + TRowBatchReadOptions opt{ + .MaxRowsPerRead = static_cast<i64>(rows.size()) + 10, + .Columnar = true}; + auto batch = ReadRowBatch(schemalessRangeChunkReader, opt); + return batch; +} + +//////////////////////////////////////////////////////////////////////////////// + +ISchemalessFormatWriterPtr CreateArrowWriter(TNameTablePtr nameTable, + IOutputStream* outputStream, + const std::vector<NTableClient::TTableSchemaPtr>& schemas) +{ + auto controlAttributes = NYT::New<TControlAttributesConfig>(); + controlAttributes->EnableTableIndex = false; + controlAttributes->EnableRowIndex = false; + controlAttributes->EnableRangeIndex = false; + controlAttributes->EnableTabletIndex = false; + return CreateWriterForArrow( + nameTable, + schemas, + NConcurrency::CreateAsyncAdapter(static_cast<IOutputStream*>(outputStream)), + false, + controlAttributes, + 0); +} + +ISchemalessFormatWriterPtr CreateArrowWriterWithSystemColumns(TNameTablePtr nameTable, + IOutputStream* outputStream, + const std::vector<NTableClient::TTableSchemaPtr>& schemas) +{ + auto controlAttributes = NYT::New<TControlAttributesConfig>(); + controlAttributes->EnableTableIndex = true; + controlAttributes->EnableRowIndex = true; + controlAttributes->EnableRangeIndex = true; + controlAttributes->EnableTabletIndex = true; + return CreateWriterForArrow( + nameTable, + schemas, + NConcurrency::CreateAsyncAdapter(static_cast<IOutputStream*>(outputStream)), + false, + controlAttributes, + 0); +} + +//////////////////////////////////////////////////////////////////////////////// + +std::shared_ptr<arrow::RecordBatch> MakeBatch(const TStringStream& outputStream) +{ + auto buffer = arrow::Buffer(reinterpret_cast<const uint8_t*>(outputStream.Data()), outputStream.Size()); + arrow::io::BufferReader bufferReader(buffer); + + std::shared_ptr<arrow::ipc::RecordBatchStreamReader> batchReader = (arrow::ipc::RecordBatchStreamReader::Open(&bufferReader)).ValueOrDie(); + + auto batch = batchReader->Next().ValueOrDie(); + return batch; +} + +std::vector<std::shared_ptr<arrow::RecordBatch>> MakeAllBatch(const TStringStream& outputStream, int batchNumb) +{ + auto buffer = arrow::Buffer(reinterpret_cast<const uint8_t*>(outputStream.Data()), outputStream.Size()); + arrow::io::BufferReader bufferReader(buffer); + + std::shared_ptr<arrow::ipc::RecordBatchStreamReader> batchReader = (arrow::ipc::RecordBatchStreamReader::Open(&bufferReader)).ValueOrDie(); + + std::vector<std::shared_ptr<arrow::RecordBatch>> batches; + for (int i = 0; i < batchNumb; i++) { + auto batch = batchReader->Next().ValueOrDie(); + if (batch == nullptr) { + batchReader = (arrow::ipc::RecordBatchStreamReader::Open(&bufferReader)).ValueOrDie(); + batchNumb++; + } else { + batches.push_back(batch); + } + } + return batches; +} + +//////////////////////////////////////////////////////////////////////////////// + +std::vector<int64_t> ReadInterger64Array(const std::shared_ptr<arrow::Array>& array) +{ + auto int64Array = std::dynamic_pointer_cast<arrow::Int64Array>(array); + YT_VERIFY(int64Array); + return {int64Array->raw_values(), int64Array->raw_values() + array->length()}; +} + +std::vector<uint32_t> ReadInterger32Array(const std::shared_ptr<arrow::Array>& array) +{ + auto int32Array = std::dynamic_pointer_cast<arrow::UInt32Array>(array); + YT_VERIFY(int32Array); + return {int32Array->raw_values(), int32Array->raw_values() + array->length()}; +} + +std::vector<std::string> ReadStringArray(const std::shared_ptr<arrow::Array>& array) +{ + auto arraySize = array->length(); + auto binArray = std::dynamic_pointer_cast<arrow::BinaryArray>(array); + YT_VERIFY(binArray); + std::vector<std::string> stringArray; + for (int i = 0; i < arraySize; i++) { + stringArray.push_back(binArray->GetString(i)); + } + return stringArray; +} + +std::vector<bool> ReadBoolArray(const std::shared_ptr<arrow::Array>& array) +{ + auto arraySize = array->length(); + auto boolArray = std::dynamic_pointer_cast<arrow::BooleanArray>(array); + YT_VERIFY(boolArray); + std::vector<bool> result; + for (int i = 0; i < arraySize; i++) { + result.push_back(boolArray->Value(i)); + } + return result; +} + +std::vector<double> ReadDoubleArray(const std::shared_ptr<arrow::Array>& array) +{ + auto doubleArray = std::dynamic_pointer_cast<arrow::DoubleArray>(array); + YT_VERIFY(doubleArray); + return {doubleArray->raw_values(), doubleArray->raw_values() + array->length()}; +} + +std::vector<std::string> ReadStringArrayFromDict(const std::shared_ptr<arrow::Array>& array) +{ + auto dictAr = std::dynamic_pointer_cast<arrow::DictionaryArray>(array); + YT_VERIFY(dictAr); + auto indices = ReadInterger32Array(dictAr->indices()); + + // Get values array. + auto values = ReadStringArray(dictAr->dictionary()); + + std::vector<std::string> result; + for (size_t i = 0; i < indices.size(); i++) { + auto index = indices[i]; + auto value = values[index]; + result.push_back(value); + } + return result; +} + +std::vector<std::string> ReadAnyStringArray(const std::shared_ptr<arrow::Array>& array) +{ + if (std::dynamic_pointer_cast<arrow::BinaryArray>(array)) { + return ReadStringArray(array); + } else if (std::dynamic_pointer_cast<arrow::DictionaryArray>(array)) { + return ReadStringArrayFromDict(array); + } + YT_ABORT(); +} + +bool IsDictColumn(const std::shared_ptr<arrow::Array>& array) +{ + return std::dynamic_pointer_cast<arrow::DictionaryArray>(array) != nullptr; +} + +//////////////////////////////////////////////////////////////////////////////// + +using ColumnInteger = std::vector<int64_t>; +using ColumnString = std::vector<std::string>; +using ColumnBool = std::vector<bool>; +using ColumnDouble = std::vector<double>; + +using ColumnStringWithNulls = std::vector<std::optional<std::string>>; +using ColumnBoolWithNulls = std::vector<std::optional<bool>>; +using ColumnDoubleWithNulls = std::vector<std::optional<double>>; + +struct TOwnerRows +{ + std::vector<TUnversionedRow> Rows; + std::vector<TUnversionedOwningRowBuilder> Builders; + TNameTablePtr NameTable; + std::vector<TUnversionedOwningRow> OwningRows; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TOwnerRows MakeUnversionedIntegerRows( + const std::vector<ColumnInteger>& column, + const std::vector<std::string>& columnNames) +{ + YT_VERIFY(column.size() > 0); + + auto nameTable = New<TNameTable>(); + + std::vector<TUnversionedOwningRowBuilder> rowsBuilders(column[0].size()); + + for (int colIdx = 0; colIdx < std::ssize(column); colIdx++) { + auto columnId = nameTable->RegisterName(columnNames[colIdx]); + for (int rowIndex = 0; rowIndex < std::ssize(column[colIdx]); rowIndex++) { + rowsBuilders[rowIndex].AddValue(MakeUnversionedInt64Value(column[colIdx][rowIndex], columnId)); + } + } + std::vector<TUnversionedRow> rows; + std::vector<TUnversionedOwningRow> owningRows; + for (int rowIndex = 0; rowIndex < std::ssize(rowsBuilders); rowIndex++) { + owningRows.push_back(rowsBuilders[rowIndex].FinishRow()); + rows.push_back(owningRows.back().Get()); + } + return {std::move(rows), std::move(rowsBuilders), std::move(nameTable), std::move(owningRows)}; +} + +TOwnerRows MakeUnversionedStringRows( + const std::vector<ColumnString>& column, + const std::vector<std::string>& columnNames) +{ + YT_VERIFY(column.size() > 0); + std::vector<TString> strings; + + auto nameTable = New<TNameTable>(); + + std::vector<TUnversionedOwningRowBuilder> rowsBuilders(column[0].size()); + + for (int colIdx = 0; colIdx < std::ssize(column); colIdx++) { + auto columnId = nameTable->RegisterName(columnNames[colIdx]); + for (int rowIndex = 0; rowIndex < std::ssize(column[colIdx]); rowIndex++) { + strings.push_back(TString(column[colIdx][rowIndex])); + rowsBuilders[rowIndex].AddValue(MakeUnversionedStringValue(strings.back(), columnId)); + } + } + std::vector<TUnversionedRow> rows; + std::vector<TUnversionedOwningRow> owningRows; + for (int rowIndex = 0; rowIndex < std::ssize(rowsBuilders); rowIndex++) { + owningRows.push_back(rowsBuilders[rowIndex].FinishRow()); + rows.push_back(owningRows.back().Get()); + } + return {std::move(rows), std::move(rowsBuilders), std::move(nameTable), std::move(owningRows)}; +} + +std::string MakeRandomString(size_t stringSize) +{ + std::string randomString; + randomString.reserve(stringSize); + for (size_t i = 0; i < stringSize; i++) { + randomString += ('a' + rand() % 30); + } + return randomString; +} + +//////////////////////////////////////////////////////////////////////////////// + +void CheckColumnNames( + std::shared_ptr<arrow::RecordBatch> batch, + const std::vector<std::string>& columnNames) +{ + EXPECT_EQ(batch->num_columns(), std::ssize(columnNames)); + for (size_t i = 0; i < columnNames.size(); i++) { + EXPECT_EQ(batch->column_name(i), columnNames[i]); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(Simple, JustWork) +{ + EXPECT_TRUE(true); + std::vector<TTableSchemaPtr> tableSchemas; + std::vector<std::string> columnNames = {"integer"}; + + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::Int64), + })); + + TStringStream outputStream; + + ColumnInteger column = {42, 179179}; + + auto rows = MakeUnversionedIntegerRows({column}, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + + EXPECT_TRUE(writer->Write(rows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadInterger64Array(batch->column(0)), column); +} + +TEST(Simple, WorkWithSystemColumns) +{ + std::vector<TTableSchemaPtr> tableSchemas; + std::vector<std::string> columnNames = {"integer"}; + + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::Int64), + })); + + TStringStream outputStream; + + ColumnInteger column = {42, 179179}; + + auto rows = MakeUnversionedIntegerRows({column}, columnNames); + + auto writer = CreateArrowWriterWithSystemColumns(rows.NameTable, &outputStream, tableSchemas); + + + EXPECT_TRUE(writer->Write(rows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + CheckColumnNames(batch, {"integer", "$row_index", "$range_index", "$table_index", "$tablet_index"}); + EXPECT_EQ(ReadInterger64Array(batch->column(0)), column); +} + +TEST(Simple, ColumnarBatch) +{ + std::vector<TTableSchemaPtr> tableSchemas; + std::vector<std::string> columnNames = {"integer"}; + + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::Int64), + })); + + TStringStream outputStream; + + ColumnInteger column = {42, 179179}; + + auto rows = MakeUnversionedIntegerRows({column}, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + auto columnarBatch = MakeColumnarRowBatch(rows.Rows, tableSchemas[0]); + EXPECT_TRUE(writer->WriteBatch(columnarBatch)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadInterger64Array(batch->column(0)), column); +} + +TEST(Simple, RowBatch) +{ + std::vector<TTableSchemaPtr> tableSchemas; + std::vector<std::string> columnNames = {"integer"}; + + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::Int64), + })); + + TStringStream outputStream; + + ColumnInteger column = {42, 179179}; + + auto rows = MakeUnversionedIntegerRows({column}, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + auto rowBatch = CreateBatchFromUnversionedRows(MakeSharedRange(std::move(rows.Rows))); + + EXPECT_TRUE(writer->WriteBatch(rowBatch)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadInterger64Array(batch->column(0)), column); +} + +TEST(Simple, Null) +{ + std::vector<TTableSchemaPtr> tableSchemas; + std::vector<std::string> columnNames = {"integer"}; + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::Int64), + TColumnSchema(TString("null"), EValueType::Null), + })); + + TStringStream outputStream; + auto nameTable = New<TNameTable>(); + auto columnId = nameTable->RegisterName(columnNames[0]); + auto nullColumnId = nameTable->RegisterName("null"); + + TUnversionedRowBuilder row1, row2; + row1.AddValue(MakeUnversionedNullValue(columnId)); + row1.AddValue(MakeUnversionedNullValue(nullColumnId)); + + row2.AddValue(MakeUnversionedInt64Value(3, columnId)); + row2.AddValue(MakeUnversionedNullValue(nullColumnId)); + + std::vector<TUnversionedRow> rows = {row1.GetRow(), row2.GetRow()}; + + + auto writer = CreateArrowWriter(nameTable, &outputStream, tableSchemas); + + + EXPECT_TRUE(writer->Write(rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadInterger64Array(batch->column(0))[1], 3); +} + +TEST(Simple, String) +{ + std::vector<std::string> columnNames = {"string"}; + std::vector<TTableSchemaPtr> tableSchemas; + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::String), + })); + + TStringStream outputStream; + + ColumnString column = {"cat", "mouse"}; + + auto rows = MakeUnversionedStringRows({column}, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + + EXPECT_TRUE(writer->Write(rows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadAnyStringArray(batch->column(0)), column); +} + +TEST(Simple, DictionaryString) +{ + std::vector<std::string> columnNames = {"string"}; + std::vector<TTableSchemaPtr> tableSchemas; + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::String), + })); + TStringStream outputStream; + + std::string longString, longString2; + for (int i = 0; i < 20; i++) { + longString += 'a'; + longString2 += 'b'; + } + + auto rows = MakeUnversionedStringRows({{longString, longString2, longString, longString2}}, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + EXPECT_TRUE(writer->Write(rows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadAnyStringArray(batch->column(0))[0], longString); + EXPECT_TRUE(IsDictColumn(batch->column(0))); +} + +TEST(Simple, DictionaryAndDirectStrings) +{ + std::vector<std::string> columnNames = {"string"}; + std::vector<TTableSchemaPtr> tableSchemas; + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema(TString(columnNames[0]), EValueType::String), + })); + + TStringStream outputStream; + + std::string longString, longString2; + for (int i = 0; i < 20; i++) { + longString += 'a'; + longString2 += 'b'; + } + ColumnString firstColumn = {longString, longString2, longString, longString2}; + ColumnString secondColumn = {"cat", "dog", "mouse", "table"}; + + auto dictRows = MakeUnversionedStringRows({firstColumn}, columnNames); + auto directRows = MakeUnversionedStringRows({secondColumn}, columnNames); + + auto writer = CreateArrowWriter(dictRows.NameTable, &outputStream, tableSchemas); + + // Write first batch, that will be decode as dictionary. + EXPECT_TRUE(writer->Write(dictRows.Rows)); + + // Write second batch, that will be decode as direct. + EXPECT_TRUE(writer->Write(directRows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + + auto batches = MakeAllBatch(outputStream, 2); + + CheckColumnNames(batches[0], columnNames); + CheckColumnNames(batches[1], columnNames); + + EXPECT_EQ(ReadAnyStringArray(batches[0]->column(0)), firstColumn); + EXPECT_EQ(ReadAnyStringArray(batches[1]->column(0)), secondColumn); +} + +TEST(StressOneBatch, Integer) +{ + // Constans. + const size_t columnsCount = 100; + const size_t rowsCount = 100; + + std::vector<TTableSchemaPtr> tableSchemas; + TStringStream outputStream; + + std::vector<std::string> columnNames; + std::vector<ColumnInteger> columnsElements(columnsCount); + + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + // Create column name. + std::string ColumnName = "integer" + std::to_string(columnIndex); + columnNames.push_back(ColumnName); + + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + columnsElements[columnIndex].push_back(rand()); + } + } + + std::vector<TColumnSchema> schemas_; + for (size_t columnIdx = 0; columnIdx < columnsCount; columnIdx++) { + schemas_.push_back(TColumnSchema(TString(columnNames[columnIdx]), EValueType::Int64)); + } + tableSchemas.push_back(New<TTableSchema>(schemas_)); + + auto rows = MakeUnversionedIntegerRows(columnsElements, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + EXPECT_TRUE(writer->Write(rows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + + CheckColumnNames(batch, columnNames); + + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + EXPECT_EQ(ReadInterger64Array(batch->column(columnIndex)), columnsElements[columnIndex]); + } +} + +TEST(StressOneBatch, String) +{ + const size_t columnsCount = 10; + const size_t rowsCount = 10; + const size_t stringSize = 10; + + std::vector<TTableSchemaPtr> tableSchemas; + + TStringStream outputStream; + + std::vector<std::string> columnNames; + std::vector<ColumnString> columnsElements(columnsCount); + + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + + std::string ColumnName = "string" + std::to_string(columnIndex); + columnNames.push_back(ColumnName); + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + columnsElements[columnIndex].push_back(MakeRandomString(stringSize)); + } + } + + std::vector<TColumnSchema> schemas_; + for (size_t columnIdx = 0; columnIdx < columnsCount; columnIdx++) { + schemas_.push_back(TColumnSchema(TString(columnNames[columnIdx]), EValueType::String)); + } + tableSchemas.push_back(New<TTableSchema>(schemas_)); + + auto rows = MakeUnversionedStringRows(columnsElements, columnNames); + + auto writer = CreateArrowWriter(rows.NameTable, &outputStream, tableSchemas); + + EXPECT_TRUE(writer->Write(rows.Rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + auto batch = MakeBatch(outputStream); + + CheckColumnNames(batch, columnNames); + + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + EXPECT_EQ(ReadAnyStringArray(batch->column(columnIndex)), columnsElements[columnIndex]); + } +} + +TEST(StressOneBatch, MixTypes) +{ + // Constants. + const size_t rowsCount = 10; + const size_t stringSize = 10; + + std::vector<TTableSchemaPtr> tableSchemas; + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema("bool", EValueType::Boolean), + TColumnSchema("double", EValueType::Double), + TColumnSchema("any", EValueType::Any)})); + + TStringStream outputStream; + + auto nameTable = New<TNameTable>(); + std::vector<TUnversionedOwningRowBuilder> rowsBuilders(rowsCount); + + std::vector<std::string> columnNames; + + std::vector<bool> boolColumn; + std::vector<double> doubleColumn; + std::vector<std::string> anyColumn; + std::vector<TUnversionedRow> rows; + + // Fill bool column. + std::string ColumnName = "bool"; + auto boolId = nameTable->RegisterName(ColumnName); + columnNames.push_back(ColumnName); + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + boolColumn.push_back((rand() % 2) == 0); + + rowsBuilders[rowIndex].AddValue(MakeUnversionedBooleanValue(boolColumn[rowIndex], boolId)); + } + + // Fill double column. + ColumnName = "double"; + auto columnId = nameTable->RegisterName(ColumnName); + columnNames.push_back(ColumnName); + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + doubleColumn.push_back((double)(rand() % 100) / 10.0); + rowsBuilders[rowIndex].AddValue(MakeUnversionedDoubleValue(doubleColumn[rowIndex], columnId)); + } + + // Fill any column. + ColumnName = "any"; + auto anyId = nameTable->RegisterName(ColumnName); + columnNames.push_back(ColumnName); + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + std::string randomString = MakeRandomString(stringSize); + + anyColumn.push_back(randomString); + + rowsBuilders[rowIndex].AddValue(MakeUnversionedAnyValue(randomString, anyId)); + } + + std::vector<TUnversionedOwningRow> owningRows; + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + owningRows.push_back(rowsBuilders[rowIndex].FinishRow()); + rows.push_back(owningRows.back().Get()); + } + + auto writer = CreateArrowWriter(nameTable, &outputStream, tableSchemas); + + EXPECT_TRUE(writer->Write(rows)); + + writer->Close() + .Get() + .ThrowOnError(); + + + auto batch = MakeBatch(outputStream); + + CheckColumnNames(batch, columnNames); + + EXPECT_EQ(ReadBoolArray(batch->column(0)), boolColumn); + EXPECT_EQ(ReadDoubleArray(batch->column(1)), doubleColumn); + EXPECT_EQ(ReadAnyStringArray(batch->column(2)), anyColumn); +} + +TEST(StressMultiBatch, Integer) +{ + // Constants. + const size_t columnsCount = 10; + const size_t rowsCount = 10; + const size_t numbOfBatch = 10; + + std::vector<std::string> columnNames; + std::vector<TTableSchemaPtr> tableSchemas; + std::vector<TColumnSchema> schemas_; + + for (size_t columnIdx = 0; columnIdx < columnsCount; columnIdx++) { + std::string ColumnName = "integer" + std::to_string(columnIdx); + columnNames.push_back(ColumnName); + schemas_.push_back(TColumnSchema(TString(columnNames[columnIdx]), EValueType::Int64)); + } + tableSchemas.push_back(New<TTableSchema>(schemas_)); + + TStringStream outputStream; + std::vector<std::vector<ColumnInteger>> columnsElements(numbOfBatch, std::vector<ColumnInteger>(columnsCount)); + + auto nameTable = New<TNameTable>(); + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + std::string ColumnName = "integer" + std::to_string(columnIndex); + nameTable->RegisterName(ColumnName); + } + auto writer = CreateArrowWriter(nameTable, &outputStream, tableSchemas); + + + for (size_t batchIndex = 0; batchIndex < numbOfBatch; batchIndex++) { + + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + std::string ColumnName = "integer" + std::to_string(columnIndex); + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + columnsElements[batchIndex][columnIndex].push_back(rand()); + } + } + + auto rows = MakeUnversionedIntegerRows(columnsElements[batchIndex], columnNames); + EXPECT_TRUE(writer->Write(rows.Rows)); + } + + writer->Close() + .Get() + .ThrowOnError(); + + + auto batches = MakeAllBatch(outputStream, numbOfBatch); + + size_t batchIndex = 0; + for (auto& batch : batches) { + for (size_t columnIndex = 0; columnIndex < columnsCount; columnIndex++) { + CheckColumnNames(batch, columnNames); + EXPECT_EQ(ReadInterger64Array(batch->column(columnIndex)), columnsElements[batchIndex][columnIndex]); + } + batchIndex++; + } +} + +TEST(StressMultiBatch, MixTypes) +{ + // Сonstants. + const size_t rowsCount = 10; + const size_t numbOfBatch = 10; + const size_t stringSize = 10; + + std::vector<TTableSchemaPtr> tableSchemas; + tableSchemas.push_back(New<TTableSchema>(std::vector{ + TColumnSchema("bool", EValueType::Boolean), + TColumnSchema("double", EValueType::Double), + TColumnSchema("any", EValueType::Any)})); + + TStringStream outputStream; + + auto nameTable = New<TNameTable>(); + + std::vector<std::string> columnNames = {"bool", "double", "any"}; + auto boolId = nameTable->RegisterName(columnNames[0]); + auto doubleId = nameTable->RegisterName(columnNames[1]); + auto anyId = nameTable->RegisterName(columnNames[2]); + + std::vector<ColumnBoolWithNulls> boolColumns(numbOfBatch); + std::vector<ColumnDoubleWithNulls> doubleColumns(numbOfBatch); + std::vector<ColumnStringWithNulls> anyColumns(numbOfBatch); + + auto writer = CreateArrowWriter(nameTable, &outputStream, tableSchemas); + + std::vector<TUnversionedOwningRow> owningRows; + + for (size_t batchIndex = 0; batchIndex < numbOfBatch; batchIndex++) { + std::vector<TUnversionedOwningRowBuilder> rowsBuilders(rowsCount); + std::vector<TUnversionedRow> rows; + + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + if (rand() % 2 == 0) { + boolColumns[batchIndex].push_back(std::nullopt); + doubleColumns[batchIndex].push_back(std::nullopt); + anyColumns[batchIndex].push_back(std::nullopt); + rowsBuilders[rowIndex].AddValue(MakeUnversionedNullValue(boolId)); + rowsBuilders[rowIndex].AddValue(MakeUnversionedNullValue(doubleId)); + rowsBuilders[rowIndex].AddValue(MakeUnversionedNullValue(anyId)); + } else { + boolColumns[batchIndex].push_back((rand() % 2) == 0); + rowsBuilders[rowIndex].AddValue(MakeUnversionedBooleanValue(*boolColumns[batchIndex][rowIndex], boolId)); + + doubleColumns[batchIndex].push_back((double)(rand() % 100) / 10.0); + rowsBuilders[rowIndex].AddValue(MakeUnversionedDoubleValue(*doubleColumns[batchIndex][rowIndex], doubleId)); + + std::string randomString = MakeRandomString(stringSize); + anyColumns[batchIndex].push_back(randomString); + rowsBuilders[rowIndex].AddValue(MakeUnversionedAnyValue(randomString, anyId)); + } + owningRows.push_back(rowsBuilders[rowIndex].FinishRow()); + rows.push_back(owningRows.back().Get()); + } + + EXPECT_TRUE(writer->Write(rows)); + } + + writer->Close() + .Get() + .ThrowOnError(); + + auto batches = MakeAllBatch(outputStream, numbOfBatch); + size_t batchIndex = 0; + for (auto& batch : batches) { + CheckColumnNames(batch, columnNames); + + auto boolAr = ReadBoolArray(batch->column(0)); + auto doubleAr = ReadDoubleArray(batch->column(1)); + auto anyAr = ReadAnyStringArray(batch->column(2)); + + for (size_t rowIndex = 0; rowIndex < rowsCount; rowIndex++) { + if (boolColumns[batchIndex][rowIndex] == std::nullopt) { + EXPECT_TRUE(batch->column(0)->IsNull(rowIndex)); + EXPECT_TRUE(batch->column(1)->IsNull(rowIndex)); + EXPECT_TRUE(batch->column(2)->IsNull(rowIndex)); + } else { + EXPECT_EQ(boolAr[rowIndex], *boolColumns[batchIndex][rowIndex]); + EXPECT_EQ(doubleAr[rowIndex], *doubleColumns[batchIndex][rowIndex]); + EXPECT_EQ(anyAr[rowIndex], *anyColumns[batchIndex][rowIndex]); + } + } + + batchIndex++; + } +} + +} // namespace +} // namespace NYT::NTableClient diff --git a/yt/yt/client/unittests/ya.make b/yt/yt/client/unittests/ya.make index ab9f547e19..da4a035bd5 100644 --- a/yt/yt/client/unittests/ya.make +++ b/yt/yt/client/unittests/ya.make @@ -9,6 +9,7 @@ PROTO_NAMESPACE(yt) SRCS( protobuf_format_ut.proto + arrow_writer_ut.cpp check_schema_compatibility_ut.cpp check_type_compatibility_ut.cpp chunk_replica_ut.cpp @@ -66,8 +67,11 @@ PEERDIR( yt/yt/client/formats yt/yt/client/unittests/mock yt/yt/library/named_value + yt/yt/ytlib yt/yt_proto/yt/formats + + contrib/libs/apache/arrow ) RESOURCE( diff --git a/yt/yt/core/misc/tls_guard.h b/yt/yt/core/misc/tls_guard.h new file mode 100644 index 0000000000..6c17b3edd0 --- /dev/null +++ b/yt/yt/core/misc/tls_guard.h @@ -0,0 +1,48 @@ +#pragma once + +#include <yt/yt/core/concurrency/scheduler_api.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +class TTlsGuard + : public NConcurrency::TForbidContextSwitchGuard +{ +public: + TTlsGuard(const TTlsGuard&) = delete; + TTlsGuard(TTlsGuard&&) = delete; + + TTlsGuard(T* ptr, T newValue) + : TlsPtr_(ptr) + , SavedValue_(newValue) + { + YT_VERIFY(ptr); + std::swap(SavedValue_, *TlsPtr_); + } + + ~TTlsGuard() + { + std::swap(SavedValue_, *TlsPtr_); + } + +private: + T* const TlsPtr_; + T SavedValue_; +}; + +// Modifies thread local variable and saves current value. +// On scope exit restores saved value. +// Forbids context switch. + +template <class T> +auto TlsGuard(T* ptr, T newValue) +{ + return TTlsGuard<T>(ptr, std::move(newValue)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + diff --git a/yt/yt/library/auth_server/auth_cache-inl.h b/yt/yt/library/auth_server/auth_cache-inl.h new file mode 100644 index 0000000000..1e0f229e3e --- /dev/null +++ b/yt/yt/library/auth_server/auth_cache-inl.h @@ -0,0 +1,137 @@ +#ifndef AUTH_CACHE_INL_H_ +#error "Direct inclusion of this file is not allowed, include auth_cache-inl.h" +// For the sake of sane code completion. +#include "auth_cache.h" +#endif + +#include "config.h" + +#include <yt/yt/core/profiling/timing.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TKey, class TValue, class TContext> +TFuture<TValue> TAuthCache<TKey, TValue, TContext>::Get(const TKey& key, const TContext& context) +{ + TEntryPtr entry; + { + auto guard = ReaderGuard(SpinLock_); + auto it = Cache_.find(key); + if (it != Cache_.end()) { + entry = it->second; + } + } + + auto now = NProfiling::GetCpuInstant(); + + if (entry) { + auto guard = Guard(entry->Lock); + auto future = entry->Future; + + entry->Context = context; + entry->LastAccessTime = now; + + if (entry->IsOutdated(Config_->CacheTtl, Config_->ErrorTtl) && !entry->Updating) { + entry->LastUpdateTime = now; + entry->Updating = true; + + auto context = entry->Context; + guard.Release(); + + DoGet(entry->Key, context) + .Subscribe(BIND([entry] (const TErrorOr<TValue>& value) { + auto transientError = !value.IsOK() && !value.FindMatching(NRpc::EErrorCode::InvalidCredentials); + + auto guard = Guard(entry->Lock); + entry->Updating = false; + + if (transientError) { + return; + } + + entry->Future = MakeFuture(value); + })); + } + + return future; + } + + entry = New<TEntry>(key, context); + entry->Promise = NewPromise<TValue>(); + entry->Future = entry->Promise.ToFuture(); + entry->LastUpdateTime = now; + + bool inserted = false; + + { + auto writerGuard = WriterGuard(SpinLock_); + auto it = Cache_.find(key); + if (it == Cache_.end()) { + inserted = true; + Cache_[key] = entry; + } else { + entry = it->second; + } + } + + if (inserted) { + entry->EraseCookie = NConcurrency::TDelayedExecutor::Submit( + BIND(&TAuthCache::TryErase, MakeWeak(this), MakeWeak(entry)), + Config_->OptimisticCacheTtl); + + entry->Promise.SetFrom(DoGet(entry->Key, entry->Context).ToUncancelable()); + } + + auto guard = Guard(entry->Lock); + return entry->Future; +} + +template <class TKey, class TValue, class TContext> +void TAuthCache<TKey, TValue, TContext>::TryErase(const TWeakPtr<TEntry>& weakEntry) +{ + auto entry = weakEntry.Lock(); + if (!entry) { + return; + } + + auto guard = Guard(entry->Lock); + if (entry->IsExpired(Config_->OptimisticCacheTtl)) { + auto writerGuard = WriterGuard(SpinLock_); + auto it = Cache_.find(entry->Key); + if (it != Cache_.end() && it->second == entry) { + Cache_.erase(it); + } + } else { + entry->EraseCookie = NConcurrency::TDelayedExecutor::Submit( + BIND(&TAuthCache::TryErase, MakeWeak(this), MakeWeak(entry)), + Config_->OptimisticCacheTtl); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class TKey, class TValue, class TContext> +bool TAuthCache<TKey, TValue, TContext>::TEntry::IsOutdated(TDuration ttl, TDuration errorTtl) +{ + auto now = NProfiling::GetCpuInstant(); + + auto value = Future.TryGet(); + if (value && !value->IsOK()) { + return now > LastUpdateTime + NProfiling::DurationToCpuDuration(errorTtl); + } else { + return now > LastUpdateTime + NProfiling::DurationToCpuDuration(ttl); + } +} + +template<class TKey, class TValue, class TContext> +bool TAuthCache<TKey, TValue, TContext>::TEntry::IsExpired(TDuration ttl) +{ + auto now = NProfiling::GetCpuInstant(); + return now > LastAccessTime + NProfiling::DurationToCpuDuration(ttl); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/auth_cache.h b/yt/yt/library/auth_server/auth_cache.h new file mode 100644 index 0000000000..24dfa12864 --- /dev/null +++ b/yt/yt/library/auth_server/auth_cache.h @@ -0,0 +1,72 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/profiling/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TKey, class TValue, class TContext> +class TAuthCache + : public virtual TRefCounted +{ +public: + TAuthCache( + TAuthCacheConfigPtr config, + NProfiling::TProfiler profiler = {}) + : Config_(std::move(config)) + , Profiler_(std::move(profiler)) + { } + + TFuture<TValue> Get(const TKey& key, const TContext& context); + +private: + struct TEntry + : public TRefCounted + { + const TKey Key; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, Lock); + TContext Context; + TFuture<TValue> Future; + TPromise<TValue> Promise; + + NConcurrency::TDelayedExecutorCookie EraseCookie; + NProfiling::TCpuInstant LastAccessTime; + + NProfiling::TCpuInstant LastUpdateTime; + bool Updating = false; + + bool IsOutdated(TDuration ttl, TDuration errorTtl); + bool IsExpired(TDuration ttl); + + TEntry(const TKey& key, const TContext& context) + : Key(key) + , Context(context) + , LastAccessTime(GetCpuInstant()) + , LastUpdateTime(GetCpuInstant()) + { } + }; + using TEntryPtr = TIntrusivePtr<TEntry>; + + const TAuthCacheConfigPtr Config_; + const NProfiling::TProfiler Profiler_; + + YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, SpinLock_); + THashMap<TKey, TEntryPtr> Cache_; + + virtual TFuture<TValue> DoGet(const TKey& key, const TContext& context) noexcept = 0; + void TryErase(const TWeakPtr<TEntry>& weakEntry); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth + +#define AUTH_CACHE_INL_H_ +#include "auth_cache-inl.h" +#undef AUTH_CACHE_INL_H_ diff --git a/yt/yt/library/auth_server/authentication_manager.cpp b/yt/yt/library/auth_server/authentication_manager.cpp new file mode 100644 index 0000000000..eb346622bc --- /dev/null +++ b/yt/yt/library/auth_server/authentication_manager.cpp @@ -0,0 +1,255 @@ +#include "authentication_manager.h" + +#include "blackbox_cookie_authenticator.h" +#include "blackbox_service.h" +#include "config.h" +#include "cookie_authenticator.h" +#include "cypress_cookie_manager.h" +#include "cypress_token_authenticator.h" +#include "cypress_user_manager.h" +#include "oauth_cookie_authenticator.h" +#include "oauth_token_authenticator.h" +#include "oauth_service.h" +#include "private.h" +#include "ticket_authenticator.h" +#include "token_authenticator.h" + +#include <yt/yt/core/rpc/authenticator.h> + +#include <yt/yt/library/tvm/service/tvm_service.h> + +namespace NYT::NAuth { + +using namespace NApi; +using namespace NConcurrency; +using namespace NHttp; +using namespace NProfiling; +using namespace NRpc; + +//////////////////////////////////////////////////////////////////////////////// + +class TAuthenticationManager + : public IAuthenticationManager +{ +public: + TAuthenticationManager( + TAuthenticationManagerConfigPtr config, + IPollerPtr poller, + NApi::IClientPtr client) + { + std::vector<NRpc::IAuthenticatorPtr> rpcAuthenticators; + std::vector<ITokenAuthenticatorPtr> tokenAuthenticators; + std::vector<ICookieAuthenticatorPtr> cookieAuthenticators; + + if (config->TvmService && poller) { + TvmService_ = CreateTvmService( + config->TvmService, + AuthProfiler.WithPrefix("/tvm/remote")); + } + + IBlackboxServicePtr blackboxService; + if (config->BlackboxService && poller) { + blackboxService = CreateBlackboxService( + config->BlackboxService, + TvmService_, + poller, + AuthProfiler.WithPrefix("/blackbox")); + } + + IOAuthServicePtr oauthService; + if (config->OAuthService && poller) { + oauthService = CreateOAuthService( + config->OAuthService, + poller, + AuthProfiler.WithPrefix("/oauth")); + } + + if (config->CypressCookieManager) { + CypressCookieManager_ = CreateCypressCookieManager( + config->CypressCookieManager, + client, + AuthProfiler); + cookieAuthenticators.push_back(CypressCookieManager_->GetCookieAuthenticator()); + } + + if (config->CypressUserManager) { + CypressUserManager_ = CreateCachingCypressUserManager( + config->CypressUserManager, + CreateCypressUserManager( + config->CypressUserManager, + client), + AuthProfiler.WithPrefix("/cypress_user_manager/cache")); + } + + if (config->BlackboxTokenAuthenticator && blackboxService) { + // COMPAT(gritukan): Set proper values in proxy configs and remove this code. + if (!TvmService_) { + config->BlackboxTokenAuthenticator->GetUserTicket = false; + } + + tokenAuthenticators.push_back( + CreateCachingTokenAuthenticator( + config->BlackboxTokenAuthenticator, + CreateBlackboxTokenAuthenticator( + config->BlackboxTokenAuthenticator, + blackboxService, + AuthProfiler.WithPrefix("/blackbox_token_authenticator/remote")), + AuthProfiler.WithPrefix("/blackbox_token_authenticator/cache"))); + } + + if (config->CypressTokenAuthenticator && client) { + tokenAuthenticators.push_back( + CreateCachingTokenAuthenticator( + config->CypressTokenAuthenticator, + CreateLegacyCypressTokenAuthenticator( + config->CypressTokenAuthenticator, + client), + AuthProfiler.WithPrefix("/legacy_cypress_token_authenticator/cache"))); + + tokenAuthenticators.push_back( + CreateCachingTokenAuthenticator( + config->CypressTokenAuthenticator, + CreateCypressTokenAuthenticator(client), + AuthProfiler.WithPrefix("/cypress_token_authenticator/cache"))); + } + + if (config->OAuthTokenAuthenticator && oauthService && CypressUserManager_) { + tokenAuthenticators.push_back( + CreateCachingTokenAuthenticator( + config->OAuthTokenAuthenticator, + CreateOAuthTokenAuthenticator( + config->OAuthTokenAuthenticator, + oauthService, + CypressUserManager_), + AuthProfiler.WithPrefix("/oauth_token_authenticator/cache"))); + } + + if (config->BlackboxCookieAuthenticator && blackboxService) { + // COMPAT(gritukan): Set proper values in proxy configs and remove this code. + if (!TvmService_) { + config->BlackboxCookieAuthenticator->GetUserTicket = false; + } + + cookieAuthenticators.push_back(CreateCachingCookieAuthenticator( + config->BlackboxCookieAuthenticator, + CreateBlackboxCookieAuthenticator( + config->BlackboxCookieAuthenticator, + blackboxService), + AuthProfiler.WithPrefix("/blackbox_cookie_authenticator/cache"))); + } + + if (config->OAuthCookieAuthenticator && oauthService && CypressUserManager_) { + cookieAuthenticators.push_back(CreateCachingCookieAuthenticator( + config->OAuthCookieAuthenticator, + CreateOAuthCookieAuthenticator( + config->OAuthCookieAuthenticator, + oauthService, + CypressUserManager_), + AuthProfiler.WithPrefix("/oauth_cookie_authenticator/cache"))); + } + + if (blackboxService && config->BlackboxTicketAuthenticator) { + TicketAuthenticator_ = CreateBlackboxTicketAuthenticator( + config->BlackboxTicketAuthenticator, + blackboxService, + TvmService_); + rpcAuthenticators.push_back( + CreateTicketAuthenticatorWrapper(TicketAuthenticator_)); + } + + if (!tokenAuthenticators.empty()) { + rpcAuthenticators.push_back(CreateTokenAuthenticatorWrapper( + CreateCompositeTokenAuthenticator(tokenAuthenticators))); + } + + if (!config->RequireAuthentication) { + tokenAuthenticators.push_back(CreateNoopTokenAuthenticator()); + } + TokenAuthenticator_ = CreateCompositeTokenAuthenticator(tokenAuthenticators); + + CookieAuthenticator_ = CreateCompositeCookieAuthenticator( + std::move(cookieAuthenticators)); + rpcAuthenticators.push_back(CreateCookieAuthenticatorWrapper(CookieAuthenticator_)); + + if (!config->RequireAuthentication) { + rpcAuthenticators.push_back(NRpc::CreateNoopAuthenticator()); + } + RpcAuthenticator_ = CreateCompositeAuthenticator(std::move(rpcAuthenticators)); + } + + void Start() override + { + if (CypressCookieManager_) { + CypressCookieManager_->Start(); + } + } + + void Stop() override + { + if (CypressCookieManager_) { + CypressCookieManager_->Stop(); + } + } + + const NRpc::IAuthenticatorPtr& GetRpcAuthenticator() const override + { + return RpcAuthenticator_; + } + + const ITokenAuthenticatorPtr& GetTokenAuthenticator() const override + { + return TokenAuthenticator_; + } + + const ICookieAuthenticatorPtr& GetCookieAuthenticator() const override + { + return CookieAuthenticator_; + } + + const ITicketAuthenticatorPtr& GetTicketAuthenticator() const override + { + return TicketAuthenticator_; + } + + const ITvmServicePtr& GetTvmService() const override + { + return TvmService_; + } + + const ICypressCookieManagerPtr& GetCypressCookieManager() const override + { + return CypressCookieManager_; + } + + const ICypressUserManagerPtr& GetCypressUserManager() const override + { + return CypressUserManager_; + } + +private: + ITvmServicePtr TvmService_; + NRpc::IAuthenticatorPtr RpcAuthenticator_; + ITokenAuthenticatorPtr TokenAuthenticator_; + ICookieAuthenticatorPtr CookieAuthenticator_; + ITicketAuthenticatorPtr TicketAuthenticator_; + + ICypressCookieManagerPtr CypressCookieManager_; + ICypressUserManagerPtr CypressUserManager_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +IAuthenticationManagerPtr CreateAuthenticationManager( + TAuthenticationManagerConfigPtr config, + IPollerPtr poller, + NApi::IClientPtr client) +{ + return New<TAuthenticationManager>( + std::move(config), + std::move(poller), + std::move(client)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/authentication_manager.h b/yt/yt/library/auth_server/authentication_manager.h new file mode 100644 index 0000000000..bfb94e2237 --- /dev/null +++ b/yt/yt/library/auth_server/authentication_manager.h @@ -0,0 +1,45 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/core/rpc/public.h> + +#include <yt/yt/core/actions/public.h> + +#include <yt/yt/core/concurrency/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct IAuthenticationManager + : public TRefCounted +{ + virtual void Start() = 0; + virtual void Stop() = 0; + + virtual const NRpc::IAuthenticatorPtr& GetRpcAuthenticator() const = 0; + virtual const ITokenAuthenticatorPtr& GetTokenAuthenticator() const = 0; + virtual const ICookieAuthenticatorPtr& GetCookieAuthenticator() const = 0; + virtual const ITicketAuthenticatorPtr& GetTicketAuthenticator() const = 0; + + virtual const ITvmServicePtr& GetTvmService() const = 0; + + virtual const ICypressCookieManagerPtr& GetCypressCookieManager() const = 0; + virtual const ICypressUserManagerPtr& GetCypressUserManager() const = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IAuthenticationManager) + +//////////////////////////////////////////////////////////////////////////////// + +IAuthenticationManagerPtr CreateAuthenticationManager( + TAuthenticationManagerConfigPtr config, + NConcurrency::IPollerPtr poller, + NApi::IClientPtr client); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/batching_secret_vault_service.cpp b/yt/yt/library/auth_server/batching_secret_vault_service.cpp new file mode 100644 index 0000000000..a5435eecfe --- /dev/null +++ b/yt/yt/library/auth_server/batching_secret_vault_service.cpp @@ -0,0 +1,154 @@ +#include "batching_secret_vault_service.h" +#include "secret_vault_service.h" +#include "config.h" +#include "private.h" + +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/throughput_throttler.h> + +#include <yt/yt/core/rpc/dispatcher.h> + +#include <queue> + +namespace NYT::NAuth { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class TBatchingSecretVaultService + : public ISecretVaultService +{ +public: + TBatchingSecretVaultService( + TBatchingSecretVaultServiceConfigPtr config, + ISecretVaultServicePtr underlying, + NProfiling::TProfiler profiler) + : Config_(std::move(config)) + , Underlying_(std::move(underlying)) + , TickExecutor_(New<TPeriodicExecutor>( + NRpc::TDispatcher::Get()->GetHeavyInvoker(), + BIND(&TBatchingSecretVaultService::OnTick, MakeWeak(this)), + Config_->BatchDelay)) + , RequestThrottler_(CreateReconfigurableThroughputThrottler( + Config_->RequestsThrottler, + NLogging::TLogger(), + profiler.WithPrefix("/request_throttler"))) + , BatchingLatencyTimer_(profiler.Timer("/batching_latency")) + { + TickExecutor_->Start(); + } + + TFuture<std::vector<TErrorOrSecretSubresponse>> GetSecrets(const std::vector<TSecretSubrequest>& subrequests) override + { + std::vector<TFuture<TSecretSubresponse>> asyncResults; + asyncResults.reserve(subrequests.size()); + auto guard = Guard(SpinLock_); + for (const auto& subrequest : subrequests) { + asyncResults.push_back(DoGetSecret(subrequest, guard)); + } + return AllSet(asyncResults); + } + + TFuture<TString> GetDelegationToken(TDelegationTokenRequest request) override + { + return Underlying_->GetDelegationToken(std::move(request)); + } + +private: + const TBatchingSecretVaultServiceConfigPtr Config_; + const ISecretVaultServicePtr Underlying_; + + const TPeriodicExecutorPtr TickExecutor_; + const IThroughputThrottlerPtr RequestThrottler_; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + + struct TQueueItem + { + TSecretSubrequest Subrequest; + TPromise<TSecretSubresponse> Promise; + TInstant EnqueueTime; + }; + std::queue<TQueueItem> SubrequestQueue_; + + NProfiling::TEventTimer BatchingLatencyTimer_; + + TFuture<TSecretSubresponse> DoGetSecret(const TSecretSubrequest& subrequest, TGuard<NThreading::TSpinLock>& /*guard*/) + { + auto promise = NewPromise<TSecretSubresponse>(); + SubrequestQueue_.push(TQueueItem{ + subrequest, + promise, + TInstant::Now() + }); + return promise.ToFuture(); + } + + void OnTick() + { + while (true) { + { + auto guard = Guard(SpinLock_); + if (SubrequestQueue_.empty()) { + break; + } + } + + if (!RequestThrottler_->TryAcquire(1)) { + break; + } + + std::vector<TQueueItem> items; + { + auto guard = Guard(SpinLock_); + while (!SubrequestQueue_.empty() && static_cast<int>(items.size()) < Config_->MaxSubrequestsPerRequest) { + items.push_back(SubrequestQueue_.front()); + SubrequestQueue_.pop(); + } + } + + if (items.empty()) { + break; + } + + std::vector<TSecretSubrequest> subrequests; + subrequests.reserve(items.size()); + auto now = TInstant::Now(); + for (const auto& item : items) { + subrequests.push_back(item.Subrequest); + BatchingLatencyTimer_.Record(now - item.EnqueueTime); + } + + Underlying_->GetSecrets(subrequests).Subscribe( + BIND([=, items = std::move(items)] (const TErrorOr<std::vector<TErrorOrSecretSubresponse>>& result) mutable { + if (result.IsOK()) { + const auto& subresponses = result.Value(); + for (size_t index = 0; index < items.size(); ++index) { + auto& item = items[index]; + item.Promise.Set(subresponses[index]); + } + } else { + for (auto& item : items) { + item.Promise.Set(TError(result)); + } + } + })); + } + } +}; + +ISecretVaultServicePtr CreateBatchingSecretVaultService( + TBatchingSecretVaultServiceConfigPtr config, + ISecretVaultServicePtr underlying, + NProfiling::TProfiler profiler) +{ + return New<TBatchingSecretVaultService>( + std::move(config), + std::move(underlying), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/batching_secret_vault_service.h b/yt/yt/library/auth_server/batching_secret_vault_service.h new file mode 100644 index 0000000000..e5c75da6b9 --- /dev/null +++ b/yt/yt/library/auth_server/batching_secret_vault_service.h @@ -0,0 +1,18 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ISecretVaultServicePtr CreateBatchingSecretVaultService( + TBatchingSecretVaultServiceConfigPtr config, + ISecretVaultServicePtr underlying, + NProfiling::TProfiler profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/blackbox_cookie_authenticator.cpp b/yt/yt/library/auth_server/blackbox_cookie_authenticator.cpp new file mode 100644 index 0000000000..b7777db475 --- /dev/null +++ b/yt/yt/library/auth_server/blackbox_cookie_authenticator.cpp @@ -0,0 +1,169 @@ +#include "blackbox_cookie_authenticator.h" + +#include "blackbox_service.h" +#include "config.h" +#include "cookie_authenticator.h" +#include "helpers.h" +#include "private.h" + +#include <yt/yt/core/crypto/crypto.h> + +#include <util/string/split.h> + +namespace NYT::NAuth { + +using namespace NYTree; +using namespace NYPath; +using namespace NCrypto; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +// TODO(sandello): Indicate to end-used that cookie must be resigned. +class TBlackboxCookieAuthenticator + : public ICookieAuthenticator +{ +public: + TBlackboxCookieAuthenticator( + TBlackboxCookieAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService) + : Config_(std::move(config)) + , BlackboxService_(std::move(blackboxService)) + { } + + const std::vector<TStringBuf>& GetCookieNames() const override + { + static const std::vector<TStringBuf> cookieNames{ + BlackboxSessionIdCookieName, + BlackboxSslSessionIdCookieName, + }; + return cookieNames; + } + + bool CanAuthenticate(const TCookieCredentials& credentials) const override + { + return credentials.Cookies.contains(BlackboxSessionIdCookieName); + } + + TFuture<TAuthenticationResult> Authenticate( + const TCookieCredentials& credentials) override + { + const auto& cookies = credentials.Cookies; + auto sessionId = GetOrCrash(cookies, BlackboxSessionIdCookieName); + + std::optional<TString> sslSessionId; + auto cookieIt = cookies.find(BlackboxSslSessionIdCookieName); + if (cookieIt != cookies.end()) { + sslSessionId = cookieIt->second; + } + + auto sessionIdMD5 = GetMD5HexDigestUpperCase(sessionId); + auto sslSessionIdMD5 = GetMD5HexDigestUpperCase(sslSessionId.value_or("")); + auto userIP = FormatUserIP(credentials.UserIP); + + YT_LOG_DEBUG( + "Authenticating user via session cookie (SessionIdMD5: %v, SslSessionIdMD5: %v, UserIP: %v)", + sessionIdMD5, + sslSessionIdMD5, + userIP); + + THashMap<TString, TString> params{ + {"sessionid", sessionId}, + {"host", Config_->Domain}, + {"userip", userIP}, + }; + + if (Config_->GetUserTicket) { + params["get_user_ticket"] = "yes"; + } + + if (sslSessionId) { + params["sslsessionid"] = *sslSessionId; + } + + return BlackboxService_->Call("sessionid", params) + .Apply(BIND( + &TBlackboxCookieAuthenticator::OnCallResult, + MakeStrong(this), + std::move(sessionIdMD5), + std::move(sslSessionIdMD5))); + } + +private: + const TBlackboxCookieAuthenticatorConfigPtr Config_; + const IBlackboxServicePtr BlackboxService_; + +private: + TFuture<TAuthenticationResult> OnCallResult( + const TString& sessionIdMD5, + const TString& sslSessionIdMD5, + const INodePtr& data) + { + auto result = OnCallResultImpl(data); + if (!result.IsOK()) { + YT_LOG_DEBUG(result, "Authentication failed (SessionIdMD5: %v, SslSessionIdMD5: %v)", sessionIdMD5, sslSessionIdMD5); + result.MutableAttributes()->Set("sessionid_md5", sessionIdMD5); + result.MutableAttributes()->Set("sslsessionid_md5", sslSessionIdMD5); + } else { + YT_LOG_DEBUG( + "Authentication successful (SessionIdMD5: %v, SslSessionIdMD5: %v, Login: %v, Realm: %v)", + sessionIdMD5, + sslSessionIdMD5, + result.Value().Login, + result.Value().Realm); + } + return MakeFuture(result); + } + + TErrorOr<TAuthenticationResult> OnCallResultImpl(const INodePtr& data) + { + auto statusId = GetByYPath<i64>(data, "/status/id"); + if (!statusId.IsOK()) { + return TError("Blackbox returned invalid response"); + } + + auto status = static_cast<EBlackboxStatus>(statusId.Value()); + if (status != EBlackboxStatus::Valid && status != EBlackboxStatus::NeedReset) { + auto error = GetByYPath<TString>(data, "/error"); + auto reason = error.IsOK() ? error.Value() : "unknown"; + return TError(NRpc::EErrorCode::InvalidCredentials, "Blackbox rejected session cookie") + << TErrorAttribute("reason", reason); + } + + auto login = BlackboxService_->GetLogin(data); + + // Sanity checks. + if (!login.IsOK()) { + return TError("Blackbox returned invalid response") + << login; + } + + TAuthenticationResult result; + result.Login = login.Value(); + result.Realm = "blackbox:cookie"; + auto userTicket = GetByYPath<TString>(data, "/user_ticket"); + if (userTicket.IsOK()) { + result.UserTicket = userTicket.Value(); + } else if (Config_->GetUserTicket) { + return TError("Failed to retrieve user ticket"); + } + return result; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateBlackboxCookieAuthenticator( + TBlackboxCookieAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService) +{ + return New<TBlackboxCookieAuthenticator>(std::move(config), std::move(blackboxService)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/blackbox_cookie_authenticator.h b/yt/yt/library/auth_server/blackbox_cookie_authenticator.h new file mode 100644 index 0000000000..88ee1ce02d --- /dev/null +++ b/yt/yt/library/auth_server/blackbox_cookie_authenticator.h @@ -0,0 +1,15 @@ +#pragma once + +#include "public.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateBlackboxCookieAuthenticator( + TBlackboxCookieAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/blackbox_service.cpp b/yt/yt/library/auth_server/blackbox_service.cpp new file mode 100644 index 0000000000..ddcb375c04 --- /dev/null +++ b/yt/yt/library/auth_server/blackbox_service.cpp @@ -0,0 +1,297 @@ +#include "blackbox_service.h" + +#include "config.h" +#include "helpers.h" +#include "private.h" + +#include <yt/yt/core/concurrency/delayed_executor.h> + +#include <yt/yt/core/http/client.h> +#include <yt/yt/core/http/http.h> + +#include <yt/yt/core/https/client.h> +#include <yt/yt/core/https/config.h> + +#include <yt/yt/core/json/json_parser.h> + +#include <yt/yt/core/rpc/dispatcher.h> + +#include <yt/yt/core/profiling/timing.h> + +#include <yt/yt/library/tvm/service/tvm_service.h> + +namespace NYT::NAuth { + +using namespace NConcurrency; +using namespace NHttp; +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TBlackboxService + : public IBlackboxService +{ +public: + TBlackboxService( + TBlackboxServiceConfigPtr config, + ITvmServicePtr tvmService, + IPollerPtr poller, + NProfiling::TProfiler profiler) + : Config_(std::move(config)) + , TvmService_(std::move(tvmService)) + , HttpClient_(Config_->Secure + ? NHttps::CreateClient(Config_->HttpClient, std::move(poller)) + : NHttp::CreateClient(Config_->HttpClient, std::move(poller))) + , BlackboxCalls_(profiler.Counter("/blackbox_calls")) + , BlackboxCallErrors_(profiler.Counter("/blackbox_call_errors")) + , BlackboxCallFatalErrors_(profiler.Counter("/blackbox_call_fatal_errors")) + , BlackboxCallTime_(profiler.Timer("/blackbox_call_time")) + { } + + TFuture<INodePtr> Call( + const TString& method, + const THashMap<TString, TString>& params) override + { + return BIND(&TBlackboxService::DoCall, MakeStrong(this), method, params) + .AsyncVia(NRpc::TDispatcher::Get()->GetLightInvoker()) + .Run(); + } + + TErrorOr<TString> GetLogin(const NYTree::INodePtr& reply) const override + { + if (Config_->UseLowercaseLogin) { + return GetByYPath<TString>(reply, "/attributes/1008"); + } else { + return GetByYPath<TString>(reply, "/login"); + } + } + +private: + const TBlackboxServiceConfigPtr Config_; + const ITvmServicePtr TvmService_; + + const NHttp::IClientPtr HttpClient_; + + NProfiling::TCounter BlackboxCalls_; + NProfiling::TCounter BlackboxCallErrors_; + NProfiling::TCounter BlackboxCallFatalErrors_; + NProfiling::TEventTimer BlackboxCallTime_; + +private: + INodePtr DoCall( + const TString& method, + const THashMap<TString, TString>& params) + { + auto deadline = TInstant::Now() + Config_->RequestTimeout; + auto callId = TGuid::Create(); + + TSafeUrlBuilder builder; + builder.AppendString(Format("%v://%v:%v/blackbox?", + Config_->Secure ? "https" : "http", + Config_->Host, + Config_->Port)); + builder.AppendParam(TStringBuf("method"), method); + for (const auto& param : params) { + builder.AppendChar('&'); + builder.AppendParam(param.first, param.second); + } + builder.AppendChar('&'); + builder.AppendParam("attributes", "1008"); + builder.AppendChar('&'); + builder.AppendParam("format", "json"); + + auto realUrl = builder.FlushRealUrl(); + auto safeUrl = builder.FlushSafeUrl(); + + auto httpHeaders = New<THeaders>(); + if (TvmService_) { + httpHeaders->Add("X-Ya-Service-Ticket", + TvmService_->GetServiceTicket(Config_->BlackboxServiceId)); + } + + std::vector<TError> accumulatedErrors; + + for (int attempt = 1; TInstant::Now() < deadline || attempt == 1; ++attempt) { + INodePtr result; + try { + BlackboxCalls_.Increment(); + result = DoCallOnce( + callId, + attempt, + realUrl, + safeUrl, + httpHeaders, + deadline); + } catch (const std::exception& ex) { + BlackboxCallErrors_.Increment(); + YT_LOG_WARNING( + ex, + "Blackbox call attempt failed, backing off (CallId: %v, Attempt: %v)", + callId, + attempt); + auto error = TError("Blackbox call attempt %v failed", attempt) + << ex + << TErrorAttribute("call_id", callId) + << TErrorAttribute("attempt", attempt); + accumulatedErrors.push_back(std::move(error)); + } + + // Check for known exceptions to retry. + if (result) { + auto exceptionNode = result->AsMap()->FindChild("exception"); + if (!exceptionNode || exceptionNode->GetType() != ENodeType::Map) { + // No exception information, go as-is. + return result; + } + + auto exceptionIdNode = exceptionNode->AsMap()->FindChild("id"); + if (!exceptionIdNode || exceptionIdNode->GetType() != ENodeType::Int64) { + // No exception information, go as-is. + return result; + } + + auto errorNode = result->AsMap()->FindChild("error"); + auto blackboxError = + errorNode && errorNode->GetType() == ENodeType::String + ? TError(errorNode->GetValue<TString>()) + : TError("Blackbox did not provide any human-readable error details"); + + switch (static_cast<EBlackboxException>(exceptionIdNode->GetValue<i64>())) { + case EBlackboxException::Ok: + return result; + case EBlackboxException::DBFetchFailed: + case EBlackboxException::DBException: + YT_LOG_WARNING(blackboxError, + "Blackbox has raised an exception, backing off (CallId: %v, Attempt: %v)", + callId, + attempt); + break; + default: + YT_LOG_WARNING(blackboxError, + "Blackbox has raised an exception (CallId: %v, Attempt: %v)", + callId, + attempt); + BlackboxCallFatalErrors_.Increment(); + THROW_ERROR_EXCEPTION("Blackbox has raised an exception") + << TErrorAttribute("call_id", callId) + << TErrorAttribute("attempt", attempt) + << blackboxError; + } + } + + auto now = TInstant::Now(); + if (now > deadline) { + break; + } + + TDelayedExecutor::WaitForDuration(std::min(Config_->BackoffTimeout, deadline - now)); + } + + BlackboxCallFatalErrors_.Increment(); + THROW_ERROR_EXCEPTION("Blackbox call failed") + << std::move(accumulatedErrors) + << TErrorAttribute("call_id", callId); + } + + static NJson::TJsonFormatConfigPtr MakeJsonFormatConfig() + { + auto config = New<NJson::TJsonFormatConfig>(); + config->EncodeUtf8 = false; // Hipsters use real Utf8. + return config; + } + + INodePtr DoCallOnce( + TGuid callId, + int attempt, + const TString& realUrl, + const TString& safeUrl, + const THeadersPtr& headers, + TInstant deadline) + { + auto onError = [&] (TError error) { + error.MutableAttributes()->Set("call_id", callId); + YT_LOG_DEBUG(error); + THROW_ERROR(error); + }; + + NProfiling::TWallTimer timer; + auto timeout = std::min(deadline - TInstant::Now(), Config_->AttemptTimeout); + + YT_LOG_DEBUG("Calling Blackbox (Url: %v, CallId: %v, Attempt: %v, Timeout: %v)", + safeUrl, + callId, + attempt, + timeout); + + auto rspOrError = WaitFor(HttpClient_->Get(realUrl, headers).WithTimeout(timeout)); + if (!rspOrError.IsOK()) { + onError(TError("Blackbox call failed") + << rspOrError); + } + + const auto& rsp = rspOrError.Value(); + if (rsp->GetStatusCode() != EStatusCode::OK) { + onError(TError("Blackbox call returned HTTP status code %v", + static_cast<int>(rsp->GetStatusCode()))); + } + + INodePtr rootNode; + try { + + YT_LOG_DEBUG("Started reading response body from Blackbox (CallId: %v, Attempt: %v)", + callId, + attempt); + + auto body = rsp->ReadAll(); + + YT_LOG_DEBUG("Finished reading response body from Blackbox (CallId: %v, Attempt: %v)", + callId, + attempt); + + TMemoryInput stream(body.Begin(), body.Size()); + auto factory = NYTree::CreateEphemeralNodeFactory(); + auto builder = NYTree::CreateBuilderFromFactory(factory.get()); + static const auto Config = MakeJsonFormatConfig(); + NJson::ParseJson(&stream, builder.get(), Config); + rootNode = builder->EndTree(); + + BlackboxCallTime_.Record(timer.GetElapsedTime()); + YT_LOG_DEBUG("Parsed Blackbox daemon reply (CallId: %v, Attempt: %v)", + callId, + attempt); + } catch (const std::exception& ex) { + onError(TError( + "Error parsing Blackbox response") + << ex); + } + + if (rootNode->GetType() != ENodeType::Map) { + THROW_ERROR_EXCEPTION("Blackbox has returned an improper result") + << TErrorAttribute("expected_result_type", ENodeType::Map) + << TErrorAttribute("actual_result_type", rootNode->GetType()); + } + + return rootNode; + } +}; + +IBlackboxServicePtr CreateBlackboxService( + TBlackboxServiceConfigPtr config, + ITvmServicePtr tvmService, + IPollerPtr poller, + NProfiling::TProfiler profiler) +{ + return New<TBlackboxService>( + std::move(config), + std::move(tvmService), + std::move(poller), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/blackbox_service.h b/yt/yt/library/auth_server/blackbox_service.h new file mode 100644 index 0000000000..1f4bb42c0e --- /dev/null +++ b/yt/yt/library/auth_server/blackbox_service.h @@ -0,0 +1,38 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/ytree/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +//! Abstracts away Blackbox. +//! See https://doc.yandex-team.ru/blackbox/ for API reference. +struct IBlackboxService + : public virtual TRefCounted +{ + virtual TFuture<NYTree::INodePtr> Call( + const TString& method, + const THashMap<TString, TString>& params) = 0; + virtual TErrorOr<TString> GetLogin(const NYTree::INodePtr& reply) const = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IBlackboxService) + +//////////////////////////////////////////////////////////////////////////////// + +IBlackboxServicePtr CreateBlackboxService( + TBlackboxServiceConfigPtr config, + ITvmServicePtr tvmService, + NConcurrency::IPollerPtr poller, + NProfiling::TProfiler profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/caching_secret_vault_service.cpp b/yt/yt/library/auth_server/caching_secret_vault_service.cpp new file mode 100644 index 0000000000..a5e4d34271 --- /dev/null +++ b/yt/yt/library/auth_server/caching_secret_vault_service.cpp @@ -0,0 +1,102 @@ +#include "caching_secret_vault_service.h" +#include "secret_vault_service.h" +#include "config.h" +#include "private.h" + +#include <yt/yt/core/misc/async_expiring_cache.h> + +namespace NYT::NAuth { + +using namespace NConcurrency; + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingSecretVaultService + : public ISecretVaultService + , public TAsyncExpiringCache< + ISecretVaultService::TSecretSubrequest, + ISecretVaultService::TSecretSubresponse + > +{ +public: + TCachingSecretVaultService( + TCachingSecretVaultServiceConfigPtr config, + ISecretVaultServicePtr underlying, + NProfiling::TProfiler profiler) + : TAsyncExpiringCache( + config->Cache, + AuthLogger.WithTag("Cache: SecretVault"), + std::move(profiler)) + , Underlying_(std::move(underlying)) + { } + + TFuture<std::vector<TErrorOrSecretSubresponse>> GetSecrets(const std::vector<TSecretSubrequest>& subrequests) override + { + std::vector<TFuture<TSecretSubresponse>> asyncResults; + THashMap<TSecretSubrequest, TFuture<TSecretSubresponse>> subrequestToAsyncResult; + for (const auto& subrequest : subrequests) { + auto it = subrequestToAsyncResult.find(subrequest); + if (it == subrequestToAsyncResult.end()) { + auto asyncResult = Get(subrequest); + YT_VERIFY(subrequestToAsyncResult.emplace(subrequest, asyncResult).second); + asyncResults.push_back(std::move(asyncResult)); + } else { + asyncResults.push_back(it->second); + } + } + return AllSet(asyncResults); + } + + TFuture<TString> GetDelegationToken(TDelegationTokenRequest request) override + { + return Underlying_->GetDelegationToken(std::move(request)); + } + +protected: + //! Called under write lock. + void OnAdded(const ISecretVaultService::TSecretSubrequest& subrequest) noexcept override + { + YT_LOG_DEBUG("Secret added to cache (SecretId: %v, SecretVersion: %v)", + subrequest.SecretId, + subrequest.SecretVersion); + } + + //! Called under write lock. + void OnRemoved(const ISecretVaultService::TSecretSubrequest& subrequest) noexcept override + { + YT_LOG_DEBUG("Secret removed from cache (SecretId: %v, SecretVersion: %v)", + subrequest.SecretId, + subrequest.SecretVersion); + } + +private: + const ISecretVaultServicePtr Underlying_; + + TFuture<TSecretSubresponse> DoGet( + const TSecretSubrequest& subrequest, + bool /*isPeriodicUpdate*/) noexcept override + { + return Underlying_->GetSecrets({subrequest}) + .Apply(BIND([] (const std::vector<TErrorOrSecretSubresponse>& result) { + YT_VERIFY(result.size() == 1); + return result[0].ValueOrThrow(); + })); + } +}; + +ISecretVaultServicePtr CreateCachingSecretVaultService( + TCachingSecretVaultServiceConfigPtr config, + ISecretVaultServicePtr underlying, + NProfiling::TProfiler profiler) +{ + return New<TCachingSecretVaultService>( + std::move(config), + std::move(underlying), + profiler); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/caching_secret_vault_service.h b/yt/yt/library/auth_server/caching_secret_vault_service.h new file mode 100644 index 0000000000..5191e2fde1 --- /dev/null +++ b/yt/yt/library/auth_server/caching_secret_vault_service.h @@ -0,0 +1,20 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ISecretVaultServicePtr CreateCachingSecretVaultService( + TCachingSecretVaultServiceConfigPtr config, + ISecretVaultServicePtr underlying, + NProfiling::TProfiler profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/config.cpp b/yt/yt/library/auth_server/config.cpp new file mode 100644 index 0000000000..8c68441029 --- /dev/null +++ b/yt/yt/library/auth_server/config.cpp @@ -0,0 +1,350 @@ +#include "config.h" + +#include <yt/yt/core/concurrency/config.h> + +#include <yt/yt/core/https/config.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +void TAuthCacheConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cache_ttl", &TThis::CacheTtl) + .Default(TDuration::Minutes(5)); + registrar.Parameter("optimistic_cache_ttl", &TThis::OptimisticCacheTtl) + .Default(TDuration::Hours(1)); + registrar.Parameter("error_ttl", &TThis::ErrorTtl) + .Default(TDuration::Seconds(15)); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TBlackboxServiceConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("http_client", &TThis::HttpClient) + .DefaultNew(); + registrar.Parameter("host", &TThis::Host) + .Default("blackbox.yandex-team.ru"); + registrar.Parameter("port", &TThis::Port) + .Default(443); + registrar.Parameter("secure", &TThis::Secure) + .Default(true); + registrar.Parameter("blackbox_service_id", &TThis::BlackboxServiceId) + .Default("blackbox"); + registrar.Parameter("request_timeout", &TThis::RequestTimeout) + .Default(TDuration::Seconds(15)); + registrar.Parameter("attempt_timeout", &TThis::AttemptTimeout) + .Default(TDuration::Seconds(10)); + registrar.Parameter("backoff_timeout", &TThis::BackoffTimeout) + .Default(TDuration::Seconds(1)); + registrar.Parameter("use_lowercase_login", &TThis::UseLowercaseLogin) + .Default(true); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TBlackboxTokenAuthenticatorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("scope", &TThis::Scope); + registrar.Parameter("enable_scope_check", &TThis::EnableScopeCheck) + .Default(true); + registrar.Parameter("get_user_ticket", &TThis::GetUserTicket) + .Default(true); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TBlackboxTicketAuthenticatorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("scopes", &TThis::Scopes) + .Optional(); + registrar.Parameter("enable_scope_check", &TThis::EnableScopeCheck) + .Default(false); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingTokenAuthenticatorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cache", &TThis::Cache) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingBlackboxTokenAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TCypressTokenAuthenticatorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("root_path", &TThis::RootPath) + .Default("//sys/tokens"); + registrar.Parameter("realm", &TThis::Realm) + .Default("cypress"); + + registrar.Parameter("secure", &TThis::Secure) + .Default(false); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingCypressTokenAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TOAuthTokenAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingOAuthTokenAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TBlackboxCookieAuthenticatorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("domain", &TThis::Domain) + .Default("yt.yandex-team.ru"); + + registrar.Parameter("csrf_secret", &TThis::CsrfSecret) + .Default(); + registrar.Parameter("csrf_token_ttl", &TThis::CsrfTokenTtl) + .Default(DefaultCsrfTokenTtl); + + registrar.Parameter("get_user_ticket", &TThis::GetUserTicket) + .Default(true); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TOAuthCookieAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TOAuthServiceConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("retrying_client", &TThis::RetryingClient) + .DefaultNew(); + registrar.Parameter("http_client", &TThis::HttpClient) + .DefaultNew(); + + registrar.Parameter("host", &TThis::Host) + .NonEmpty(); + registrar.Parameter("port", &TThis::Port) + .Default(80); + registrar.Parameter("secure", &TThis::Secure) + .Default(false); + + registrar.Parameter("authorization_header_prefix", &TThis::AuthorizationHeaderPrefix) + .Default("Bearer"); + registrar.Parameter("user_info_endpoint", &TThis::UserInfoEndpoint) + .Default("user/info"); + registrar.Parameter("user_info_login_field", &TThis::UserInfoLoginField) + .Default("nickname"); + registrar.Parameter("user_info_subject_field", &TThis::UserInfoSubjectField) + .Optional(); + registrar.Parameter("user_info_error_field", &TThis::UserInfoErrorField) + .Optional(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCypressUserManagerConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingCypressUserManagerConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cache", &TThis::Cache) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingBlackboxCookieAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingOAuthCookieAuthenticatorConfig::Register(TRegistrar /*registrar*/) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingCookieAuthenticatorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cache", &TThis::Cache) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TDefaultSecretVaultServiceConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("host", &TThis::Host) + .Default("vault-api.passport.yandex.net"); + registrar.Parameter("port", &TThis::Port) + .Default(443); + registrar.Parameter("secure", &TThis::Secure) + .Default(true); + registrar.Parameter("http_client", &TThis::HttpClient) + .DefaultNew(); + registrar.Parameter("request_timeout", &TThis::RequestTimeout) + .Default(TDuration::Seconds(3)); + registrar.Parameter("vault_service_id", &TThis::VaultServiceId) + .Default("yav"); + registrar.Parameter("consumer", &TThis::Consumer) + .Optional(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TBatchingSecretVaultServiceConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("batch_delay", &TThis::BatchDelay) + .Default(TDuration::MilliSeconds(100)); + registrar.Parameter("max_subrequests_per_request", &TThis::MaxSubrequestsPerRequest) + .Default(100) + .GreaterThan(0); + registrar.Parameter("requests_throttler", &TThis::RequestsThrottler) + .DefaultNew(); + + registrar.Preprocessor([] (TThis* config) { + config->RequestsThrottler->Limit = 100; + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCachingSecretVaultServiceConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cache", &TThis::Cache) + .DefaultNew(); + + registrar.Preprocessor([] (TThis* config) { + config->Cache->RefreshTime = std::nullopt; + config->Cache->ExpireAfterAccessTime = TDuration::Seconds(60); + config->Cache->ExpireAfterSuccessfulUpdateTime = TDuration::Seconds(60); + config->Cache->ExpireAfterFailedUpdateTime = TDuration::Seconds(60); + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCypressCookieStoreConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("full_fetch_period", &TThis::FullFetchPeriod) + .Default(TDuration::Minutes(5)); + + registrar.Parameter("error_eviction_time", &TThis::ErrorEvictionTime) + .Default(TDuration::Minutes(1)); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCypressCookieGeneratorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cookie_expiration_timeout", &TThis::CookieExpirationTimeout) + .Default(TDuration::Days(90)); + registrar.Parameter("cookie_renewal_period", &TThis::CookieRenewalPeriod) + .Default(TDuration::Days(30)); + + registrar.Parameter("secure", &TThis::Secure) + .Default(true); + + registrar.Parameter("http_only", &TThis::HttpOnly) + .Default(true); + + registrar.Parameter("domain", &TThis::Domain) + .Default(); + registrar.Parameter("path", &TThis::Path) + .Default("/"); + + registrar.Parameter("redirect_url", &TThis::RedirectUrl) + .Default(); + + registrar.Postprocessor([] (TThis* config) { + if (config->CookieRenewalPeriod > config->CookieExpirationTimeout) { + THROW_ERROR_EXCEPTION( + "\"cookie_renewal_period\" cannot be greater than \"cookie_expiration_timeout\""); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCypressCookieManagerConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cookie_store", &TThis::CookieStore) + .DefaultNew(); + registrar.Parameter("cookie_generator", &TThis::CookieGenerator) + .DefaultNew(); + registrar.Parameter("cookie_authenticator", &TThis::CookieAuthenticator) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString TAuthenticationManagerConfig::GetCsrfSecret() const +{ + if (BlackboxCookieAuthenticator && + BlackboxCookieAuthenticator->CsrfSecret) + { + return *BlackboxCookieAuthenticator->CsrfSecret; + } + + return TString(); +} + +TInstant TAuthenticationManagerConfig::GetCsrfTokenExpirationTime() const +{ + if (BlackboxCookieAuthenticator) { + return TInstant::Now() - BlackboxCookieAuthenticator->CsrfTokenTtl; + } + + return TInstant::Now() - DefaultCsrfTokenTtl; +} + +void TAuthenticationManagerConfig::Register(TRegistrar registrar) +{ + // COMPAT(prime@) + registrar.Parameter("require_authentication", &TThis::RequireAuthentication) + .Alias("enable_authentication") + .Default(true); + registrar.Parameter("blackbox_token_authenticator", &TThis::BlackboxTokenAuthenticator) + .Alias("token_authenticator") + .Optional(); + registrar.Parameter("blackbox_cookie_authenticator", &TThis::BlackboxCookieAuthenticator) + .Alias("cookie_authenticator") + .DefaultNew(); + registrar.Parameter("blackbox_service", &TThis::BlackboxService) + .Alias("blackbox") + .DefaultNew(); + registrar.Parameter("cypress_token_authenticator", &TThis::CypressTokenAuthenticator) + .Optional(); + registrar.Parameter("tvm_service", &TThis::TvmService) + .Optional(); + registrar.Parameter("blackbox_ticket_authenticator", &TThis::BlackboxTicketAuthenticator) + .Optional(); + registrar.Parameter("cypress_cookie_manager", &TThis::CypressCookieManager) + .Default(); + registrar.Parameter("oauth_cookie_authenticator", &TThis::OAuthCookieAuthenticator) + .Optional(); + registrar.Parameter("oauth_token_authenticator", &TThis::OAuthTokenAuthenticator) + .Optional(); + registrar.Parameter("oauth_service", &TThis::OAuthService) + .Optional(); + registrar.Parameter("cypress_user_manager", &TThis::CypressUserManager) + .Optional(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/config.h b/yt/yt/library/auth_server/config.h new file mode 100644 index 0000000000..7142cbfbce --- /dev/null +++ b/yt/yt/library/auth_server/config.h @@ -0,0 +1,466 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/http/public.h> + +#include <yt/yt/core/https/public.h> + +#include <yt/yt/core/misc/cache_config.h> + +#include <yt/yt/library/tvm/service/config.h> + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +class TAuthCacheConfig + : public virtual NYTree::TYsonStruct +{ +public: + TDuration CacheTtl; + TDuration OptimisticCacheTtl; + TDuration ErrorTtl; + + REGISTER_YSON_STRUCT(TAuthCacheConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TAuthCacheConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TBlackboxServiceConfig + : public virtual NYTree::TYsonStruct +{ +public: + NHttps::TClientConfigPtr HttpClient; + TString Host; + int Port; + bool Secure; + TString BlackboxServiceId; + + TDuration RequestTimeout; + TDuration AttemptTimeout; + TDuration BackoffTimeout; + bool UseLowercaseLogin; + + REGISTER_YSON_STRUCT(TBlackboxServiceConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TBlackboxServiceConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TBlackboxTokenAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + TString Scope; + bool EnableScopeCheck; + bool GetUserTicket; + + REGISTER_YSON_STRUCT(TBlackboxTokenAuthenticatorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TBlackboxTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TBlackboxTicketAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + THashSet<TString> Scopes; + bool EnableScopeCheck; + + REGISTER_YSON_STRUCT(TBlackboxTicketAuthenticatorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TBlackboxTicketAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingTokenAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + TAuthCacheConfigPtr Cache; + + REGISTER_YSON_STRUCT(TCachingTokenAuthenticatorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingBlackboxTokenAuthenticatorConfig + : public TBlackboxTokenAuthenticatorConfig + , public TCachingTokenAuthenticatorConfig +{ + REGISTER_YSON_STRUCT(TCachingBlackboxTokenAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingBlackboxTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressTokenAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + NYPath::TYPath RootPath; + TString Realm; + + bool Secure; + + REGISTER_YSON_STRUCT(TCypressTokenAuthenticatorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCypressTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingCypressTokenAuthenticatorConfig + : public TCachingTokenAuthenticatorConfig + , public TCypressTokenAuthenticatorConfig +{ + REGISTER_YSON_STRUCT(TCachingCypressTokenAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingCypressTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TOAuthTokenAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + REGISTER_YSON_STRUCT(TOAuthTokenAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TOAuthTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingOAuthTokenAuthenticatorConfig + : public TOAuthTokenAuthenticatorConfig + , public TCachingTokenAuthenticatorConfig +{ + REGISTER_YSON_STRUCT(TCachingOAuthTokenAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingOAuthTokenAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +static const auto DefaultCsrfTokenTtl = TDuration::Days(7); + +class TBlackboxCookieAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + TString Domain; + + std::optional<TString> CsrfSecret; + TDuration CsrfTokenTtl; + + bool GetUserTicket; + + REGISTER_YSON_STRUCT(TBlackboxCookieAuthenticatorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TBlackboxCookieAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TOAuthCookieAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + REGISTER_YSON_STRUCT(TOAuthCookieAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TOAuthCookieAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TOAuthServiceConfig + : public virtual NYTree::TYsonStruct +{ +public: + NHttp::TRetryingClientConfigPtr RetryingClient; + NHttps::TClientConfigPtr HttpClient; + + TString Host; + int Port; + bool Secure; + + TString AuthorizationHeaderPrefix; + TString UserInfoEndpoint; + TString UserInfoLoginField; + std::optional<TString> UserInfoSubjectField; + std::optional<TString> UserInfoErrorField; + + REGISTER_YSON_STRUCT(TOAuthServiceConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TOAuthServiceConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressUserManagerConfig + : public virtual NYTree::TYsonStruct +{ +public: + REGISTER_YSON_STRUCT(TCypressUserManagerConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCypressUserManagerConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingCookieAuthenticatorConfig + : public virtual NYTree::TYsonStruct +{ +public: + TAuthCacheConfigPtr Cache; + + REGISTER_YSON_STRUCT(TCachingCookieAuthenticatorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingCookieAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingCypressUserManagerConfig + : public TCypressUserManagerConfig +{ +public: + TAuthCacheConfigPtr Cache; + + REGISTER_YSON_STRUCT(TCachingCypressUserManagerConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingCypressUserManagerConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingBlackboxCookieAuthenticatorConfig + : public TBlackboxCookieAuthenticatorConfig + , public TCachingCookieAuthenticatorConfig +{ + REGISTER_YSON_STRUCT(TCachingBlackboxCookieAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingBlackboxCookieAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingOAuthCookieAuthenticatorConfig + : public TOAuthCookieAuthenticatorConfig + , public TCachingCookieAuthenticatorConfig +{ + REGISTER_YSON_STRUCT(TCachingOAuthCookieAuthenticatorConfig); + + static void Register(TRegistrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingOAuthCookieAuthenticatorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TDefaultSecretVaultServiceConfig + : public virtual NYT::NYTree::TYsonStruct +{ +public: + TString Host; + int Port; + bool Secure; + NHttps::TClientConfigPtr HttpClient; + TDuration RequestTimeout; + TString VaultServiceId; + TString Consumer; + + REGISTER_YSON_STRUCT(TDefaultSecretVaultServiceConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TDefaultSecretVaultServiceConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TBatchingSecretVaultServiceConfig + : public virtual NYT::NYTree::TYsonStruct +{ +public: + TDuration BatchDelay; + int MaxSubrequestsPerRequest; + NConcurrency::TThroughputThrottlerConfigPtr RequestsThrottler; + + REGISTER_YSON_STRUCT(TBatchingSecretVaultServiceConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TBatchingSecretVaultServiceConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingSecretVaultServiceConfig + : public TAsyncExpiringCacheConfig +{ +public: + TAsyncExpiringCacheConfigPtr Cache; + + REGISTER_YSON_STRUCT(TCachingSecretVaultServiceConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCachingSecretVaultServiceConfig) + +//////////////////////////////////////////////////////////////////////////////// + +struct TCypressCookieStoreConfig + : public NYTree::TYsonStruct +{ + //! Store will renew cookie list with this frequency. + TDuration FullFetchPeriod; + + //! Errors are cached for this period of time. + TDuration ErrorEvictionTime; + + REGISTER_YSON_STRUCT(TCypressCookieStoreConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCypressCookieStoreConfig) + +//////////////////////////////////////////////////////////////////////////////// + +struct TCypressCookieGeneratorConfig + : public NYTree::TYsonStruct +{ + //! Used to form ExpiresAt parameter. + TDuration CookieExpirationTimeout; + + //! If cookie will expire within this period, + //! authenticator will try to renew it. + TDuration CookieRenewalPeriod; + + //! Controls Secure parameter of a cookie. + //! If true, cookie will be used by user only + //! in https requests which prevents cookie + //! stealing because of unsecured connection, + //! so this field should be set to true in production + //! environments. + bool Secure; + + //! Controls HttpOnly parameter of a cookie. + bool HttpOnly; + + //! Domain parameter of generated cookies. + std::optional<TString> Domain; + + //! Path parameter of generated cookies. + TString Path; + + //! If set and if cookie is generated via login page, + //! will redirect user to this page. + std::optional<TString> RedirectUrl; + + REGISTER_YSON_STRUCT(TCypressCookieGeneratorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCypressCookieGeneratorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +struct TCypressCookieManagerConfig + : public NYTree::TYsonStruct +{ + TCypressCookieStoreConfigPtr CookieStore; + TCypressCookieGeneratorConfigPtr CookieGenerator; + TCachingBlackboxCookieAuthenticatorConfigPtr CookieAuthenticator; + + REGISTER_YSON_STRUCT(TCypressCookieManagerConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCypressCookieManagerConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TAuthenticationManagerConfig + : public virtual NYT::NYTree::TYsonStruct +{ +public: + bool RequireAuthentication; + TCachingBlackboxTokenAuthenticatorConfigPtr BlackboxTokenAuthenticator; + TCachingBlackboxCookieAuthenticatorConfigPtr BlackboxCookieAuthenticator; + TBlackboxServiceConfigPtr BlackboxService; + TCachingCypressTokenAuthenticatorConfigPtr CypressTokenAuthenticator; + TTvmServiceConfigPtr TvmService; + TBlackboxTicketAuthenticatorConfigPtr BlackboxTicketAuthenticator; + TCachingOAuthCookieAuthenticatorConfigPtr OAuthCookieAuthenticator; + TCachingOAuthTokenAuthenticatorConfigPtr OAuthTokenAuthenticator; + TOAuthServiceConfigPtr OAuthService; + + TCypressCookieManagerConfigPtr CypressCookieManager; + TCachingCypressUserManagerConfigPtr CypressUserManager; + + TString GetCsrfSecret() const; + + TInstant GetCsrfTokenExpirationTime() const; + + REGISTER_YSON_STRUCT(TAuthenticationManagerConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TAuthenticationManagerConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cookie_authenticator.cpp b/yt/yt/library/auth_server/cookie_authenticator.cpp new file mode 100644 index 0000000000..691a29e0bf --- /dev/null +++ b/yt/yt/library/auth_server/cookie_authenticator.cpp @@ -0,0 +1,223 @@ +#include "cookie_authenticator.h" + +#include "config.h" +#include "helpers.h" +#include "private.h" +#include "auth_cache.h" + +#include <yt/yt/core/rpc/authenticator.h> + +namespace NYT::NAuth { + +using namespace NYTree; +using namespace NYPath; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +struct TCookieAuthenticatorCacheKey +{ + TCookieCredentials Credentials; + + operator size_t() const + { + size_t result = 0; + + std::vector<std::pair<TString, TString>> cookies( + Credentials.Cookies.begin(), + Credentials.Cookies.end()); + std::sort(cookies.begin(), cookies.end()); + for (const auto& cookie : cookies) { + HashCombine(result, cookie.first); + HashCombine(result, cookie.second); + } + + return result; + } + + bool operator == (const TCookieAuthenticatorCacheKey& other) const + { + return Credentials.Cookies == other.Credentials.Cookies; + } +}; + +class TCachingCookieAuthenticator + : public ICookieAuthenticator + , private TAuthCache<TCookieAuthenticatorCacheKey, TAuthenticationResult, NNet::TNetworkAddress> +{ +public: + TCachingCookieAuthenticator( + TCachingCookieAuthenticatorConfigPtr config, + ICookieAuthenticatorPtr underlying, + NProfiling::TProfiler profiler) + : TAuthCache(config->Cache, std::move(profiler)) + , UnderlyingAuthenticator_(std::move(underlying)) + { } + + const std::vector<TStringBuf>& GetCookieNames() const override + { + return UnderlyingAuthenticator_->GetCookieNames(); + } + + bool CanAuthenticate(const TCookieCredentials& credentials) const override + { + return UnderlyingAuthenticator_->CanAuthenticate(credentials); + } + + TFuture<TAuthenticationResult> Authenticate( + const TCookieCredentials& credentials) override + { + return Get(TCookieAuthenticatorCacheKey{credentials}, credentials.UserIP); + } + +private: + const ICookieAuthenticatorPtr UnderlyingAuthenticator_; + + TFuture<TAuthenticationResult> DoGet( + const TCookieAuthenticatorCacheKey& key, + const NNet::TNetworkAddress& userIP) noexcept override + { + auto credentials = key.Credentials; + credentials.UserIP = userIP; + return UnderlyingAuthenticator_->Authenticate(credentials); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateCachingCookieAuthenticator( + TCachingCookieAuthenticatorConfigPtr config, + ICookieAuthenticatorPtr authenticator, + NProfiling::TProfiler profiler) +{ + return New<TCachingCookieAuthenticator>( + std::move(config), + std::move(authenticator), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TCompositeCookieAuthenticator + : public ICookieAuthenticator +{ +public: + explicit TCompositeCookieAuthenticator(std::vector<ICookieAuthenticatorPtr> authenticators) + : Authenticators_(std::move(authenticators)) + { + for (const auto& authenticator : Authenticators_) { + const auto& cookieNames = authenticator->GetCookieNames(); + CookieNames_.insert(CookieNames_.end(), cookieNames.begin(), cookieNames.end()); + } + } + + const std::vector<TStringBuf>& GetCookieNames() const override + { + return CookieNames_; + } + + bool CanAuthenticate(const TCookieCredentials& credentials) const override + { + for (const auto& authenticator : Authenticators_) { + if (authenticator->CanAuthenticate(credentials)) { + return true; + } + } + + return false; + } + + TFuture<TAuthenticationResult> Authenticate( + const TCookieCredentials& credentials) override + { + for (const auto& authenticator : Authenticators_) { + if (authenticator->CanAuthenticate(credentials)) { + TCookieCredentials filteredCredentials{ + .UserIP = credentials.UserIP, + }; + const auto& cookies = credentials.Cookies; + for (const auto& cookie : authenticator->GetCookieNames()) { + auto cookieIt = cookies.find(cookie); + if (cookieIt != cookies.end()) { + EmplaceOrCrash(filteredCredentials.Cookies, cookie, cookieIt->second); + } + } + + return authenticator->Authenticate(filteredCredentials); + } + } + + YT_ABORT(); + } + +private: + const std::vector<ICookieAuthenticatorPtr> Authenticators_; + + std::vector<TStringBuf> CookieNames_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateCompositeCookieAuthenticator( + std::vector<ICookieAuthenticatorPtr> authenticators) +{ + return New<TCompositeCookieAuthenticator>(std::move(authenticators)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TCookieAuthenticatorWrapper + : public NRpc::IAuthenticator +{ +public: + explicit TCookieAuthenticatorWrapper(ICookieAuthenticatorPtr underlying) + : Underlying_(std::move(underlying)) + { + YT_VERIFY(Underlying_); + } + + bool CanAuthenticate(const NRpc::TAuthenticationContext& context) override + { + if (!context.Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext)) { + return false; + } + + const auto& ext = context.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext); + if (!ext.has_session_id() && !ext.has_ssl_session_id()) { + return false; + } + + return context.UserIP.IsIP4() || context.UserIP.IsIP6(); + } + + TFuture<NRpc::TAuthenticationResult> AsyncAuthenticate( + const NRpc::TAuthenticationContext& context) override + { + YT_ASSERT(CanAuthenticate(context)); + const auto& ext = context.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext); + TCookieCredentials credentials; + auto& cookies = credentials.Cookies; + cookies[BlackboxSessionIdCookieName] = ext.session_id(); + cookies[BlackboxSslSessionIdCookieName] = ext.ssl_session_id(); + credentials.UserIP = context.UserIP; + return Underlying_->Authenticate(credentials).Apply( + BIND([=] (const TAuthenticationResult& authResult) { + NRpc::TAuthenticationResult rpcResult; + rpcResult.User = authResult.Login; + rpcResult.Realm = authResult.Realm; + rpcResult.UserTicket = authResult.UserTicket; + return rpcResult; + })); + } +private: + const ICookieAuthenticatorPtr Underlying_; +}; + +NRpc::IAuthenticatorPtr CreateCookieAuthenticatorWrapper(ICookieAuthenticatorPtr underlying) +{ + return New<TCookieAuthenticatorWrapper>(std::move(underlying)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cookie_authenticator.h b/yt/yt/library/auth_server/cookie_authenticator.h new file mode 100644 index 0000000000..42a69bfb38 --- /dev/null +++ b/yt/yt/library/auth_server/cookie_authenticator.h @@ -0,0 +1,47 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/rpc/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ICookieAuthenticator + : public virtual TRefCounted +{ + //! Returns list of cookie names which are used to authentication. + virtual const std::vector<TStringBuf>& GetCookieNames() const = 0; + + //! Returns true if user provided enough cookies to perform authentication. + virtual bool CanAuthenticate(const TCookieCredentials& credentials) const = 0; + + virtual TFuture<TAuthenticationResult> Authenticate( + const TCookieCredentials& credentials) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ICookieAuthenticator) + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateCachingCookieAuthenticator( + TCachingCookieAuthenticatorConfigPtr config, + ICookieAuthenticatorPtr authenticator, + NProfiling::TProfiler profiler = {}); + +ICookieAuthenticatorPtr CreateCompositeCookieAuthenticator( + std::vector<ICookieAuthenticatorPtr> authenticators); + +//////////////////////////////////////////////////////////////////////////////// + +NRpc::IAuthenticatorPtr CreateCookieAuthenticatorWrapper( + ICookieAuthenticatorPtr underlying); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie.cpp b/yt/yt/library/auth_server/cypress_cookie.cpp new file mode 100644 index 0000000000..e277e15372 --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie.cpp @@ -0,0 +1,55 @@ +#include "cypress_cookie.h" + +#include "config.h" + +#include <yt/yt/core/crypto/crypto.h> + +#include <util/string/hex.h> + +namespace NYT::NAuth { + +using namespace NCrypto; + +//////////////////////////////////////////////////////////////////////////////// + +TString TCypressCookie::ToHeader(const TCypressCookieGeneratorConfigPtr& config) const +{ + auto header = Format("%v=%v; Expires=%v", + CypressCookieName, + Value, + ExpiresAt.ToRfc822String()); + if (config->Secure) { + header += "; Secure"; + } + if (config->HttpOnly) { + header += "; HttpOnly"; + } + if (const auto& domain = config->Domain) { + header += Format("; Domain=%v", domain); + } + header += Format("; Path=%v", config->Path); + + return header; +} + +void TCypressCookie::Register(TRegistrar registrar) +{ + registrar.Parameter("value", &TThis::Value); + registrar.Parameter("user", &TThis::User); + registrar.Parameter("password_revision", &TThis::PasswordRevision); + registrar.Parameter("expires_at", &TThis::ExpiresAt); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString GenerateCookieValue() +{ + constexpr int ValueSize = 32; + + auto rawCookie = GenerateCryptoStrongRandomString(ValueSize); + return HexEncode(rawCookie.data(), rawCookie.size()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie.h b/yt/yt/library/auth_server/cypress_cookie.h new file mode 100644 index 0000000000..b7e52d020e --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie.h @@ -0,0 +1,46 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct TCypressCookie + : public NYTree::TYsonStruct +{ + //! Value of the cookie. + TString Value; + + //! User for which cookie is issued. + TString User; + + //! Revision of password in the moment of cookie issue. + ui64 PasswordRevision; + + //! Cookie expiration instant. + TInstant ExpiresAt; + + //! Returns text representation of a cookie for SetCookie header. + TString ToHeader(const TCypressCookieGeneratorConfigPtr& config) const; + + REGISTER_YSON_STRUCT(TCypressCookie); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TCypressCookie) + +//////////////////////////////////////////////////////////////////////////////// + +//! Generates new cookie value using cryptographically strong generator. +// NB: May throw on RNG failure. +TString GenerateCookieValue(); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_authenticator.cpp b/yt/yt/library/auth_server/cypress_cookie_authenticator.cpp new file mode 100644 index 0000000000..f4ced1221a --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_authenticator.cpp @@ -0,0 +1,202 @@ +#include "cypress_cookie_authenticator.h" + +#include "config.h" +#include "cypress_cookie_store.h" + +#include <yt/yt/client/api/client.h> + +#include <yt/yt/library/auth_server/cookie_authenticator.h> +#include <yt/yt/library/auth_server/helpers.h> +#include <yt/yt/library/auth_server/private.h> + +#include <yt/yt/core/crypto/crypto.h> + +namespace NYT::NAuth { + +using namespace NApi; +using namespace NConcurrency; +using namespace NCrypto; +using namespace NYPath; +using namespace NYTree; +using namespace NYson; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressCookieAuthenticator + : public ICookieAuthenticator +{ +public: + TCypressCookieAuthenticator( + TCypressCookieGeneratorConfigPtr config, + ICypressCookieStorePtr cookieStore, + IClientPtr client) + : Config_(std::move(config)) + , CookieStore_(std::move(cookieStore)) + , Client_(std::move(client)) + { } + + const std::vector<TStringBuf>& GetCookieNames() const override + { + static const std::vector<TStringBuf> cookieNames{ + CypressCookieName, + }; + return cookieNames; + } + + bool CanAuthenticate(const TCookieCredentials& credentials) const override + { + return credentials.Cookies.contains(CypressCookieName); + } + + TFuture<TAuthenticationResult> Authenticate( + const TCookieCredentials& credentials) override + { + const auto& cookieValue = GetOrCrash(credentials.Cookies, CypressCookieName); + + YT_LOG_DEBUG( + "Authenticating user via native cookie (CookieMD5: %v, UserIP: %v)", + GetMD5HexDigestUpperCase(cookieValue), + FormatUserIP(credentials.UserIP)); + + return CookieStore_->GetCookie(cookieValue) + .Apply(BIND(&TCypressCookieAuthenticator::OnGotCookie, MakeStrong(this))) + .Apply(BIND([] (const TErrorOr<TAuthenticationResult>& resultOrError) -> TErrorOr<TAuthenticationResult> { + if (resultOrError.FindMatching(NYTree::EErrorCode::ResolveError)) { + return TError( + NRpc::EErrorCode::InvalidCredentials, + "Unknown credentials") + << resultOrError; + } + + return resultOrError; + })); + } + +private: + const TCypressCookieGeneratorConfigPtr Config_; + + const ICypressCookieStorePtr CookieStore_; + + const IClientPtr Client_; + + TFuture<ui64> GetUserPasswordRevision(const TString& user) + { + auto path = Format("//sys/users/%v", ToYPathLiteral(user)); + + constexpr TStringBuf PasswordRevisionAttribute = "password_revision"; + + TGetNodeOptions options; + options.Attributes = std::vector<TString>({ + TString{PasswordRevisionAttribute}, + }); + + return Client_->GetNode(path, options) + .Apply(BIND([=] (const TYsonString& rsp) { + auto rspNode = ConvertToNode(rsp); + return rspNode->Attributes().Get<ui64>(PasswordRevisionAttribute); + })); + } + + TFuture<TAuthenticationResult> OnGotCookie(const TCypressCookiePtr& cookie) + { + return GetUserPasswordRevision(cookie->User) + .Apply(BIND(&TCypressCookieAuthenticator::OnGotPasswordRevision, MakeStrong(this), cookie)); + } + + TAuthenticationResult OnGotPasswordRevision( + const TCypressCookiePtr& cookie, + ui64 passwordRevision) + { + if (cookie->PasswordRevision != passwordRevision) { + THROW_ERROR_EXCEPTION(NRpc::EErrorCode::InvalidCredentials, + "Native cookie was issued for previous password revision") + << TErrorAttribute("cookie_password_revision", cookie->PasswordRevision) + << TErrorAttribute("password_revision", passwordRevision); + } + + auto now = TInstant::Now(); + if (cookie->ExpiresAt < now) { + THROW_ERROR_EXCEPTION(NRpc::EErrorCode::InvalidCredentials, + "Native cookie expired") + << TErrorAttribute("cookie_expiration_time", cookie->ExpiresAt); + } + + const auto& user = cookie->User; + TAuthenticationResult result{ + .Login = user, + }; + + if (cookie->ExpiresAt < now + Config_->CookieRenewalPeriod) { + auto latestCookie = CookieStore_->GetLastCookieForUser(user); + + // Very unlikely, but might happen during cookie duration reconfiguration. + if (latestCookie && latestCookie->PasswordRevision != passwordRevision) { + CookieStore_->RemoveLastCookieForUser(user); + latestCookie.Reset(); + } + + if (latestCookie && latestCookie->ExpiresAt > now + Config_->CookieRenewalPeriod) { + result.SetCookie = latestCookie->ToHeader(Config_); + } else { + auto newCookie = New<TCypressCookie>(); + newCookie->Value = GenerateCookieValue(); + newCookie->User = user; + newCookie->PasswordRevision = passwordRevision; + newCookie->ExpiresAt = TInstant::Now() + Config_->CookieExpirationTimeout; + + YT_LOG_DEBUG("Issuing new cookie for renewal " + "(User: %v, CookieMD5: %v, PasswordRevision: %v, ExpiresAt: %v)", + user, + GetMD5HexDigestUpperCase(newCookie->Value), + passwordRevision, + newCookie->ExpiresAt); + + auto error = WaitFor(CookieStore_->RegisterCookie(newCookie)); + if (error.IsOK()) { + YT_LOG_DEBUG("Issued new cookie for renewal (User: %v, CookieMD5: %v)", + user, + GetMD5HexDigestUpperCase(newCookie->Value)); + result.SetCookie = newCookie->ToHeader(Config_); + } else { + // NB: Cookie creation failure should not lead to authentication error. + YT_LOG_DEBUG(error, "Failed to issue new cookie for renewal (User: %v, CookieMD5: %v)", + user, + GetMD5HexDigestUpperCase(newCookie->Value)); + } + } + } + + std::optional<TString> setCookieMD5; + if (auto setCookie = result.SetCookie) { + setCookieMD5 = GetMD5HexDigestUpperCase(*setCookie); + } + + YT_LOG_DEBUG("User authenticated (User: %v, CookieMD5: %v, SetCookieMD5: %v)", + user, + GetMD5HexDigestUpperCase(cookie->Value), + setCookieMD5); + + return result; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateCypressCookieAuthenticator( + TCypressCookieGeneratorConfigPtr config, + ICypressCookieStorePtr cookieStore, + IClientPtr client) +{ + return New<TCypressCookieAuthenticator>( + std::move(config), + std::move(cookieStore), + std::move(client)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_authenticator.h b/yt/yt/library/auth_server/cypress_cookie_authenticator.h new file mode 100644 index 0000000000..e911141cd6 --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_authenticator.h @@ -0,0 +1,18 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateCypressCookieAuthenticator( + TCypressCookieGeneratorConfigPtr config, + ICypressCookieStorePtr cookieStore, + NApi::IClientPtr client); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_login.cpp b/yt/yt/library/auth_server/cypress_cookie_login.cpp new file mode 100644 index 0000000000..6bdb6b4dca --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_login.cpp @@ -0,0 +1,222 @@ +#include "cypress_cookie_login.h" + +#include "config.h" +#include "cypress_cookie_store.h" + +#include <yt/yt/client/api/client.h> + +#include <yt/yt/library/auth_server/private.h> + +#include <yt/yt/core/crypto/crypto.h> + +#include <yt/yt/core/http/helpers.h> + +#include <library/cpp/string_utils/base64/base64.h> + +namespace NYT::NAuth { + +using namespace NApi; +using namespace NConcurrency; +using namespace NCrypto; +using namespace NHttp; +using namespace NYPath; +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressCookieLoginHandler + : public IHttpHandler +{ +public: + TCypressCookieLoginHandler( + TCypressCookieGeneratorConfigPtr config, + NApi::IClientPtr client, + ICypressCookieStorePtr cookieStore) + : Config_(std::move(config)) + , Client_(std::move(client)) + , CookieStore_(std::move(cookieStore)) + { } + + void HandleRequest( + const IRequestPtr& req, + const IResponseWriterPtr& rsp) override + { + if (auto header = req->GetHeaders()->Find(AuthorizationHeader)) { + HandleLoginRequest(*header, req, rsp); + } else { + HandleRegularRequest(rsp); + } + + WaitFor(rsp->Close()) + .ThrowOnError(); + } + +private: + const TCypressCookieGeneratorConfigPtr Config_; + + const NApi::IClientPtr Client_; + + const ICypressCookieStorePtr CookieStore_; + + constexpr static TStringBuf AuthorizationHeader = "Authorization"; + constexpr static TStringBuf SetCookieHedaer = "Set-Cookie"; + constexpr static TStringBuf BasicAuthorizationMethod = "Basic"; + + struct TUserInfo + { + TString HashedPassword; + TString PasswordSalt; + ui64 PasswordRevision; + }; + + void HandleLoginRequest( + TStringBuf authorizationHeader, + const IRequestPtr& req, + const IResponseWriterPtr& rsp) + { + auto replyAndLogError = [&] (const TError& error, const std::optional<TString>& user = {}) { + ReplyError(rsp, error); + YT_LOG_DEBUG(error, "Failed to login user using password (ConnectionId: %v, User: %v)", + req->GetConnectionId(), + user); + }; + + TStringBuf authorizationMethod; + TStringBuf encodedCredentials; + if (!authorizationHeader.TrySplit(' ', authorizationMethod, encodedCredentials)) { + rsp->SetStatus(EStatusCode::BadRequest); + + auto error = TError("Malformed \"Authorization\" header: failed to parse authorization method"); + replyAndLogError(error); + return; + } + + if (authorizationMethod != BasicAuthorizationMethod) { + rsp->SetStatus(EStatusCode::BadRequest); + + auto error = TError("Unsupported authorization method %Qlv", authorizationMethod); + replyAndLogError(error); + return; + } + + auto credentials = Base64StrictDecode(encodedCredentials); + TStringBuf user; + TStringBuf password; + if (!TStringBuf{credentials}.TrySplit(':', user, password)) { + rsp->SetStatus(EStatusCode::BadRequest); + + auto error = TError("Failed to parse user credentials"); + replyAndLogError(error); + return; + } + + TUserInfo userInfo; + try { + userInfo = FetchUserInfo(TString{user}); + } catch (const std::exception& ex) { + auto error = TError(ex); + if (error.FindMatching(NYTree::EErrorCode::ResolveError)) { + HandleRegularRequest(rsp); + + error = TError("No such user %Qlv or user has no password set", user) << error; + replyAndLogError(error, TString{user}); + return; + } + + // Unknown error, reply 500. + error = TError("Failed to fetch info for user %Qlv during logging", user) << error; + replyAndLogError(error, TString{user}); + throw; + } + + if (HashPassword(TString{password}, userInfo.PasswordSalt) != userInfo.HashedPassword) { + HandleRegularRequest(rsp); + + auto error = TError("Invalid password"); + replyAndLogError(error, TString{user}); + return; + } + + auto cookie = New<TCypressCookie>(); + cookie->Value = GenerateCookieValue(); + cookie->User = user; + cookie->PasswordRevision = userInfo.PasswordRevision; + cookie->ExpiresAt = TInstant::Now() + Config_->CookieExpirationTimeout; + + auto error = WaitFor(CookieStore_->RegisterCookie(cookie)); + if (!error.IsOK()) { + error = TError("Failed to register cookie in cookie store"); + replyAndLogError(error, TString{user}); + // Will return 500. + error.ThrowOnError(); + } + + YT_LOG_DEBUG("Issued new cookie for user (User: %v, CookieMD5: %v)", + user, + GetMD5HexDigestUpperCase(cookie->Value)); + + if (const auto& redirectUrl = Config_->RedirectUrl) { + rsp->SetStatus(EStatusCode::PermanentRedirect); + rsp->GetHeaders()->Add("Location", *redirectUrl); + } else { + rsp->SetStatus(EStatusCode::OK); + } + + rsp->GetHeaders()->Add(TString{SetCookieHedaer}, cookie->ToHeader(Config_)); + } + + void HandleRegularRequest(const IResponseWriterPtr& rsp) + { + rsp->SetStatus(EStatusCode::Unauthorized); + + rsp->GetHeaders()->Add("WWW-Authenticate", "Basic"); + } + + TUserInfo FetchUserInfo(const TString& user) + { + auto path = Format("//sys/users/%v", ToYPathLiteral(user)); + + constexpr TStringBuf HashedPasswordAttribute = "hashed_password"; + constexpr TStringBuf PasswordSaltAttribute = "password_salt"; + constexpr TStringBuf PasswordRevisionAttribute = "password_revision"; + + TGetNodeOptions options; + options.Attributes = std::vector<TString>({ + TString{HashedPasswordAttribute}, + TString{PasswordSaltAttribute}, + TString{PasswordRevisionAttribute}, + }); + + auto rsp = WaitFor(Client_->GetNode(path, options)) + .ValueOrThrow(); + auto rspNode = ConvertToNode(rsp); + const auto& attributes = rspNode->Attributes(); + + return TUserInfo{ + .HashedPassword = attributes.Get<TString>(HashedPasswordAttribute), + .PasswordSalt = attributes.Get<TString>(PasswordSaltAttribute), + .PasswordRevision = attributes.Get<ui64>(PasswordRevisionAttribute), + }; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +IHttpHandlerPtr CreateCypressCookieLoginHandler( + TCypressCookieGeneratorConfigPtr config, + NApi::IClientPtr client, + ICypressCookieStorePtr cookieStore) +{ + return New<TCypressCookieLoginHandler>( + std::move(config), + std::move(client), + std::move(cookieStore)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_login.h b/yt/yt/library/auth_server/cypress_cookie_login.h new file mode 100644 index 0000000000..eaba6ee0b2 --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_login.h @@ -0,0 +1,20 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/core/http/http.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +NHttp::IHttpHandlerPtr CreateCypressCookieLoginHandler( + TCypressCookieGeneratorConfigPtr config, + NApi::IClientPtr client, + ICypressCookieStorePtr cookieStore); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_manager.cpp b/yt/yt/library/auth_server/cypress_cookie_manager.cpp new file mode 100644 index 0000000000..c1b9fe862d --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_manager.cpp @@ -0,0 +1,81 @@ +#include "cypress_cookie_manager.h" + +#include "config.h" +#include "cookie_authenticator.h" +#include "cypress_cookie_authenticator.h" +#include "cypress_cookie_store.h" + +#include <yt/yt/core/rpc/dispatcher.h> + +namespace NYT::NAuth { + +using namespace NApi; +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressCookieManager + : public ICypressCookieManager +{ +public: + TCypressCookieManager( + TCypressCookieManagerConfigPtr config, + IClientPtr client, + TProfiler profiler) + : CookieStore_(CreateCypressCookieStore( + config->CookieStore, + client, + NRpc::TDispatcher::Get()->GetHeavyInvoker())) + , CookieAuthenticator_(CreateCypressCookieAuthenticator( + config->CookieGenerator, + CookieStore_, + client)) + , CachingCookieAuthenticator_(CreateCachingCookieAuthenticator( + config->CookieAuthenticator, + CookieAuthenticator_, + profiler.WithPrefix("/cypress_cookie_authenticator/cache"))) + { } + + void Start() + { + CookieStore_->Start(); + } + + void Stop() + { + CookieStore_->Stop(); + } + + const ICypressCookieStorePtr& GetCookieStore() const + { + return CookieStore_; + } + + const ICookieAuthenticatorPtr& GetCookieAuthenticator() const + { + return CachingCookieAuthenticator_; + } + +private: + const ICypressCookieStorePtr CookieStore_; + + const ICookieAuthenticatorPtr CookieAuthenticator_; + const ICookieAuthenticatorPtr CachingCookieAuthenticator_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICypressCookieManagerPtr CreateCypressCookieManager( + TCypressCookieManagerConfigPtr config, + IClientPtr client, + TProfiler profiler) +{ + return New<TCypressCookieManager>( + std::move(config), + std::move(client), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_manager.h b/yt/yt/library/auth_server/cypress_cookie_manager.h new file mode 100644 index 0000000000..d24d068b5b --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_manager.h @@ -0,0 +1,37 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/http/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ICypressCookieManager + : public TRefCounted +{ + virtual void Start() = 0; + virtual void Stop() = 0; + + virtual const ICypressCookieStorePtr& GetCookieStore() const = 0; + + virtual const ICookieAuthenticatorPtr& GetCookieAuthenticator() const = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ICypressCookieManager) + +//////////////////////////////////////////////////////////////////////////////// + +ICypressCookieManagerPtr CreateCypressCookieManager( + TCypressCookieManagerConfigPtr config, + NApi::IClientPtr client, + NProfiling::TProfiler profiler); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_store.cpp b/yt/yt/library/auth_server/cypress_cookie_store.cpp new file mode 100644 index 0000000000..3027ff22fa --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_store.cpp @@ -0,0 +1,300 @@ +#include "cypress_cookie_store.h" + +#include "config.h" + +#include <yt/yt/library/auth_server/private.h> + +#include <yt/yt/client/api/client.h> + +#include <yt/yt/core/concurrency/periodic_executor.h> + +namespace NYT::NAuth { + +using namespace NApi; +using namespace NConcurrency; +using namespace NObjectClient; +using namespace NThreading; +using namespace NYson; +using namespace NYPath; +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressCookieStore + : public ICypressCookieStore +{ +public: + TCypressCookieStore( + TCypressCookieStoreConfigPtr config, + IClientPtr client, + IInvokerPtr invoker) + : Config_(std::move(config)) + , Client_(std::move(client)) + , UpdateExecutor_(New<TPeriodicExecutor>( + std::move(invoker), + BIND(&TCypressCookieStore::DoFetchAllCookies, MakeWeak(this)), + Config_->FullFetchPeriod)) + { } + + void Start() override + { + UpdateExecutor_->Start(); + + YT_LOG_DEBUG("Starting periodic updates in native cookie store"); + } + + void Stop() override + { + YT_UNUSED_FUTURE(UpdateExecutor_->Stop()); + + YT_LOG_DEBUG("Stopping periodic updates in native cookie store"); + } + + TFuture<TCypressCookiePtr> GetCookie(const TString& value) override + { + { + auto guard = ReaderGuard(CookiesLock_); + auto cookieIt = Cookies_.find(value); + if (cookieIt != Cookies_.end()) { + auto entry = cookieIt->second; + if (IsEntryActual(entry)) { + return entry->CookieFuture; + } + } + } + + { + auto guard = WriterGuard(CookiesLock_); + auto cookieIt = Cookies_.find(value); + if (cookieIt != Cookies_.end()) { + auto entry = cookieIt->second; + // Double check. + if (IsEntryActual(entry)) { + return entry->CookieFuture; + } else { + Cookies_.erase(cookieIt); + } + } + + auto cookieFuture = DoFetchCookie(value); + auto entry = New<TEntry>(); + + cookieFuture = cookieFuture + .Apply(BIND([=, this, this_ = MakeStrong(this)] (const TErrorOr<TCypressCookiePtr>& cookieOrError) { + entry->FetchTime = TInstant::Now(); + + if (cookieOrError.IsOK()) { + DoRegisterCookie(cookieOrError.Value()); + } + + return cookieOrError; + })).ToUncancelable(); + entry->CookieFuture = cookieFuture; + + EmplaceOrCrash(Cookies_, value, std::move(entry)); + + return cookieFuture; + } + } + + TCypressCookiePtr GetLastCookieForUser(const TString& user) override + { + auto guard = ReaderGuard(UserToLastCookieLock_); + auto userIt = UserToLastCookie_.find(user); + if (userIt == UserToLastCookie_.end()) { + return nullptr; + } else { + return userIt->second; + } + } + + void RemoveLastCookieForUser(const TString& user) override + { + auto guard = WriterGuard(UserToLastCookieLock_); + UserToLastCookie_.erase(user); + } + + TFuture<void> RegisterCookie(const TCypressCookiePtr& cookie) override + { + auto attributes = CreateEphemeralAttributes(); + attributes->Set("value", ConvertToYsonString(cookie)); + attributes->Set("expiration_time", ConvertToYsonString(cookie->ExpiresAt)); + + TCreateNodeOptions createOptions; + createOptions.Attributes = std::move(attributes); + + auto future = Client_->CreateNode( + GetCookiePath(cookie->Value), + EObjectType::Document, + createOptions); + return future.AsVoid().Apply(BIND([=, this, this_ = MakeStrong(this)] (const TError& error) { + if (error.IsOK()) { + auto entry = New<TEntry>(); + entry->CookieFuture = MakeFuture<TCypressCookiePtr>(cookie); + entry->FetchTime = TInstant::Now(); + + { + // NB: Technically it is possible that cookie is already in |Cookies_|. + auto writeGuard = WriterGuard(CookiesLock_); + Cookies_[cookie->Value] = std::move(entry); + } + + DoRegisterCookie(cookie); + } + })); + } + +private: + const TCypressCookieStoreConfigPtr Config_; + + const IClientPtr Client_; + + const TPeriodicExecutorPtr UpdateExecutor_; + + struct TEntry final + { + TFuture<TCypressCookiePtr> CookieFuture; + + //! Time when this entry was fetched. + TInstant FetchTime; + }; + using TEntryPtr = TIntrusivePtr<TEntry>; + + THashMap<TString, TEntryPtr> Cookies_; + YT_DECLARE_SPIN_LOCK(TReaderWriterSpinLock, CookiesLock_); + + THashMap<TString, TCypressCookiePtr> UserToLastCookie_; + YT_DECLARE_SPIN_LOCK(TReaderWriterSpinLock, UserToLastCookieLock_); + + bool IsEntryActual(const TEntryPtr& entry) + { + // Cookie info is not fetched yet, so information cannot be stale. + if (!entry->CookieFuture.IsSet()) { + return true; + } + + // Successes are stored forever and errors are cached for a configured time. + return + entry->CookieFuture.Get().IsOK() || + entry->FetchTime + Config_->ErrorEvictionTime > TInstant::Now(); + } + + TFuture<TCypressCookiePtr> DoFetchCookie(const TString& value) + { + YT_LOG_DEBUG("Fetching cookie from Cypress (Cookie: %v)", + value); + + return Client_->GetNode(GetCookiePath(value)) + .Apply(BIND([=, this_ = MakeStrong(this)] (const TYsonString& value) { + return ConvertTo<TCypressCookiePtr>(value); + })); + } + + void DoFetchAllCookies() + { + try { + GuardedFetchAllCookies(); + } catch (const std::exception& ex) { + YT_LOG_WARNING(ex, "Failed to fetch native cookies from Cypress"); + } + } + + void GuardedFetchAllCookies() + { + YT_LOG_DEBUG("Started fetching native cookies"); + + constexpr TStringBuf ValueAttribute = "value"; + + TListNodeOptions listOptions{ + .Attributes = std::vector<TString>({TString{ValueAttribute}}), + }; + listOptions.ReadFrom = EMasterChannelKind::Cache; + + auto rawListResult = WaitFor(Client_->ListNode( + "//sys/cypress_cookies", + listOptions)) + .ValueOrThrow(); + auto listResult = ConvertTo<IListNodePtr>(rawListResult); + + YT_LOG_DEBUG("Native cookies fetched from Cypress (CookieCount: %v)", + listResult->GetChildCount()); + + THashMap<TString, TEntryPtr> newCookies; + THashMap<TString, TCypressCookiePtr> newUserToLastCookie; + for (const auto& child : listResult->GetChildren()) { + try { + auto cookie = child->Attributes().Get<TCypressCookiePtr>(ValueAttribute); + + auto entry = New<TEntry>(); + entry->CookieFuture = MakeFuture<TCypressCookiePtr>(cookie); + entry->FetchTime = TInstant::Now(); + EmplaceOrCrash(newCookies, cookie->Value, entry); + + const auto& user = cookie->User; + auto userIt = newUserToLastCookie.find(user); + if (userIt == newUserToLastCookie.end()) { + newUserToLastCookie[user] = cookie; + } else { + const auto& lastCookie = userIt->second; + if (cookie->ExpiresAt > lastCookie->ExpiresAt) { + newUserToLastCookie[user] = cookie; + } + } + } catch (const std::exception& ex) { + YT_LOG_WARNING(ex, "Failed to parse cookie (Cookie: %v)", + child->GetValue<TString>()); + } + } + + { + auto guard = WriterGuard(CookiesLock_); + Cookies_ = std::move(newCookies); + } + { + auto guard = WriterGuard(UserToLastCookieLock_); + UserToLastCookie_ = std::move(newUserToLastCookie); + } + } + + void DoRegisterCookie(const TCypressCookiePtr& cookie) + { + auto guard = WriterGuard(UserToLastCookieLock_); + + const auto& user = cookie->User; + auto userIt = UserToLastCookie_.find(user); + if (userIt == UserToLastCookie_.end()) { + UserToLastCookie_[user] = cookie; + } else { + const auto& lastCookie = userIt->second; + if (cookie->ExpiresAt > lastCookie->ExpiresAt) { + UserToLastCookie_[user] = cookie; + } + } + } + + static TYPath GetCookiePath(const TString& value) + { + return Format("//sys/cypress_cookies/%v", ToYPathLiteral(value)); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICypressCookieStorePtr CreateCypressCookieStore( + TCypressCookieStoreConfigPtr config, + IClientPtr client, + IInvokerPtr invoker) +{ + return New<TCypressCookieStore>( + std::move(config), + std::move(client), + std::move(invoker)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_cookie_store.h b/yt/yt/library/auth_server/cypress_cookie_store.h new file mode 100644 index 0000000000..a959bf3599 --- /dev/null +++ b/yt/yt/library/auth_server/cypress_cookie_store.h @@ -0,0 +1,50 @@ +#pragma once + +#include "cypress_cookie.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +//! This class stores user session cookies locally and periodically +//! synchronizes them with Cypress. +/* + * Thread affinity: any + */ +struct ICypressCookieStore + : public TRefCounted +{ + //! Starts periodic cookie fetch. + virtual void Start() = 0; + + //! Stops periodic cookie fetch. + virtual void Stop() = 0; + + //! Finds cookie description by value. If cookie with given value is known, + //! returns its description. Otherwise, tries to fetch cookie from Cypress. + virtual TFuture<TCypressCookiePtr> GetCookie(const TString& value) = 0; + + //! Returns known cookie for given user with maximum |ExpiresAt|. + //! If no cookies for user are known, returns |nullptr|. + virtual TCypressCookiePtr GetLastCookieForUser(const TString& user) = 0; + + //! Invalidates last cookie for user. + virtual void RemoveLastCookieForUser(const TString& user) = 0; + + //! Registers cookie in Cypress. If registration is successful, also stores + //! cookie locally. + virtual TFuture<void> RegisterCookie(const TCypressCookiePtr& cookie) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ICypressCookieStore) + +//////////////////////////////////////////////////////////////////////////////// + +ICypressCookieStorePtr CreateCypressCookieStore( + TCypressCookieStoreConfigPtr config, + NApi::IClientPtr client, + IInvokerPtr invoker); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_token_authenticator.cpp b/yt/yt/library/auth_server/cypress_token_authenticator.cpp new file mode 100644 index 0000000000..364928427f --- /dev/null +++ b/yt/yt/library/auth_server/cypress_token_authenticator.cpp @@ -0,0 +1,88 @@ +#include "cypress_token_authenticator.h" + +#include "token_authenticator.h" +#include "private.h" + +#include <yt/yt/client/api/client.h> + +#include <yt/yt/core/crypto/crypto.h> + +namespace NYT::NAuth { + +using namespace NApi; +using namespace NConcurrency; +using namespace NCrypto; +using namespace NYPath; +using namespace NYTree; +using namespace NYson; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressTokenAuthenticator + : public ITokenAuthenticator +{ +public: + explicit TCypressTokenAuthenticator(IClientPtr client) + : Client_(std::move(client)) + { } + + TFuture<TAuthenticationResult> Authenticate( + const TTokenCredentials& credentials) override + { + const auto& token = credentials.Token; + const auto& userIP = credentials.UserIP; + auto tokenHash = GetSha256HexDigestLowerCase(token); + YT_LOG_DEBUG("Authenticating user with Cypress token (TokenHash: %v, UserIP: %v)", + tokenHash, + userIP); + + auto path = Format("//sys/cypress_tokens/%v/@user", ToYPathLiteral(tokenHash)); + return Client_->GetNode(path, /*options*/ {}) + .Apply(BIND( + &TCypressTokenAuthenticator::OnCallResult, + MakeStrong(this), + std::move(tokenHash))); + } + +private: + const IClientPtr Client_; + + TAuthenticationResult OnCallResult(const TString& tokenHash, const TErrorOr<TYsonString>& rspOrError) + { + if (!rspOrError.IsOK()) { + if (rspOrError.FindMatching(NYTree::EErrorCode::ResolveError)) { + YT_LOG_DEBUG(rspOrError, "Token is missing in Cypress (TokenHash: %v)", + tokenHash); + THROW_ERROR_EXCEPTION("Token is missing in Cypress") + << TErrorAttribute("token_hash", tokenHash) + << rspOrError; + } else { + YT_LOG_DEBUG(rspOrError, "Cypress authentication failed (TokenHash: %v)", + tokenHash); + THROW_ERROR_EXCEPTION("Cypress authentication failed") + << TErrorAttribute("token_hash", tokenHash) + << rspOrError; + } + } + + const auto& rsp = rspOrError.Value(); + return TAuthenticationResult{ + .Login = ConvertTo<TString>(rsp), + }; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ITokenAuthenticatorPtr CreateCypressTokenAuthenticator(IClientPtr client) +{ + return New<TCypressTokenAuthenticator>(std::move(client)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_token_authenticator.h b/yt/yt/library/auth_server/cypress_token_authenticator.h new file mode 100644 index 0000000000..0ab5f09920 --- /dev/null +++ b/yt/yt/library/auth_server/cypress_token_authenticator.h @@ -0,0 +1,15 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ITokenAuthenticatorPtr CreateCypressTokenAuthenticator(NApi::IClientPtr client); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_user_manager.cpp b/yt/yt/library/auth_server/cypress_user_manager.cpp new file mode 100644 index 0000000000..6ede6ee179 --- /dev/null +++ b/yt/yt/library/auth_server/cypress_user_manager.cpp @@ -0,0 +1,108 @@ +#include "cypress_user_manager.h" + +#include "auth_cache.h" +#include "private.h" + +#include <yt/yt/client/api/client.h> + +#include <library/cpp/yt/logging/logger.h> + +namespace NYT::NAuth { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCypressUserManager + : public ICypressUserManager +{ +public: + TCypressUserManager( + TCypressUserManagerConfigPtr config, + NApi::IClientPtr client) + : Config_(std::move(config)) + , Client_(std::move(client)) + { } + + TFuture<NObjectClient::TObjectId> CreateUser(const TString& name) override + { + YT_LOG_DEBUG("Creating user object (Name: %v)", name); + NApi::TCreateObjectOptions options; + options.IgnoreExisting = true; + + auto attributes = CreateEphemeralAttributes(); + attributes->Set("name", name); + options.Attributes = std::move(attributes); + + return Client_->CreateObject( + NObjectClient::EObjectType::User, + options); + } + +private: + const TCypressUserManagerConfigPtr Config_; + const NApi::IClientPtr Client_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICypressUserManagerPtr CreateCypressUserManager( + TCypressUserManagerConfigPtr config, + NApi::IClientPtr client) +{ + return New<TCypressUserManager>( + std::move(config), + std::move(client)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TCachingCypressUserManager + : public ICypressUserManager + , public TAuthCache<TString, NObjectClient::TObjectId, std::monostate> +{ +public: + TCachingCypressUserManager( + TCachingCypressUserManagerConfigPtr config, + ICypressUserManagerPtr CypressUserManager, + NProfiling::TProfiler profiler) + : TAuthCache(config->Cache, std::move(profiler)) + , CypressUserManager_(std::move(CypressUserManager)) + { } + + TFuture<NObjectClient::TObjectId> CreateUser(const TString& name) override + { + return Get(name, std::monostate{}); + } + +private: + const ICypressUserManagerPtr CypressUserManager_; + + TFuture<NObjectClient::TObjectId> DoGet( + const TString& name, + const std::monostate& /*context*/) noexcept override + { + return CypressUserManager_->CreateUser(name); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICypressUserManagerPtr CreateCachingCypressUserManager( + TCachingCypressUserManagerConfigPtr config, + ICypressUserManagerPtr userManager, + NProfiling::TProfiler profiler) +{ + return New<TCachingCypressUserManager>( + std::move(config), + std::move(userManager), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/cypress_user_manager.h b/yt/yt/library/auth_server/cypress_user_manager.h new file mode 100644 index 0000000000..79fc4cbb7b --- /dev/null +++ b/yt/yt/library/auth_server/cypress_user_manager.h @@ -0,0 +1,36 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ICypressUserManager + : public virtual TRefCounted +{ + virtual TFuture<NObjectClient::TObjectId> CreateUser(const TString& name) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ICypressUserManager) + +//////////////////////////////////////////////////////////////////////////////// + +ICypressUserManagerPtr CreateCypressUserManager( + TCypressUserManagerConfigPtr config, + NApi::IClientPtr client); + +//////////////////////////////////////////////////////////////////////////////// + +ICypressUserManagerPtr CreateCachingCypressUserManager( + TCachingCypressUserManagerConfigPtr config, + ICypressUserManagerPtr userManager, + NProfiling::TProfiler profiler); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/default_secret_vault_service.cpp b/yt/yt/library/auth_server/default_secret_vault_service.cpp new file mode 100644 index 0000000000..e7b8e63228 --- /dev/null +++ b/yt/yt/library/auth_server/default_secret_vault_service.cpp @@ -0,0 +1,456 @@ +#include "default_secret_vault_service.h" +#include "secret_vault_service.h" + +#include "config.h" +#include "private.h" + +#include <yt/yt/core/http/client.h> +#include <yt/yt/core/http/helpers.h> +#include <yt/yt/core/http/http.h> + +#include <yt/yt/core/https/client.h> +#include <yt/yt/core/https/config.h> + +#include <yt/yt/core/json/json_parser.h> +#include <yt/yt/core/json/json_writer.h> + +#include <yt/yt/core/rpc/dispatcher.h> + +#include <yt/yt/core/ytree/ephemeral_node_factory.h> +#include <yt/yt/core/ytree/fluent.h> +#include <yt/yt/core/ytree/tree_builder.h> + +#include <yt/yt/core/profiling/timing.h> + +#include <yt/yt/library/tvm/service/tvm_service.h> + +#include <library/cpp/uri/encode.h> + +namespace NYT::NAuth { + +using namespace NConcurrency; +using namespace NHttp; +using namespace NJson; +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(ESecretVaultResponseStatus, + ((Unknown) (0)) + ((OK) (1)) + ((Warning) (2)) + ((Error) (3)) +); + +//////////////////////////////////////////////////////////////////////////////// + +class TDefaultSecretVaultService + : public ISecretVaultService +{ +public: + TDefaultSecretVaultService( + TDefaultSecretVaultServiceConfigPtr config, + ITvmServicePtr tvmService, + IPollerPtr poller, + NProfiling::TProfiler profiler) + : Config_(std::move(config)) + , TvmService_(std::move(tvmService)) + , HttpClient_(Config_->Secure + ? NHttps::CreateClient(Config_->HttpClient, std::move(poller)) + : NHttp::CreateClient(Config_->HttpClient, std::move(poller))) + , SubrequestsPerCallGauge_(profiler.Gauge("/subrequests_per_call")) + , CallCountCounter_(profiler.Counter("/call_count")) + , SubrequestCountCounter_(profiler.Counter("/subrequest_count")) + , CallTimer_(profiler.Timer("/call_time")) + , SuccessfulCallCountCounter_(profiler.Counter("/successful_call_count")) + , FailedCallCountCounter_(profiler.Counter("/failed_call_count")) + , SuccessfulSubrequestCountCounter_(profiler.Counter("/successful_subrequest_count")) + , WarningSubrequestCountCounter_(profiler.Counter("/warning_subrequest_count")) + , FailedSubrequestCountCounter_(profiler.Counter("/failed_subrequest_count")) + { } + + TFuture<std::vector<TErrorOrSecretSubresponse>> GetSecrets( + const std::vector<TSecretSubrequest>& subrequests) override + { + return BIND(&TDefaultSecretVaultService::DoGetSecrets, MakeStrong(this), subrequests) + .AsyncVia(NRpc::TDispatcher::Get()->GetLightInvoker()) + .Run(); + } + + TFuture<TString> GetDelegationToken(TDelegationTokenRequest request) override + { + if (request.Signature.empty() || request.SecretId.empty() || request.UserTicket.empty()) { + return MakeFuture<TString>(TError( + "Invalid call for delegation token with signature %Qv, secret id %Qv " + "and user ticket length %v", + request.Signature, + request.SecretId, + request.UserTicket.size())); + } + + return BIND(&TDefaultSecretVaultService::DoGetDelegationToken, + MakeStrong(this), + std::move(request)) + .AsyncVia(NRpc::TDispatcher::Get()->GetLightInvoker()) + .Run(); + } + +private: + const TDefaultSecretVaultServiceConfigPtr Config_; + const ITvmServicePtr TvmService_; + + const NHttp::IClientPtr HttpClient_; + + NProfiling::TGauge SubrequestsPerCallGauge_; + NProfiling::TCounter CallCountCounter_; + NProfiling::TCounter SubrequestCountCounter_; + NProfiling::TEventTimer CallTimer_; + NProfiling::TCounter SuccessfulCallCountCounter_; + NProfiling::TCounter FailedCallCountCounter_; + NProfiling::TCounter SuccessfulSubrequestCountCounter_; + NProfiling::TCounter WarningSubrequestCountCounter_; + NProfiling::TCounter FailedSubrequestCountCounter_; + +private: + std::vector<TErrorOrSecretSubresponse> DoGetSecrets( + const std::vector<TSecretSubrequest>& subrequests) + { + const auto callId = TGuid::Create(); + + YT_LOG_DEBUG("Retrieving secrets from Vault (Count: %v, CallId: %v)", + subrequests.size(), + callId); + + CallCountCounter_.Increment(); + SubrequestCountCounter_.Increment(subrequests.size()); + SubrequestsPerCallGauge_.Update(subrequests.size()); + + try { + const auto url = MakeRequestUrl("/1/tokens/", true); + const auto headers = New<THeaders>(); + headers->Add("Content-Type", "application/json"); + + const auto vaultTicket = TvmService_->GetServiceTicket(Config_->VaultServiceId); + const auto body = MakeGetSecretsRequestBody(vaultTicket, subrequests); + + const auto responseBody = HttpPost(url, body, headers); + const auto response = ParseVaultResponse(responseBody); + + auto responseStatusString = GetStatusStringFromResponse(response); + auto responseStatus = ParseStatus(responseStatusString); + if (responseStatus == ESecretVaultResponseStatus::Error) { + THROW_ERROR GetErrorFromResponse(response, responseStatusString); + } + if (responseStatus != ESecretVaultResponseStatus::OK) { + // NB! Vault API is not supposed to return other statuses (e.g. warning) at the top-level. + THROW_ERROR MakeUnexpectedStatusError(responseStatusString); + } + + std::vector<TErrorOrSecretSubresponse> subresponses; + + auto secretsNode = response->GetChildOrThrow("secrets")->AsList(); + + int successCount = 0; + int warningCount = 0; + int errorCount = 0; + auto secretNodes = secretsNode->GetChildren(); + for (size_t subresponseIndex = 0; subresponseIndex < secretNodes.size(); ++subresponseIndex) { + auto secretMapNode = secretNodes[subresponseIndex]->AsMap(); + + auto subresponseStatusString = GetStatusStringFromResponse(secretMapNode); + auto subresponseStatus = ParseStatus(subresponseStatusString); + if (subresponseStatus == ESecretVaultResponseStatus::OK) { + ++successCount; + } else if (subresponseStatus == ESecretVaultResponseStatus::Warning) { + // NB! Warning status is supposed to contain valid data so we proceed parsing the response. + ++warningCount; + auto warningMessage = GetWarningMessageFromResponse(secretMapNode); + YT_LOG_DEBUG( + "Received warning status in subresponse from Vault " + "(CallId: %v, SubresponseIndex: %v, WarningMessage: %v)", + callId, + subresponseIndex, + warningMessage); + } else if (subresponseStatus == ESecretVaultResponseStatus::Error) { + subresponses.push_back(GetErrorFromResponse( + secretMapNode, + subresponseStatusString)); + ++errorCount; + continue; + } else { + subresponses.push_back(MakeUnexpectedStatusError(subresponseStatusString)); + ++errorCount; + continue; + } + + TSecretSubresponse subresponse; + auto valueNode = secretMapNode->GetChildOrThrow("value")->AsList(); + for (const auto& fieldNode : valueNode->GetChildren()) { + auto fieldMapNode = fieldNode->AsMap(); + auto encodingNode = fieldMapNode->FindChild("encoding"); + TString encoding = encodingNode ? encodingNode->GetValue<TString>() : ""; + subresponse.Values.emplace_back(TSecretValue{ + fieldMapNode->GetChildValueOrThrow<TString>("key"), + fieldMapNode->GetChildValueOrThrow<TString>("value"), + encoding}); + } + + subresponses.push_back(subresponse); + } + + SuccessfulCallCountCounter_.Increment(); + SuccessfulSubrequestCountCounter_.Increment(successCount); + WarningSubrequestCountCounter_.Increment(warningCount); + FailedSubrequestCountCounter_.Increment(errorCount); + + YT_LOG_DEBUG( + "Secrets retrieved from Vault " + "(CallId: %v, SuccessCount: %v, WarningCount: %v, ErrorCount: %v)", + callId, + successCount, + warningCount, + errorCount); + return subresponses; + } catch (const std::exception& ex) { + FailedCallCountCounter_.Increment(); + auto error = TError("Failed to get secrets from Vault") + << ex + << TErrorAttribute("call_id", callId); + YT_LOG_DEBUG(error); + THROW_ERROR error; + } + } + + TString DoGetDelegationToken(TDelegationTokenRequest request) + { + const auto callId = TGuid::Create(); + + YT_LOG_DEBUG( + "Retrieving delegation token from Vault " + "(SecretId: %v, Signature: %v, UserTicket: %v, CallId: %v)", + request.SecretId, + request.Signature, // signatures are not secret; tokens are + RemoveTicketSignature(request.UserTicket), + callId); + + CallCountCounter_.Increment(); + + try { + const auto url = MakeRequestUrl(Format("/1/secrets/%v/tokens/", request.SecretId), false); + const auto headers = New<THeaders>(); + const auto vaultTicket = TvmService_->GetServiceTicket(Config_->VaultServiceId); + headers->Add("Content-Type", "application/json"); + headers->Add("X-Ya-User-Ticket", request.UserTicket); + headers->Add("X-Ya-Service-Ticket", vaultTicket); + const auto body = MakeGetDelegationTokenRequestBody(request); + + const auto responseBody = HttpPost(url, body, headers); + const auto response = ParseVaultResponse(responseBody); + + auto responseStatusString = GetStatusStringFromResponse(response); + auto responseStatus = ParseStatus(responseStatusString); + if (responseStatus == ESecretVaultResponseStatus::Error) { + THROW_ERROR GetErrorFromResponse(response, responseStatusString); + } + if (responseStatus == ESecretVaultResponseStatus::Unknown) { + THROW_ERROR MakeUnexpectedStatusError(responseStatusString); + } + if (responseStatus == ESecretVaultResponseStatus::Warning) { + WarningSubrequestCountCounter_.Increment(); + YT_LOG_WARNING("Received warning message from Vault: %v", + GetWarningMessageFromResponse(response)); + } + + return response->GetChildValueOrThrow<TString>("token"); + } catch (const std::exception& ex) { + FailedCallCountCounter_.Increment(); + auto error = TError("Failed to get delegation token from Vault") + << ex + << TErrorAttribute("call_id", callId); + YT_LOG_DEBUG(error); + THROW_ERROR error; + } + } + + TString MakeRequestUrl(TStringBuf path, bool addConsumer) const + { + auto url = Format("%v://%v:%v%v", + Config_->Secure ? "https" : "http", + Config_->Host, + Config_->Port, + path); + if (addConsumer && !Config_->Consumer.empty()) { + url = Format("%v?consumer=%v", url, Config_->Consumer); + } + return url; + } + + TSharedRef MakeGetSecretsRequestBody( + const TString& vaultTicket, + const std::vector<TSecretSubrequest>& subrequests) + { + TString body; + TStringOutput stream(body); + auto jsonWriter = CreateJsonConsumer(&stream); + BuildYsonFluently(jsonWriter.get()) + .BeginMap() + .Item("tokenized_requests").DoListFor(subrequests, + [&] (auto fluent, const auto& subrequest) { + auto map = fluent.Item().BeginMap(); + if (!vaultTicket.empty()) { + map.Item("service_ticket").Value(vaultTicket); + } + if (!subrequest.DelegationToken.empty()) { + map.Item("token").Value(subrequest.DelegationToken); + } + if (!subrequest.Signature.empty()) { + map.Item("signature").Value(subrequest.Signature); + } + if (!subrequest.SecretId.empty()) { + map.Item("secret_uuid").Value(subrequest.SecretId); + } + if (!subrequest.SecretVersion.empty()) { + map.Item("secret_version").Value(subrequest.SecretVersion); + } + map.EndMap(); + }) + .EndMap(); + jsonWriter->Flush(); + return TSharedRef::FromString(std::move(body)); + } + + TSharedRef MakeGetDelegationTokenRequestBody(const TDelegationTokenRequest& request) + { + TString body; + TStringOutput stream(body); + auto jsonWriter = CreateJsonConsumer(&stream); + BuildYsonFluently(jsonWriter.get()) + .BeginMap() + .Item("signature").Value(request.Signature) + .Item("tvm_client_id").Value(TvmService_->GetSelfTvmId()) + .DoIf(!request.Comment.empty(), + [&] (auto fluent) { + fluent.Item("comment").Value(request.Comment); + }) + .EndMap(); + jsonWriter->Flush(); + return TSharedRef::FromString(std::move(body)); + } + + TSharedRef HttpPost( + const TString& url, + const TSharedRef& body, + const THeadersPtr& headers) + { + NProfiling::TWallTimer timer; + auto rspOrError = WaitFor(HttpClient_->Post(url, body, headers) + .WithTimeout(Config_->RequestTimeout)); + CallTimer_.Record(timer.GetElapsedTime()); + + THROW_ERROR_EXCEPTION_IF_FAILED(rspOrError, "Vault call failed"); + + const auto& rsp = rspOrError.Value(); + if (rsp->GetStatusCode() != EStatusCode::OK) { + THROW_ERROR_EXCEPTION("Vault call returned HTTP status code %v, response %v", + static_cast<int>(rsp->GetStatusCode()), + rsp->ReadAll()); + } + + return rsp->ReadAll(); + } + + static IMapNodePtr ParseVaultResponse(const TSharedRef& body) + { + try { + TMemoryInput stream(body.Begin(), body.Size()); + auto builder = CreateBuilderFromFactory(GetEphemeralNodeFactory()); + auto jsonConfig = New<TJsonFormatConfig>(); + jsonConfig->EncodeUtf8 = false; + ParseJson(&stream, builder.get(), jsonConfig); + return builder->EndTree()->AsMap(); + } catch (const std::exception& ex) { + THROW_ERROR TError(ESecretVaultErrorCode::MalformedResponse, + "Error parsing Vault response"); + } + } + + static TString GetStatusStringFromResponse(const IMapNodePtr& node) + { + return node->GetChildValueOrThrow<TString>("status"); + } + + static ESecretVaultResponseStatus ParseStatus(const TString& statusString) + { + if (statusString == "ok") { + return ESecretVaultResponseStatus::OK; + } else if (statusString == "warning") { + return ESecretVaultResponseStatus::Warning; + } else if (statusString == "error") { + return ESecretVaultResponseStatus::Error; + } else { + return ESecretVaultResponseStatus::Unknown; + } + } + + static TError GetErrorFromResponse(const IMapNodePtr& node, const TString& statusString) + { + auto codeString = node->GetChildValueOrThrow<TString>("code"); + auto code = ParseErrorCode(codeString); + + auto messageNode = node->FindChild("message"); + return TError( + code, + messageNode ? messageNode->GetValue<TString>() : "Vault error") + << TErrorAttribute("status", statusString) + << TErrorAttribute("code", codeString); + } + + static ESecretVaultErrorCode ParseErrorCode(TStringBuf codeString) + { + // https://vault-api.passport.yandex.net/docs/#api + if (codeString == "nonexistent_entity_error") { + return ESecretVaultErrorCode::NonexistentEntityError; + } else if (codeString == "delegation_access_error") { + return ESecretVaultErrorCode::DelegationAccessError; + } else if (codeString == "delegation_token_revoked") { + return ESecretVaultErrorCode::DelegationTokenRevoked; + } else { + return ESecretVaultErrorCode::UnknownError; + } + } + + static TError MakeUnexpectedStatusError(const TString& statusString) + { + return TError( + ESecretVaultErrorCode::UnexpectedStatus, + "Received unexpected status from Vault") + << TErrorAttribute("status", statusString); + } + + static TString GetWarningMessageFromResponse(const IMapNodePtr& node) + { + auto warningMessageNode = node->FindChild("warning_message"); + return warningMessageNode ? warningMessageNode->GetValue<TString>() : "Vault warning"; + } +}; // TDefaultSecretVaultService + +ISecretVaultServicePtr CreateDefaultSecretVaultService( + TDefaultSecretVaultServiceConfigPtr config, + ITvmServicePtr tvmService, + IPollerPtr poller, + NProfiling::TProfiler profiler) +{ + return New<TDefaultSecretVaultService>( + std::move(config), + std::move(tvmService), + std::move(poller), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/default_secret_vault_service.h b/yt/yt/library/auth_server/default_secret_vault_service.h new file mode 100644 index 0000000000..a054ae48ee --- /dev/null +++ b/yt/yt/library/auth_server/default_secret_vault_service.h @@ -0,0 +1,21 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ISecretVaultServicePtr CreateDefaultSecretVaultService( + TDefaultSecretVaultServiceConfigPtr config, + ITvmServicePtr tvmService, + NConcurrency::IPollerPtr poller, + NProfiling::TProfiler profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/dummy_secret_vault_service.cpp b/yt/yt/library/auth_server/dummy_secret_vault_service.cpp new file mode 100644 index 0000000000..7f12344a5b --- /dev/null +++ b/yt/yt/library/auth_server/dummy_secret_vault_service.cpp @@ -0,0 +1,35 @@ +#include "dummy_secret_vault_service.h" +#include "secret_vault_service.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +class TDummySecretVaultService + : public ISecretVaultService +{ +public: + TFuture<std::vector<TErrorOrSecretSubresponse>> GetSecrets( + const std::vector<TSecretSubrequest>& subrequests) override + { + std::vector<TErrorOrSecretSubresponse> results; + for (size_t index = 0; index < subrequests.size(); ++index) { + results.push_back(TError("Secret Vault is not configured")); + } + return MakeFuture(std::move(results)); + } + + TFuture<TString> GetDelegationToken(TDelegationTokenRequest /*request*/) override + { + return MakeFuture<TString>(TError("Secret Vault is not configured")); + } +}; + +ISecretVaultServicePtr CreateDummySecretVaultService() +{ + return New<TDummySecretVaultService>(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/dummy_secret_vault_service.h b/yt/yt/library/auth_server/dummy_secret_vault_service.h new file mode 100644 index 0000000000..e03164d29e --- /dev/null +++ b/yt/yt/library/auth_server/dummy_secret_vault_service.h @@ -0,0 +1,13 @@ +#pragma once + +#include "public.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ISecretVaultServicePtr CreateDummySecretVaultService(); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/helpers.cpp b/yt/yt/library/auth_server/helpers.cpp new file mode 100644 index 0000000000..71eed055f0 --- /dev/null +++ b/yt/yt/library/auth_server/helpers.cpp @@ -0,0 +1,170 @@ +#include "helpers.h" + +#include <yt/yt/core/crypto/crypto.h> + +#include <yt/yt/core/ytree/fluent.h> + +#include <yt/yt_proto/yt/core/rpc/proto/rpc.pb.h> + +#include <library/cpp/string_utils/quote/quote.h> +#include <library/cpp/string_utils/url/url.h> + +#include <util/string/split.h> + +namespace NYT::NAuth { + +using namespace NCrypto; +using namespace NYson; +using namespace NYTree; +using namespace NRpc::NProto; + +//////////////////////////////////////////////////////////////////////////////// + +TString GetCryptoHash(TStringBuf secret) +{ + return NCrypto::TSha1Hasher() + .Append(secret) + .GetHexDigestLowerCase(); +} + +TString FormatUserIP(const NNet::TNetworkAddress& address) +{ + if (!address.IsIP()) { + // Sometimes userIP is missing (e.g. user is connecting + // from job using unix socket), but it is required by + // blackbox. Put placeholder in place of a real IP. + static const TString LocalUserIP = "127.0.0.1"; + return LocalUserIP; + } + return ToString( + address, + NNet::TNetworkAddressFormatOptions{ + .IncludePort = false, + .IncludeTcpProtocol = false + }); +} + +TString GetLoginForTvmId(TTvmId tvmId) +{ + return "tvm:" + ToString(tvmId); +} + +//////////////////////////////////////////////////////////////////////////////// + +static const THashSet<TString> PrivateUrlParams{ + "userip", + "oauth_token", + "sessionid", + "sslsessionid", + "user_ticket", +}; + +void TSafeUrlBuilder::AppendString(TStringBuf str) +{ + RealUrl_.AppendString(str); + SafeUrl_.AppendString(str); +} + +void TSafeUrlBuilder::AppendChar(char ch) +{ + RealUrl_.AppendChar(ch); + SafeUrl_.AppendChar(ch); +} + +void TSafeUrlBuilder::AppendParam(TStringBuf key, TStringBuf value) +{ + auto size = key.length() + 4 + CgiEscapeBufLen(value.length()); + + char* realBegin = RealUrl_.Preallocate(size); + char* realIt = realBegin; + memcpy(realIt, key.data(), key.length()); + realIt += key.length(); + *realIt = '='; + realIt += 1; + auto realEnd = CGIEscape(realIt, value.data(), value.length()); + RealUrl_.Advance(realEnd - realBegin); + + char* safeBegin = SafeUrl_.Preallocate(size); + char* safeEnd = safeBegin; + if (PrivateUrlParams.contains(key)) { + memcpy(safeEnd, realBegin, realIt - realBegin); + safeEnd += realIt - realBegin; + memcpy(safeEnd, "***", 3); + safeEnd += 3; + } else { + memcpy(safeEnd, realBegin, realEnd - realBegin); + safeEnd += realEnd - realBegin; + } + SafeUrl_.Advance(safeEnd - safeBegin); +} + +TString TSafeUrlBuilder::FlushRealUrl() +{ + return RealUrl_.Flush(); +} + +TString TSafeUrlBuilder::FlushSafeUrl() +{ + return SafeUrl_.Flush(); +} + +//////////////////////////////////////////////////////////////////////////////// + +THashedCredentials HashCredentials(const NRpc::NProto::TCredentialsExt& credentialsExt) +{ + THashedCredentials result; + if (credentialsExt.has_token()) { + result.TokenHash = GetCryptoHash(credentialsExt.token()); + } + return result; +} + +void Serialize(const THashedCredentials& hashedCredentials, IYsonConsumer* consumer) +{ + BuildYsonFluently(consumer) + .BeginMap() + .OptionalItem("token_hash", hashedCredentials.TokenHash) + .EndMap(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString SignCsrfToken(const TString& userId, const TString& key, TInstant now) +{ + auto msg = userId + ":" + ToString(now.TimeT()); + return CreateSha256Hmac(key, msg) + ":" + ToString(now.TimeT()); +} + +TError CheckCsrfToken( + const TString& csrfToken, + const TString& userId, + const TString& key, + TInstant expirationTime) +{ + std::vector<TString> parts; + StringSplitter(csrfToken).Split(':').AddTo(&parts); + if (parts.size() != 2) { + return TError("Malformed CSRF token"); + } + + auto signTime = TInstant::Seconds(FromString<time_t>(parts[1])); + if (signTime < expirationTime) { + return TError(NRpc::EErrorCode::InvalidCsrfToken, "CSRF token expired") + << TErrorAttribute("sign_time", signTime); + } + + auto msg = userId + ":" + ToString(signTime.TimeT()); + auto expectedToken = CreateSha256Hmac(key, msg); + if (!ConstantTimeCompare(expectedToken, parts[0])) { + return TError(NRpc::EErrorCode::InvalidCsrfToken, "Invalid CSFR token signature") + << TErrorAttribute("provided_signature", parts[0]) + << TErrorAttribute("user_fingerprint", msg); + } + + return {}; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth + diff --git a/yt/yt/library/auth_server/helpers.h b/yt/yt/library/auth_server/helpers.h new file mode 100644 index 0000000000..925089ae15 --- /dev/null +++ b/yt/yt/library/auth_server/helpers.h @@ -0,0 +1,80 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/convert.h> +#include <yt/yt/core/ytree/ypath_client.h> + +#include <yt/yt/core/net/public.h> + +#include <yt/yt/core/rpc/public.h> + +#include <yt/yt/client/api/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +TErrorOr<T> GetByYPath(const NYTree::INodePtr& node, const NYPath::TYPath& path) +{ + try { + auto child = NYTree::FindNodeByYPath(node, path); + if (!child) { + return TError("Missing %v", path); + } + return NYTree::ConvertTo<T>(std::move(child)); + } catch (const std::exception& ex) { + return TError("Unable to extract %v", path) << ex; + } +} + +TString GetCryptoHash(TStringBuf secret); +TString FormatUserIP(const NNet::TNetworkAddress& address); + +//////////////////////////////////////////////////////////////////////////////// + +class TSafeUrlBuilder +{ +public: + void AppendString(TStringBuf str); + void AppendChar(char ch); + void AppendParam(TStringBuf key, TStringBuf value); + + TString FlushRealUrl(); + TString FlushSafeUrl(); + +private: + TStringBuilder RealUrl_; + TStringBuilder SafeUrl_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct THashedCredentials +{ + std::optional<TString> TokenHash; + // TODO(max42): add remaining fields from TCredentialsExt when needed. +}; + +THashedCredentials HashCredentials(const NRpc::NProto::TCredentialsExt& credentialsExt); + +void Serialize(const THashedCredentials& hashedCredentials, NYson::IYsonConsumer* consumer); + +TString GetLoginForTvmId(TTvmId tvmId); + +//////////////////////////////////////////////////////////////////////////////// + +TString SignCsrfToken( + const TString& userId, + const TString& key, + TInstant now); +TError CheckCsrfToken( + const TString& csrfToken, + const TString& userId, + const TString& key, + TInstant expirationTime); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/oauth_cookie_authenticator.cpp b/yt/yt/library/auth_server/oauth_cookie_authenticator.cpp new file mode 100644 index 0000000000..e1fdee4c21 --- /dev/null +++ b/yt/yt/library/auth_server/oauth_cookie_authenticator.cpp @@ -0,0 +1,128 @@ +#include "oauth_cookie_authenticator.h" + +#include "config.h" +#include "cookie_authenticator.h" +#include "cypress_user_manager.h" +#include "helpers.h" +#include "oauth_service.h" +#include "private.h" + +#include <yt/yt/core/crypto/crypto.h> + +namespace NYT::NAuth { + +using namespace NYTree; +using namespace NYPath; +using namespace NCrypto; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TOAuthCookieAuthenticator + : public ICookieAuthenticator +{ +public: + TOAuthCookieAuthenticator( + TOAuthCookieAuthenticatorConfigPtr config, + IOAuthServicePtr oauthService, + ICypressUserManagerPtr userManager) + : Config_(std::move(config)) + , OAuthService_(std::move(oauthService)) + , UserManager_(std::move(userManager)) + { } + + const std::vector<TStringBuf>& GetCookieNames() const override + { + static const std::vector<TStringBuf> cookieNames{ + OAuthAccessTokenCookieName, + }; + return cookieNames; + } + + bool CanAuthenticate(const TCookieCredentials& credentials) const override + { + return credentials.Cookies.contains(OAuthAccessTokenCookieName); + } + + TFuture<TAuthenticationResult> Authenticate( + const TCookieCredentials& credentials) override + { + const auto& cookies = credentials.Cookies; + auto accessToken = GetOrCrash(cookies, OAuthAccessTokenCookieName); + auto accessTokenMD5 = GetMD5HexDigestUpperCase(accessToken); + auto userIP = FormatUserIP(credentials.UserIP); + + YT_LOG_DEBUG( + "Authenticating user via OAuth cookie (AccessTokenMD5: %v, UserIP: %v)", + accessTokenMD5, + userIP); + + return OAuthService_->GetUserInfo(accessToken) + .Apply(BIND( + &TOAuthCookieAuthenticator::OnGetUserInfo, + MakeStrong(this), + std::move(accessTokenMD5))); + } + +private: + const TOAuthCookieAuthenticatorConfigPtr Config_; + const IOAuthServicePtr OAuthService_; + const ICypressUserManagerPtr UserManager_; + + TFuture<TAuthenticationResult> OnGetUserInfo( + const TString& accessTokenMD5, + const TOAuthUserInfoResult& userInfo) + { + auto result = OnGetUserInfoImpl(userInfo); + if (result.IsOK()) { + YT_LOG_DEBUG( + "Authentication via OAuth successful (AccessTokenMD5: %v, Login: %v, Realm: %v)", + accessTokenMD5, + result.Value().Login, + result.Value().Realm); + } else { + YT_LOG_DEBUG(result, "Authentication via OAuth failed (AccessTokenMD5: %v)", accessTokenMD5); + result.MutableAttributes()->Set("access_token_md5", accessTokenMD5); + } + + return MakeFuture(std::move(result)); + } + + TErrorOr<TAuthenticationResult> OnGetUserInfoImpl(const TOAuthUserInfoResult& userInfo) + { + auto result = WaitFor(UserManager_->CreateUser(userInfo.Login)); + if (!result.IsOK()) { + auto error = TError("Failed to create user") + << TErrorAttribute("name", userInfo.Login) + << std::move(result); + YT_LOG_WARNING(error); + return error; + } + + return TAuthenticationResult{ + .Login = userInfo.Login, + .Realm = TString(OAuthCookieRealm) + }; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateOAuthCookieAuthenticator( + TOAuthCookieAuthenticatorConfigPtr config, + IOAuthServicePtr oauthService, + ICypressUserManagerPtr userManager) +{ + return New<TOAuthCookieAuthenticator>( + std::move(config), + std::move(oauthService), + std::move(userManager)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/oauth_cookie_authenticator.h b/yt/yt/library/auth_server/oauth_cookie_authenticator.h new file mode 100644 index 0000000000..be8b78127b --- /dev/null +++ b/yt/yt/library/auth_server/oauth_cookie_authenticator.h @@ -0,0 +1,16 @@ +#pragma once + +#include "public.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ICookieAuthenticatorPtr CreateOAuthCookieAuthenticator( + TOAuthCookieAuthenticatorConfigPtr config, + IOAuthServicePtr oauthService, + ICypressUserManagerPtr userManager); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/oauth_service.cpp b/yt/yt/library/auth_server/oauth_service.cpp new file mode 100644 index 0000000000..7d5cdfaa22 --- /dev/null +++ b/yt/yt/library/auth_server/oauth_service.cpp @@ -0,0 +1,187 @@ +#include "oauth_service.h" + +#include "config.h" +#include "private.h" +#include "helpers.h" + +#include <yt/yt/core/http/client.h> +#include <yt/yt/core/http/helpers.h> +#include <yt/yt/core/http/http.h> +#include <yt/yt/core/http/retrying_client.h> + +#include <yt/yt/core/https/client.h> +#include <yt/yt/core/https/config.h> + +#include <yt/yt/core/ytree/convert.h> + +#include <yt/yt/core/concurrency/delayed_executor.h> +#include <yt/yt/core/concurrency/poller.h> + +#include <yt/yt/core/json/json_parser.h> + +#include <yt/yt/core/profiling/timing.h> + +#include <yt/yt/core/rpc/dispatcher.h> + +namespace NYT::NAuth { + +using namespace NConcurrency; +using namespace NHttp; +using namespace NYTree; +using namespace NYson; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TOAuthService + : public IOAuthService +{ +public: + TOAuthService( + TOAuthServiceConfigPtr config, + IPollerPtr poller, + NProfiling::TProfiler profiler) + : Config_(std::move(config)) + , HttpClient_( + CreateRetryingClient( + Config_->RetryingClient, + Config_->Secure + ? NHttps::CreateClient(Config_->HttpClient, poller) + : NHttp::CreateClient(Config_->HttpClient, poller), + poller->GetInvoker())) + , OAuthCalls_(profiler.Counter("/oauth_calls")) + , OAuthCallErrors_(profiler.Counter("/oauth_call_errors")) + , OAuthCallTime_(profiler.Timer("/oauth_call_time")) + { } + + TFuture<TOAuthUserInfoResult> GetUserInfo(const TString& accessToken) override + { + return BIND(&TOAuthService::DoGetUserInfo, MakeStrong(this), accessToken) + .AsyncVia(NRpc::TDispatcher::Get()->GetLightInvoker()) + .Run(); + } + +private: + const TOAuthServiceConfigPtr Config_; + const NHttp::IRetryingClientPtr HttpClient_; + + NProfiling::TCounter OAuthCalls_; + NProfiling::TCounter OAuthCallErrors_; + NProfiling::TEventTimer OAuthCallTime_; + + static NJson::TJsonFormatConfigPtr MakeJsonFormatConfig() + { + auto config = New<NJson::TJsonFormatConfig>(); + // Additional string conversion is not necessary in this case. + config->EncodeUtf8 = false; + return config; + } + + TOAuthUserInfoResult DoGetUserInfo(const TString& accessToken) + { + OAuthCalls_.Increment(); + + auto callId = TGuid::Create(); + auto httpHeaders = New<THeaders>(); + httpHeaders->Add("Authorization", Format("%v %v", Config_->AuthorizationHeaderPrefix, accessToken)); + + auto jsonResponseChecker = CreateJsonResponseChecker( + BIND(&TOAuthService::DoCheckUserInfoResponse, MakeStrong(this)), + MakeJsonFormatConfig()); + + const auto url = Format("%v://%v:%v/%v", + Config_->Secure ? "https" : "http", + Config_->Host, + Config_->Port, + Config_->UserInfoEndpoint); + + YT_LOG_DEBUG("Calling OAuth get user info (Url: %v, CallId: %v)", + NHttp::SanitizeUrl(url), + callId); + + auto result = [&] { + NProfiling::TWallTimer timer; + auto result = WaitFor(HttpClient_->Get(jsonResponseChecker, url, httpHeaders)); + OAuthCallTime_.Record(timer.GetElapsedTime()); + return result; + }(); + + if (!result.IsOK()) { + OAuthCallErrors_.Increment(); + auto error = TError(NRpc::EErrorCode::InvalidCredentials, "OAuth call failed") + << result + << TErrorAttribute("call_id", callId); + YT_LOG_WARNING(error); + THROW_ERROR(error); + } + + const auto& formattedResponose = jsonResponseChecker->GetFormattedResponse()->AsMap(); + auto userInfo = TOAuthUserInfoResult{ + .Login = formattedResponose->GetChildValueOrThrow<TString>(Config_->UserInfoLoginField), + }; + + if (Config_->UserInfoSubjectField) { + userInfo.Subject = formattedResponose->GetChildValueOrThrow<TString>(*Config_->UserInfoSubjectField); + } + + return userInfo; + } + + TError DoCheckUserInfoResponse(const IResponsePtr& rsp, const NYTree::INodePtr& rspNode) const + { + if (rsp->GetStatusCode() != EStatusCode::OK) { + auto error = TError("OAuth response has non-ok status code: %v", static_cast<int>(rsp->GetStatusCode())); + + if (rspNode->GetType() == ENodeType::Map && Config_->UserInfoErrorField) { + auto errorNode = rspNode->AsMap()->FindChild(*Config_->UserInfoErrorField); + error = error + << TErrorAttribute("error_field_message", ConvertToYsonString(errorNode)) + << TErrorAttribute("error_field", *Config_->UserInfoErrorField); + } + + return error; + } + + if (rspNode->GetType() != ENodeType::Map) { + return TError("OAuth response content has unexpected node type") + << TErrorAttribute("expected_result_type", ENodeType::Map) + << TErrorAttribute("actual_result_type", rspNode->GetType()); + } + + auto loginNode = rspNode->AsMap()->FindChild(Config_->UserInfoLoginField); + if (!loginNode || loginNode->GetType() != ENodeType::String) { + return TError("OAuth response content has no login field or login node type is unexpected") + << TErrorAttribute("login_field", Config_->UserInfoLoginField); + } + + if (Config_->UserInfoSubjectField) { + auto subjectNode = rspNode->AsMap()->FindChild(*Config_->UserInfoSubjectField); + if (!subjectNode || subjectNode->GetType() != ENodeType::String) { + return TError("OAuth response content has no subject field or subject node type is unexpected") + << TErrorAttribute("subject_field", Config_->UserInfoSubjectField); + } + } + + return {}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +IOAuthServicePtr CreateOAuthService( + TOAuthServiceConfigPtr config, + NConcurrency::IPollerPtr poller, + NProfiling::TProfiler profiler) +{ + return New<TOAuthService>( + std::move(config), + std::move(poller), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/oauth_service.h b/yt/yt/library/auth_server/oauth_service.h new file mode 100644 index 0000000000..54f6232e12 --- /dev/null +++ b/yt/yt/library/auth_server/oauth_service.h @@ -0,0 +1,40 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/ytree/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct TOAuthUserInfoResult +{ + TString Subject; + TString Login; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct IOAuthService + : public virtual TRefCounted +{ + virtual TFuture<TOAuthUserInfoResult> GetUserInfo(const TString& accessToken) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IOAuthService) + +//////////////////////////////////////////////////////////////////////////////// + +IOAuthServicePtr CreateOAuthService( + TOAuthServiceConfigPtr config, + NConcurrency::IPollerPtr poller, + NProfiling::TProfiler profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/oauth_token_authenticator.cpp b/yt/yt/library/auth_server/oauth_token_authenticator.cpp new file mode 100644 index 0000000000..023321dab9 --- /dev/null +++ b/yt/yt/library/auth_server/oauth_token_authenticator.cpp @@ -0,0 +1,115 @@ +#include "oauth_cookie_authenticator.h" + +#include "config.h" +#include "cookie_authenticator.h" +#include "cypress_user_manager.h" +#include "helpers.h" +#include "oauth_service.h" +#include "private.h" +#include "token_authenticator.h" + +#include <yt/yt/core/crypto/crypto.h> + +namespace NYT::NAuth { + +using namespace NYTree; +using namespace NYPath; +using namespace NCrypto; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TOAuthTokenAuthenticator + : public ITokenAuthenticator +{ +public: + TOAuthTokenAuthenticator( + TOAuthTokenAuthenticatorConfigPtr config, + IOAuthServicePtr oauthService, + ICypressUserManagerPtr userManager) + : Config_(std::move(config)) + , OAuthService_(std::move(oauthService)) + , UserManager_(std::move(userManager)) + { } + + TFuture<TAuthenticationResult> Authenticate( + const TTokenCredentials& credentials) override + { + const auto& token = credentials.Token; + auto tokenHash = GetCryptoHash(token); + auto userIP = FormatUserIP(credentials.UserIP); + + YT_LOG_DEBUG( + "Authenticating user with token via OAuth (TokenHash: %v, UserIP: %v)", + tokenHash, + userIP); + + return OAuthService_->GetUserInfo(token) + .Apply(BIND( + &TOAuthTokenAuthenticator::OnGetUserInfo, + MakeStrong(this), + std::move(tokenHash))); + } + +private: + const TOAuthTokenAuthenticatorConfigPtr Config_; + const IOAuthServicePtr OAuthService_; + const ICypressUserManagerPtr UserManager_; + + TFuture<TAuthenticationResult> OnGetUserInfo( + const TString& tokenHash, + const TOAuthUserInfoResult& userInfo) + { + auto result = OnGetUserInfoImpl(userInfo); + if (result.IsOK()) { + YT_LOG_DEBUG( + "Authentication via OAuth successful (TokenHash: %v, Login: %v, Realm: %v)", + tokenHash, + result.Value().Login, + result.Value().Realm); + } else { + YT_LOG_DEBUG(result, "Authentication via OAuth failed (TokenHash: %v)", tokenHash); + result.MutableAttributes()->Set("token_hash", tokenHash); + } + + return MakeFuture(std::move(result)); + } + + TErrorOr<TAuthenticationResult> OnGetUserInfoImpl(const TOAuthUserInfoResult& userInfo) + { + auto result = WaitFor(UserManager_->CreateUser(userInfo.Login)); + if (!result.IsOK()) { + auto error = TError("Failed to create user") + << TErrorAttribute("name", userInfo.Login) + << std::move(result); + YT_LOG_WARNING(error); + return error; + } + + return TAuthenticationResult{ + .Login = userInfo.Login, + .Realm = TString(OAuthTokenRealm), + }; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +ITokenAuthenticatorPtr CreateOAuthTokenAuthenticator( + TOAuthTokenAuthenticatorConfigPtr config, + IOAuthServicePtr oauthService, + ICypressUserManagerPtr userManager) +{ + return New<TOAuthTokenAuthenticator>( + std::move(config), + std::move(oauthService), + std::move(userManager)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/oauth_token_authenticator.h b/yt/yt/library/auth_server/oauth_token_authenticator.h new file mode 100644 index 0000000000..07711b06fc --- /dev/null +++ b/yt/yt/library/auth_server/oauth_token_authenticator.h @@ -0,0 +1,16 @@ +#pragma once + +#include "public.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +ITokenAuthenticatorPtr CreateOAuthTokenAuthenticator( + TOAuthTokenAuthenticatorConfigPtr config, + IOAuthServicePtr oauthService, + ICypressUserManagerPtr userManager); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/private.h b/yt/yt/library/auth_server/private.h new file mode 100644 index 0000000000..92be73e94a --- /dev/null +++ b/yt/yt/library/auth_server/private.h @@ -0,0 +1,24 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger AuthLogger("Auth"); +inline const NProfiling::TProfiler AuthProfiler("/auth"); + +//////////////////////////////////////////////////////////////////////////////// + +constexpr TStringBuf OAuthCookieRealm = "oauth:cookie"; +constexpr TStringBuf OAuthTokenRealm = "oauth:token"; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth + diff --git a/yt/yt/library/auth_server/public.h b/yt/yt/library/auth_server/public.h new file mode 100644 index 0000000000..3db0437912 --- /dev/null +++ b/yt/yt/library/auth_server/public.h @@ -0,0 +1,167 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +#include <yt/yt/core/net/address.h> + +#include <yt/yt/library/tvm/service/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TAuthCacheConfig) +DECLARE_REFCOUNTED_CLASS(TBlackboxServiceConfig) +DECLARE_REFCOUNTED_CLASS(TBlackboxTicketAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TBlackboxTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingBlackboxTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCypressTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingCypressTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TBlackboxCookieAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingCookieAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingBlackboxCookieAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TDefaultSecretVaultServiceConfig) +DECLARE_REFCOUNTED_CLASS(TBatchingSecretVaultServiceConfig) +DECLARE_REFCOUNTED_CLASS(TCachingSecretVaultServiceConfig) +DECLARE_REFCOUNTED_CLASS(TAuthenticationManagerConfig) + +DECLARE_REFCOUNTED_CLASS(TOAuthCookieAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TOAuthTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingOAuthCookieAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TCachingOAuthTokenAuthenticatorConfig) +DECLARE_REFCOUNTED_CLASS(TOAuthServiceConfig) +DECLARE_REFCOUNTED_CLASS(TCypressUserManagerConfig) +DECLARE_REFCOUNTED_CLASS(TCachingCypressUserManagerConfig) + +DECLARE_REFCOUNTED_STRUCT(TCypressCookie) + +DECLARE_REFCOUNTED_STRUCT(TCypressCookieStoreConfig) +DECLARE_REFCOUNTED_STRUCT(TCypressCookieGeneratorConfig) +DECLARE_REFCOUNTED_STRUCT(TCypressCookieManagerConfig) + +DECLARE_REFCOUNTED_STRUCT(ICypressCookieStore) +DECLARE_REFCOUNTED_STRUCT(ICypressCookieManager) +DECLARE_REFCOUNTED_STRUCT(ICypressUserManager) + +DECLARE_REFCOUNTED_STRUCT(IAuthenticationManager) + +DECLARE_REFCOUNTED_STRUCT(IBlackboxService) +DECLARE_REFCOUNTED_STRUCT(IOAuthService) + +DECLARE_REFCOUNTED_STRUCT(ICookieAuthenticator) +DECLARE_REFCOUNTED_STRUCT(ITokenAuthenticator) +DECLARE_REFCOUNTED_STRUCT(ITicketAuthenticator) + +DECLARE_REFCOUNTED_STRUCT(ISecretVaultService) + +//////////////////////////////////////////////////////////////////////////////// + +// See https://doc.yandex-team.ru/blackbox/reference/method-sessionid-response-json.xml for reference. +DEFINE_ENUM_WITH_UNDERLYING_TYPE(EBlackboxStatus, i64, + ((Valid) (0)) + ((NeedReset)(1)) + ((Expired) (2)) + ((NoAuth) (3)) + ((Disabled) (4)) + ((Invalid) (5)) +); + +// See https://doc.yandex-team.ru/blackbox/concepts/blackboxErrors.xml +DEFINE_ENUM_WITH_UNDERLYING_TYPE(EBlackboxException, i64, + ((Ok) (0)) + ((Unknown) (1)) + ((InvalidParameters) (2)) + ((DBFetchFailed) (9)) + ((DBException) (10)) + ((AccessDenied) (21)) +); + +DEFINE_ENUM(ESecretVaultErrorCode, + ((UnknownError) (18000)) + ((MalformedResponse) (18001)) + ((NonexistentEntityError) (18002)) + ((DelegationAccessError) (18003)) + ((DelegationTokenRevoked) (18004)) + ((UnexpectedStatus) (18005)) +); + +//////////////////////////////////////////////////////////////////////////////// + +struct TTokenCredentials +{ + TString Token; + // NB: UserIP may be ignored for caching purposes. + NNet::TNetworkAddress UserIP; +}; + +struct TCookieCredentials +{ + // NB: Since requests are caching, pass only required + // subset of cookies here. + THashMap<TString, TString> Cookies; + + NNet::TNetworkAddress UserIP; +}; + +struct TTicketCredentials +{ + TString Ticket; +}; + +struct TServiceTicketCredentials +{ + TString Ticket; +}; + +struct TAuthenticationResult +{ + TString Login; + TString Realm; + TString UserTicket; + + //! If set, client is advised to set this cookie. + std::optional<TString> SetCookie; +}; + +inline bool operator ==( + const TCookieCredentials& lhs, + const TCookieCredentials& rhs) +{ + return + std::tie(lhs.Cookies, lhs.UserIP) == + std::tie(rhs.Cookies, rhs.UserIP); +} + +//////////////////////////////////////////////////////////////////////////////// + +constexpr TStringBuf BlackboxSessionIdCookieName = "Session_id"; +constexpr TStringBuf BlackboxSslSessionIdCookieName = "sessionid2"; +constexpr TStringBuf CypressCookieName = "YTCypressCookie"; +constexpr TStringBuf OAuthAccessTokenCookieName = "access_token"; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth + +template <> +struct THash<NYT::NAuth::TCookieCredentials> +{ + inline size_t operator()(const NYT::NAuth::TCookieCredentials& credentials) const + { + size_t result = 0; + + std::vector<std::pair<TString, TString>> cookies( + credentials.Cookies.begin(), + credentials.Cookies.end()); + std::sort(cookies.begin(), cookies.end()); + for (const auto& cookie : cookies) { + NYT::HashCombine(result, cookie.first); + NYT::HashCombine(result, cookie.second); + } + + NYT::HashCombine(result, credentials.UserIP); + + return result; + } +}; diff --git a/yt/yt/library/auth_server/secret_vault_service.cpp b/yt/yt/library/auth_server/secret_vault_service.cpp new file mode 100644 index 0000000000..6b738292ac --- /dev/null +++ b/yt/yt/library/auth_server/secret_vault_service.cpp @@ -0,0 +1,27 @@ +#include "secret_vault_service.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue( + TStringBuilder* builder, + const ISecretVaultService::TSecretSubrequest& subrequest, + TStringBuf /*spec*/) +{ + builder->AppendFormat("%v:%v:%v:%v", + subrequest.SecretId, + subrequest.SecretVersion, + subrequest.DelegationToken, + subrequest.Signature); +} + +TString ToString(const ISecretVaultService::TSecretSubrequest& subrequest) +{ + return ToStringViaBuilder(subrequest); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth + diff --git a/yt/yt/library/auth_server/secret_vault_service.h b/yt/yt/library/auth_server/secret_vault_service.h new file mode 100644 index 0000000000..9b762c5ab2 --- /dev/null +++ b/yt/yt/library/auth_server/secret_vault_service.h @@ -0,0 +1,81 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/future.h> + +#include <library/cpp/yt/misc/hash.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ISecretVaultService + : public virtual TRefCounted +{ + struct TSecretSubrequest + { + TString SecretId; + TString SecretVersion; + TString DelegationToken; + TString Signature; + + bool operator == (const TSecretSubrequest& other) const + { + return + std::tie(SecretId, SecretVersion, DelegationToken, Signature) == + std::tie(other.SecretId, other.SecretVersion, other.DelegationToken, other.Signature); + } + + operator size_t() const + { + size_t hash = 0; + HashCombine(hash, SecretId); + HashCombine(hash, SecretVersion); + HashCombine(hash, DelegationToken); + HashCombine(hash, Signature); + return hash; + } + }; + + struct TSecretValue + { + TString Key; + TString Value; + TString Encoding; + }; + + struct TSecretSubresponse + { + std::vector<TSecretValue> Values; + }; + + using TErrorOrSecretSubresponse = TErrorOr<TSecretSubresponse>; + + virtual TFuture<std::vector<TErrorOrSecretSubresponse>> GetSecrets( + const std::vector<TSecretSubrequest>& subrequests) = 0; + + struct TDelegationTokenRequest + { + TString UserTicket; + TString SecretId; + TString Signature; + TString Comment; + }; + + virtual TFuture<TString> GetDelegationToken(TDelegationTokenRequest request) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ISecretVaultService) + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue( + TStringBuilder* builder, + const ISecretVaultService::TSecretSubrequest& subrequest, + TStringBuf spec); +TString ToString(const ISecretVaultService::TSecretSubrequest& subrequest); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/ticket_authenticator.cpp b/yt/yt/library/auth_server/ticket_authenticator.cpp new file mode 100644 index 0000000000..4468d8f1b0 --- /dev/null +++ b/yt/yt/library/auth_server/ticket_authenticator.cpp @@ -0,0 +1,238 @@ +#include "ticket_authenticator.h" + +#include "blackbox_service.h" +#include "config.h" +#include "helpers.h" +#include "private.h" + +#include <yt/yt/core/rpc/authenticator.h> + +#include <yt/yt/library/tvm/service/tvm_service.h> + +namespace NYT::NAuth { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TBlackboxTicketAuthenticator + : public ITicketAuthenticator +{ +public: + TBlackboxTicketAuthenticator( + TBlackboxTicketAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService, + ITvmServicePtr tvmService) + : Config_(std::move(config)) + , BlackboxService_(std::move(blackboxService)) + , TvmService_(std::move(tvmService)) + { } + + TFuture<TAuthenticationResult> Authenticate( + const TTicketCredentials& credentials) override + { + const auto& ticket = credentials.Ticket; + auto ticketHash = GetCryptoHash(ticket); + + if (Config_->EnableScopeCheck && TvmService_) { + auto result = CheckScope(ticket, ticketHash); + if (!result.IsOK()) { + return MakeFuture<TAuthenticationResult>(result); + } + } + + YT_LOG_DEBUG("Validating ticket via Blackbox (TicketHash: %v)", + ticketHash); + + return BlackboxService_->Call("user_ticket", {{"user_ticket", ticket}}) + .Apply(BIND( + &TBlackboxTicketAuthenticator::OnBlackboxCallResult, + MakeStrong(this), + ticket, + ticketHash)); + } + + TFuture<TAuthenticationResult> Authenticate( + const TServiceTicketCredentials& credentials) override + { + const auto& ticket = credentials.Ticket; + auto ticketHash = GetCryptoHash(ticket); + + YT_LOG_DEBUG("Validating service ticket (TicketHash: %v)", + ticketHash); + + try { + auto parsedTicket = TvmService_->ParseServiceTicket(ticket); + + TAuthenticationResult result; + result.Login = GetLoginForTvmId(parsedTicket.TvmId); + result.Realm = "tvm:service-ticket"; + + YT_LOG_DEBUG("Ticket authentication successful (TicketHash: %v, Login: %v, Realm: %v)", + ticketHash, + result.Login, + result.Realm); + + return MakeFuture(result); + } catch (const std::exception& ex) { + TError error(ex); + YT_LOG_DEBUG(error, "Parsing service ticket failed (TicketHash: %v)", + ticketHash); + return MakeFuture<TAuthenticationResult>(error); + } + } + +private: + const TBlackboxTicketAuthenticatorConfigPtr Config_; + const IBlackboxServicePtr BlackboxService_; + const ITvmServicePtr TvmService_; + +private: + TError CheckScope(const TString& ticket, const TString& ticketHash) + { + YT_LOG_DEBUG("Validating ticket scopes (TicketHash: %v)", + ticketHash); + try { + const auto result = TvmService_->ParseUserTicket(ticket); + const auto& scopes = result.Scopes; + YT_LOG_DEBUG("Got user ticket (Scopes: %v)", scopes); + + const auto& allowedScopes = Config_->Scopes; + for (const auto& scope : scopes) { + if (allowedScopes.contains(scope)) { + return TError(); + } + } + + return TError(NRpc::EErrorCode::InvalidCredentials, + "Ticket does not provide an allowed scope") + << TErrorAttribute("scopes", scopes) + << TErrorAttribute("allowed_scopes", allowedScopes); + } catch (const std::exception& ex) { + TError error(ex); + YT_LOG_DEBUG(error, "Parsing user ticket failed (TicketHash: %v)", + ticketHash); + return error << TErrorAttribute("ticket_hash", ticketHash); + } + } + + TAuthenticationResult OnBlackboxCallResult( + const TString& ticket, + const TString& ticketHash, + const INodePtr& data) + { + auto errorOrResult = OnCallResultImpl(data); + if (!errorOrResult.IsOK()) { + YT_LOG_DEBUG(errorOrResult, "Blackbox authentication failed (TicketHash: %v)", + ticketHash); + THROW_ERROR errorOrResult + << TErrorAttribute("ticket_hash", ticketHash); + } + + auto result = errorOrResult.Value(); + result.UserTicket = ticket; + + YT_LOG_DEBUG("Blackbox authentication successful (TicketHash: %v, Login: %v, Realm: %v)", + ticketHash, + result.Login, + result.Realm); + return result; + } + + TErrorOr<TAuthenticationResult> OnCallResultImpl(const INodePtr& data) + { + static const TString ErrorPath("/error"); + auto errorNode = FindNodeByYPath(data, ErrorPath); + if (errorNode) { + return TError(errorNode->GetValue<TString>()); + } + + static const TString LoginPath("/users/0/login"); + auto loginNode = GetNodeByYPath(data, LoginPath); + + TAuthenticationResult result; + result.Login = loginNode->GetValue<TString>(); + result.Realm = "blackbox:user-ticket"; + return result; + } +}; + +ITicketAuthenticatorPtr CreateBlackboxTicketAuthenticator( + TBlackboxTicketAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService, + ITvmServicePtr tvmService) +{ + return New<TBlackboxTicketAuthenticator>( + std::move(config), + std::move(blackboxService), + std::move(tvmService)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TTicketAuthenticatorWrapper + : public NRpc::IAuthenticator +{ +public: + explicit TTicketAuthenticatorWrapper(ITicketAuthenticatorPtr underlying) + : Underlying_(std::move(underlying)) + { } + + bool CanAuthenticate(const NRpc::TAuthenticationContext& context) override + { + if (!context.Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext)) { + return false; + } + const auto& ext = context.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext); + return ext.has_user_ticket() || ext.has_service_ticket(); + } + + TFuture<NRpc::TAuthenticationResult> AsyncAuthenticate( + const NRpc::TAuthenticationContext& context) override + { + YT_ASSERT(CanAuthenticate(context)); + const auto& ext = context.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext); + + if (ext.has_user_ticket()) { + TTicketCredentials credentials; + credentials.Ticket = ext.user_ticket(); + return Underlying_->Authenticate(credentials).Apply( + BIND([=] (const TAuthenticationResult& authResult) { + NRpc::TAuthenticationResult rpcResult; + rpcResult.User = authResult.Login; + rpcResult.Realm = authResult.Realm; + rpcResult.UserTicket = authResult.UserTicket; + return rpcResult; + })); + } + + if (ext.has_service_ticket()) { + TServiceTicketCredentials credentials; + credentials.Ticket = ext.service_ticket(); + return Underlying_->Authenticate(credentials).Apply( + BIND([=] (const TAuthenticationResult& authResult) { + NRpc::TAuthenticationResult rpcResult; + rpcResult.User = authResult.Login; + rpcResult.Realm = authResult.Realm; + return rpcResult; + })); + } + + YT_ABORT(); + } +private: + const ITicketAuthenticatorPtr Underlying_; +}; + +NRpc::IAuthenticatorPtr CreateTicketAuthenticatorWrapper(ITicketAuthenticatorPtr underlying) +{ + return New<TTicketAuthenticatorWrapper>(std::move(underlying)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/ticket_authenticator.h b/yt/yt/library/auth_server/ticket_authenticator.h new file mode 100644 index 0000000000..ed907f6be7 --- /dev/null +++ b/yt/yt/library/auth_server/ticket_authenticator.h @@ -0,0 +1,39 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/core/actions/public.h> + +#include <yt/yt/core/rpc/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ITicketAuthenticator + : public virtual TRefCounted +{ + virtual TFuture<TAuthenticationResult> Authenticate( + const TTicketCredentials& credentials) = 0; + + virtual TFuture<TAuthenticationResult> Authenticate( + const TServiceTicketCredentials& credentials) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ITicketAuthenticator) + +//////////////////////////////////////////////////////////////////////////////// + +ITicketAuthenticatorPtr CreateBlackboxTicketAuthenticator( + TBlackboxTicketAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService, + ITvmServicePtr tvmService); + +NRpc::IAuthenticatorPtr CreateTicketAuthenticatorWrapper( + ITicketAuthenticatorPtr underlying); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/token_authenticator.cpp b/yt/yt/library/auth_server/token_authenticator.cpp new file mode 100644 index 0000000000..bf2ea61714 --- /dev/null +++ b/yt/yt/library/auth_server/token_authenticator.cpp @@ -0,0 +1,458 @@ +#include "token_authenticator.h" +#include "blackbox_service.h" +#include "helpers.h" +#include "config.h" +#include "private.h" +#include "auth_cache.h" + +#include <yt/yt/client/api/client.h> + +#include <yt/yt/core/misc/async_expiring_cache.h> + +#include <yt/yt/core/rpc/authenticator.h> + +namespace NYT::NAuth { + +using namespace NYTree; +using namespace NYson; +using namespace NYPath; +using namespace NApi; +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = AuthLogger; + +//////////////////////////////////////////////////////////////////////////////// + +// TODO(babenko): used passed profiler +class TBlackboxTokenAuthenticator + : public ITokenAuthenticator +{ +public: + TBlackboxTokenAuthenticator( + TBlackboxTokenAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService, + NProfiling::TProfiler profiler) + : Config_(std::move(config)) + , Blackbox_(std::move(blackboxService)) + { + profiler = profiler.WithPrefix("/blackbox_token_authenticator"); + RejectedTokens_ = profiler.Counter("/rejected_tokens"); + InvalidBlackboxResponses_ = profiler.Counter("/invalid_responses"); + TokenScopeCheckErrors_ = profiler.Counter("/scope_check_errors"); + } + + TFuture<TAuthenticationResult> Authenticate( + const TTokenCredentials& credentials) override + { + const auto& token = credentials.Token; + auto userIP = FormatUserIP(credentials.UserIP); + auto tokenHash = GetCryptoHash(token); + + YT_LOG_DEBUG("Authenticating user with token via Blackbox (TokenHash: %v, UserIP: %v)", + tokenHash, + userIP); + + THashMap<TString, TString> params{ + {"oauth_token", token}, + {"userip", userIP}, + }; + + if (Config_->GetUserTicket) { + params["get_user_ticket"] = "yes"; + } + + return Blackbox_->Call("oauth", params) + .Apply(BIND( + &TBlackboxTokenAuthenticator::OnCallResult, + MakeStrong(this), + std::move(tokenHash))); + } + +private: + const TBlackboxTokenAuthenticatorConfigPtr Config_; + const IBlackboxServicePtr Blackbox_; + + TCounter RejectedTokens_; + TCounter InvalidBlackboxResponses_; + TCounter TokenScopeCheckErrors_; + +private: + TAuthenticationResult OnCallResult(const TString& tokenHash, const INodePtr& data) + { + auto result = OnCallResultImpl(data); + if (!result.IsOK()) { + YT_LOG_DEBUG(result, "Blackbox authentication failed (TokenHash: %v)", + tokenHash); + THROW_ERROR result + << TErrorAttribute("token_hash", tokenHash); + } + + YT_LOG_DEBUG("Blackbox authentication successful (TokenHash: %v, Login: %v, Realm: %v)", + tokenHash, + result.Value().Login, + result.Value().Realm); + return result.Value(); + } + + TErrorOr<TAuthenticationResult> OnCallResultImpl(const INodePtr& data) + { + // See https://doc.yandex-team.ru/blackbox/reference/method-oauth-response-json.xml for reference. + auto statusId = GetByYPath<int>(data, "/status/id"); + if (!statusId.IsOK()) { + InvalidBlackboxResponses_.Increment(); + return TError("Blackbox returned invalid response"); + } + + if (EBlackboxStatus(statusId.Value()) != EBlackboxStatus::Valid) { + auto error = GetByYPath<TString>(data, "/error"); + auto reason = error.IsOK() ? error.Value() : "unknown"; + RejectedTokens_.Increment(); + return TError(NRpc::EErrorCode::InvalidCredentials, "Blackbox rejected token") + << TErrorAttribute("reason", reason); + } + + auto login = Blackbox_->GetLogin(data); + auto oauthClientId = GetByYPath<TString>(data, "/oauth/client_id"); + auto oauthClientName = GetByYPath<TString>(data, "/oauth/client_name"); + auto oauthScope = GetByYPath<TString>(data, "/oauth/scope"); + + // Sanity checks. + if (!login.IsOK() || !oauthClientId.IsOK() || !oauthClientName.IsOK() || !oauthScope.IsOK()) { + auto error = TError("Blackbox returned invalid response"); + if (!login.IsOK()) error.MutableInnerErrors()->push_back(login); + if (!oauthClientId.IsOK()) error.MutableInnerErrors()->push_back(oauthClientId); + if (!oauthClientName.IsOK()) error.MutableInnerErrors()->push_back(oauthClientName); + if (!oauthScope.IsOK()) error.MutableInnerErrors()->push_back(oauthScope); + + InvalidBlackboxResponses_.Increment(); + return error; + } + + // Check that token provides valid scope. + // `oauthScope` is space-delimited list of provided scopes. + if (Config_->EnableScopeCheck) { + bool matchedScope = false; + TStringBuf providedScopes(oauthScope.Value()); + TStringBuf providedScope; + while (providedScopes.NextTok(' ', providedScope)) { + if (providedScope == Config_->Scope) { + matchedScope = true; + } + } + if (!matchedScope) { + TokenScopeCheckErrors_.Increment(); + return TError(NRpc::EErrorCode::InvalidCredentials, "Token does not provide a valid scope") + << TErrorAttribute("scope", oauthScope.Value()); + } + } + + // Check that token was issued by a known application. + TAuthenticationResult result; + result.Login = login.Value(); + result.Realm = "blackbox:token:" + oauthClientId.Value() + ":" + oauthClientName.Value(); + auto userTicket = GetByYPath<TString>(data, "/user_ticket"); + if (userTicket.IsOK()) { + result.UserTicket = userTicket.Value(); + } else if (Config_->GetUserTicket) { + return TError("Failed to retrieve user ticket"); + } + return result; + } +}; + +ITokenAuthenticatorPtr CreateBlackboxTokenAuthenticator( + TBlackboxTokenAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService, + NProfiling::TProfiler profiler) +{ + return New<TBlackboxTokenAuthenticator>( + std::move(config), + std::move(blackboxService), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TLegacyCypressTokenAuthenticator + : public ITokenAuthenticator +{ +public: + TLegacyCypressTokenAuthenticator( + TCypressTokenAuthenticatorConfigPtr config, + IClientPtr client) + : Config_(std::move(config)) + , Client_(std::move(client)) + { } + + TFuture<TAuthenticationResult> Authenticate( + const TTokenCredentials& credentials) override + { + const auto& token = credentials.Token; + const auto& userIP = credentials.UserIP; + auto tokenHash = GetCryptoHash(token); + YT_LOG_DEBUG("Authenticating user with token via Cypress (TokenHash: %v, UserIP: %v)", + tokenHash, + userIP); + + auto path = Config_->RootPath + "/" + ToYPathLiteral(Config_->Secure ? tokenHash : token); + return Client_->GetNode(path) + .Apply(BIND( + &TLegacyCypressTokenAuthenticator::OnCallResult, + MakeStrong(this), + std::move(tokenHash))); + } + +private: + const TCypressTokenAuthenticatorConfigPtr Config_; + const IClientPtr Client_; + +private: + TAuthenticationResult OnCallResult(const TString& tokenHash, const TErrorOr<TYsonString>& callResult) + { + if (!callResult.IsOK()) { + if (callResult.FindMatching(NYTree::EErrorCode::ResolveError)) { + YT_LOG_DEBUG(callResult, "Token is missing in Cypress (TokenHash: %v)", + tokenHash); + THROW_ERROR_EXCEPTION("Token is missing in Cypress"); + } else { + YT_LOG_DEBUG(callResult, "Cypress authentication failed (TokenHash: %v)", + tokenHash); + THROW_ERROR_EXCEPTION("Cypress authentication failed") + << TErrorAttribute("token_hash", tokenHash) + << callResult; + } + } + + const auto& ysonString = callResult.Value(); + try { + TAuthenticationResult authResult; + authResult.Login = ConvertTo<TString>(ysonString); + authResult.Realm = Config_->Realm; + YT_LOG_DEBUG("Cypress authentication successful (TokenHash: %v, Login: %v)", + tokenHash, + authResult.Login); + return authResult; + } catch (const std::exception& ex) { + YT_LOG_DEBUG(callResult, "Cypress contains malformed authentication entry (TokenHash: %v)", + tokenHash); + THROW_ERROR_EXCEPTION("Malformed Cypress authentication entry") + << TErrorAttribute("token_hash", tokenHash); + } + } +}; + +ITokenAuthenticatorPtr CreateLegacyCypressTokenAuthenticator( + TCypressTokenAuthenticatorConfigPtr config, + IClientPtr client) +{ + return New<TLegacyCypressTokenAuthenticator>(std::move(config), std::move(client)); +} + +//////////////////////////////////////////////////////////////////////////////// + +struct TTokenAuthenticatorCacheKey +{ + TTokenCredentials Credentials; + + operator size_t() const + { + size_t result = 0; + HashCombine(result, Credentials.Token); + return result; + } + + bool operator == (const TTokenAuthenticatorCacheKey& other) const + { + return + Credentials.Token == other.Credentials.Token; + } +}; + +class TCachingTokenAuthenticator + : public ITokenAuthenticator + , public TAuthCache<TString, TAuthenticationResult, NNet::TNetworkAddress> +{ +public: + TCachingTokenAuthenticator( + TCachingTokenAuthenticatorConfigPtr config, + ITokenAuthenticatorPtr tokenAuthenticator, + NProfiling::TProfiler profiler) + : TAuthCache(config->Cache, std::move(profiler)) + , TokenAuthenticator_(std::move(tokenAuthenticator)) + { } + + TFuture<TAuthenticationResult> Authenticate(const TTokenCredentials& credentials) override + { + return Get(credentials.Token, credentials.UserIP); + } + +private: + const ITokenAuthenticatorPtr TokenAuthenticator_; + + TFuture<TAuthenticationResult> DoGet( + const TString& token, + const NNet::TNetworkAddress& userIP) noexcept override + { + return TokenAuthenticator_->Authenticate(TTokenCredentials{token, userIP}); + } +}; + +ITokenAuthenticatorPtr CreateCachingTokenAuthenticator( + TCachingTokenAuthenticatorConfigPtr config, + ITokenAuthenticatorPtr authenticator, + NProfiling::TProfiler profiler) +{ + return New<TCachingTokenAuthenticator>( + std::move(config), + std::move(authenticator), + std::move(profiler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TCompositeTokenAuthenticator + : public ITokenAuthenticator +{ +public: + explicit TCompositeTokenAuthenticator(std::vector<ITokenAuthenticatorPtr> authenticators) + : Authenticators_(std::move(authenticators)) + { } + + TFuture<TAuthenticationResult> Authenticate( + const TTokenCredentials& credentials) override + { + return New<TAuthenticationSession>(this, credentials)->GetResult(); + } + +private: + const std::vector<ITokenAuthenticatorPtr> Authenticators_; + + class TAuthenticationSession + : public TRefCounted + { + public: + TAuthenticationSession( + TIntrusivePtr<TCompositeTokenAuthenticator> owner, + const TTokenCredentials& credentials) + : Owner_(std::move(owner)) + , Credentials_(credentials) + { + InvokeNext(); + } + + TFuture<TAuthenticationResult> GetResult() + { + return Promise_; + } + + private: + const TIntrusivePtr<TCompositeTokenAuthenticator> Owner_; + const TTokenCredentials Credentials_; + + TPromise<TAuthenticationResult> Promise_ = NewPromise<TAuthenticationResult>(); + std::vector<TError> Errors_; + size_t CurrentIndex_ = 0; + + private: + void InvokeNext() + { + if (CurrentIndex_ >= Owner_->Authenticators_.size()) { + Promise_.Set(TError(NSecurityClient::EErrorCode::AuthenticationError, "Authentication failed") + << Errors_); + return; + } + + const auto& authenticator = Owner_->Authenticators_[CurrentIndex_++]; + authenticator->Authenticate(Credentials_).Subscribe( + BIND([=, this, this_ = MakeStrong(this)] (const TErrorOr<TAuthenticationResult>& result) { + if (result.IsOK()) { + Promise_.Set(result.Value()); + } else { + Errors_.push_back(result); + InvokeNext(); + } + })); + } + }; +}; + +ITokenAuthenticatorPtr CreateCompositeTokenAuthenticator( + std::vector<ITokenAuthenticatorPtr> authenticators) +{ + return New<TCompositeTokenAuthenticator>(std::move(authenticators)); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TNoopTokenAuthenticator + : public ITokenAuthenticator +{ +public: + TFuture<TAuthenticationResult> Authenticate(const TTokenCredentials& /*credentials*/) override + { + static const auto Realm = TString("noop"); + static const auto UserTicket = TString(""); + TAuthenticationResult result{ + .Login = NRpc::RootUserName, + .Realm = Realm, + .UserTicket = UserTicket, + }; + return MakeFuture<TAuthenticationResult>(result); + } +}; + +ITokenAuthenticatorPtr CreateNoopTokenAuthenticator() +{ + return New<TNoopTokenAuthenticator>(); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TTokenAuthenticatorWrapper + : public NRpc::IAuthenticator +{ +public: + explicit TTokenAuthenticatorWrapper(ITokenAuthenticatorPtr underlying) + : Underlying_(std::move(underlying)) + { } + + bool CanAuthenticate(const NRpc::TAuthenticationContext& context) override + { + if (!context.Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext)) { + return false; + } + const auto& ext = context.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext); + return ext.has_token(); + } + + TFuture<NRpc::TAuthenticationResult> AsyncAuthenticate( + const NRpc::TAuthenticationContext& context) override + { + YT_ASSERT(CanAuthenticate(context)); + const auto& ext = context.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext); + TTokenCredentials credentials; + credentials.UserIP = context.UserIP; + credentials.Token = ext.token(); + return Underlying_->Authenticate(credentials).Apply( + BIND([=] (const TAuthenticationResult& authResult) { + NRpc::TAuthenticationResult rpcResult; + rpcResult.User = authResult.Login; + rpcResult.Realm = authResult.Realm; + rpcResult.UserTicket = authResult.UserTicket; + return rpcResult; + })); + } +private: + const ITokenAuthenticatorPtr Underlying_; +}; + +NRpc::IAuthenticatorPtr CreateTokenAuthenticatorWrapper(ITokenAuthenticatorPtr underlying) +{ + return New<TTokenAuthenticatorWrapper>(std::move(underlying)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/token_authenticator.h b/yt/yt/library/auth_server/token_authenticator.h new file mode 100644 index 0000000000..57eb7a841e --- /dev/null +++ b/yt/yt/library/auth_server/token_authenticator.h @@ -0,0 +1,54 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/api/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/public.h> + +#include <yt/yt/core/rpc/public.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ITokenAuthenticator + : public virtual TRefCounted +{ + virtual TFuture<TAuthenticationResult> Authenticate( + const TTokenCredentials& credentials) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ITokenAuthenticator) + +//////////////////////////////////////////////////////////////////////////////// + +ITokenAuthenticatorPtr CreateBlackboxTokenAuthenticator( + TBlackboxTokenAuthenticatorConfigPtr config, + IBlackboxServicePtr blackboxService, + NProfiling::TProfiler profiler = {}); + +// This authenticator was created before simple authentication scheme +// and should be removed one day. +ITokenAuthenticatorPtr CreateLegacyCypressTokenAuthenticator( + TCypressTokenAuthenticatorConfigPtr config, + NApi::IClientPtr client); + +ITokenAuthenticatorPtr CreateCachingTokenAuthenticator( + TCachingTokenAuthenticatorConfigPtr config, + ITokenAuthenticatorPtr authenticator, + NProfiling::TProfiler profiler = {}); + +ITokenAuthenticatorPtr CreateCompositeTokenAuthenticator( + std::vector<ITokenAuthenticatorPtr> authenticators); + +ITokenAuthenticatorPtr CreateNoopTokenAuthenticator(); + +NRpc::IAuthenticatorPtr CreateTokenAuthenticatorWrapper( + ITokenAuthenticatorPtr underlying); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/auth_server/ya.make b/yt/yt/library/auth_server/ya.make new file mode 100644 index 0000000000..3461577697 --- /dev/null +++ b/yt/yt/library/auth_server/ya.make @@ -0,0 +1,46 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + config.cpp + authentication_manager.cpp + batching_secret_vault_service.cpp + blackbox_cookie_authenticator.cpp + caching_secret_vault_service.cpp + cookie_authenticator.cpp + blackbox_service.cpp + default_secret_vault_service.cpp + dummy_secret_vault_service.cpp + helpers.cpp + cypress_cookie.cpp + cypress_cookie_authenticator.cpp + cypress_cookie_login.cpp + cypress_cookie_manager.cpp + cypress_cookie_store.cpp + cypress_token_authenticator.cpp + cypress_user_manager.cpp + secret_vault_service.cpp + ticket_authenticator.cpp + token_authenticator.cpp + oauth_cookie_authenticator.cpp + oauth_token_authenticator.cpp + oauth_service.cpp +) + +PEERDIR( + # For Cypress token authenticator and Cypress cookie authenticator. + yt/yt/client + + yt/yt/core/https + yt/yt/library/auth + yt/yt/library/tvm/service + library/cpp/string_utils/quote + library/cpp/string_utils/url +) + +END() + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/backtrace_introspector/http/handler.cpp b/yt/yt/library/backtrace_introspector/http/handler.cpp new file mode 100644 index 0000000000..367e3105c0 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/handler.cpp @@ -0,0 +1,89 @@ +#include "handler.h" + +#include <yt/yt/core/http/server.h> + +#include <yt/yt/core/concurrency/action_queue.h> + +#include <yt/yt/library/backtrace_introspector/introspect.h> + +namespace NYT::NBacktraceIntrospector { + +using namespace NHttp; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class THandlerBase + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /*req*/, const IResponseWriterPtr& rsp) override + { + try { + auto dumpFuture = BIND(&THandlerBase::Dump, MakeStrong(this)) + .AsyncVia(Queue_->GetInvoker()) + .Run(); + + auto dump = WaitFor(dumpFuture) + .ValueOrThrow(); + + WaitFor(rsp->WriteBody(TSharedRef::FromString(dump))) + .ThrowOnError(); + + WaitFor(rsp->Close()) + .ThrowOnError(); + } catch (const std::exception& ex) { + if (!rsp->AreHeadersFlushed()) { + rsp->SetStatus(EStatusCode::InternalServerError); + WaitFor(rsp->WriteBody(TSharedRef::FromString(ex.what()))) + .ThrowOnError(); + } + throw; + } + } + +private: + static inline const TActionQueuePtr Queue_ = New<TActionQueue>("BacktraceIntro"); + +protected: + virtual TString Dump() = 0; +}; + +class TThreadsHandler + : public THandlerBase +{ +private: + TString Dump() override + { + return FormatIntrospectionInfos(IntrospectThreads()); + } +}; + +class TFibersHandler + : public THandlerBase +{ +private: + TString Dump() override + { + return FormatIntrospectionInfos(IntrospectFibers()); + } +}; + +void Register( + const IRequestPathMatcherPtr& handlers, + const TString& prefix) +{ + handlers->Add(prefix + "/threads", New<TThreadsHandler>()); + handlers->Add(prefix + "/fibers", New<TFibersHandler>()); +} + +void Register( + const IServerPtr& server, + const TString& prefix) +{ + Register(server->GetPathMatcher(), prefix); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/http/handler.h b/yt/yt/library/backtrace_introspector/http/handler.h new file mode 100644 index 0000000000..be795b7e5d --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/handler.h @@ -0,0 +1,20 @@ +#pragma once + +#include <yt/yt/core/http/public.h> + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +//! Registers introspector handlers. +void Register( + const NHttp::IRequestPathMatcherPtr& handlers, + const TString& prefix = {}); + +void Register( + const NHttp::IServerPtr& server, + const TString& prefix = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/http/ya.make b/yt/yt/library/backtrace_introspector/http/ya.make new file mode 100644 index 0000000000..504d20a2e3 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/http/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + handler.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/core/http + + yt/yt/library/backtrace_introspector +) + +END() diff --git a/yt/yt/library/backtrace_introspector/introspect.cpp b/yt/yt/library/backtrace_introspector/introspect.cpp new file mode 100644 index 0000000000..592c232f0f --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect.cpp @@ -0,0 +1,216 @@ +#include "introspect.h" + +#include "private.h" + +#include <yt/yt/core/misc/finally.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/concurrency/fiber.h> +#include <yt/yt/core/concurrency/scheduler_api.h> + +#include <yt/yt/core/tracing/trace_context.h> + +#include <library/cpp/yt/memory/safe_memory_reader.h> + +#include <library/cpp/yt/backtrace/backtrace.h> + +#include <library/cpp/yt/backtrace/cursors/libunwind/libunwind_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/interop/interop.h> + +#include <util/system/yield.h> + +namespace NYT::NBacktraceIntrospector { + +using namespace NConcurrency; +using namespace NThreading; +using namespace NTracing; +using namespace NBacktrace; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = BacktraceIntrospectorLogger; + +//////////////////////////////////////////////////////////////////////////////// + +std::vector<TFiberIntrospectionInfo> IntrospectFibers() +{ + YT_LOG_INFO("Fiber introspection started"); + + auto fibers = TFiber::List(); + + YT_LOG_INFO("Collecting waiting fibers backtraces"); + + std::vector<TFiberIntrospectionInfo> infos; + THashSet<TFiberId> waitingFiberIds; + THashSet<TFiberId> fiberIds; + for (const auto& fiber : fibers) { + auto fiberId = fiber->GetFiberId(); + if (fiberId == InvalidFiberId) { + continue; + } + + InsertOrCrash(fiberIds, fiberId); + + EFiberState state; + if (!fiber->TryIntrospectWaiting(state, [&] { + YT_LOG_DEBUG("Waiting fiber is successfully locked for introspection (FiberId: %x)", + fiberId); + + const auto& propagatingStorage = fiber->GetPropagatingStorage(); + const auto* traceContext = TryGetTraceContextFromPropagatingStorage(propagatingStorage); + + TFiberIntrospectionInfo info{ + .State = EFiberState::Waiting, + .FiberId = fiberId, + .WaitingSince = fiber->GetWaitingSince(), + .TraceId = traceContext ? traceContext->GetTraceId() : TTraceId(), + .TraceLoggingTag = traceContext ? traceContext->GetLoggingTag() : TString(), + }; + + auto optionalContext = TrySynthesizeLibunwindContextFromMachineContext(*fiber->GetMachineContext()); + if (!optionalContext) { + YT_LOG_WARNING("Failed to synthesize libunwind context (FiberId: %x)", + fiberId); + return; + } + + TLibunwindCursor cursor(*optionalContext); + while (!cursor.IsFinished()) { + info.Backtrace.push_back(cursor.GetCurrentIP()); + cursor.MoveNext(); + } + + infos.push_back(std::move(info)); + InsertOrCrash(waitingFiberIds, fiberId); + + YT_LOG_DEBUG("Fiber introspection completed (FiberId: %x)", + info.FiberId); + })) { + YT_LOG_DEBUG("Failed to lock fiber for introspection (FiberId: %x, State: %v)", + fiberId, + state); + } + } + + YT_LOG_INFO("Collecting running fibers backtraces"); + + THashSet<TFiberId> runningFiberIds; + for (auto& info : IntrospectThreads()) { + if (info.FiberId == InvalidFiberId) { + continue; + } + + if (waitingFiberIds.contains(info.FiberId)) { + continue; + } + + if (!runningFiberIds.insert(info.FiberId).second) { + continue; + } + + infos.push_back(TFiberIntrospectionInfo{ + .State = EFiberState::Running, + .FiberId = info.FiberId, + .ThreadId = info.ThreadId, + .ThreadName = std::move(info.ThreadName), + .TraceId = info.TraceId, + .TraceLoggingTag = std::move(info.TraceLoggingTag), + .Backtrace = std::move(info.Backtrace), + }); + } + + for (const auto& fiber : fibers) { + auto fiberId = fiber->GetFiberId(); + if (fiberId == InvalidFiberId) { + continue; + } + if (runningFiberIds.contains(fiberId)) { + continue; + } + if (waitingFiberIds.contains(fiberId)) { + continue; + } + + infos.push_back(TFiberIntrospectionInfo{ + .State = fiber->GetState(), + .FiberId = fiberId, + }); + } + + YT_LOG_INFO("Fiber introspection completed"); + + return infos; +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void FormatBacktrace(TStringBuilder* builder, const std::vector<const void*>& backtrace) +{ + if (!backtrace.empty()) { + builder->AppendString("Backtrace:\n"); + SymbolizeBacktrace( + MakeRange(backtrace), + [&] (TStringBuf str) { + builder->AppendFormat(" %v", str); + }); + } +} + +} // namespace + +TString FormatIntrospectionInfos(const std::vector<TThreadIntrospectionInfo>& infos) +{ + TStringBuilder builder; + for (const auto& info : infos) { + builder.AppendFormat("Thread id: %v\n", info.ThreadId); + builder.AppendFormat("Thread name: %v\n", info.ThreadName); + if (info.FiberId != InvalidFiberId) { + builder.AppendFormat("Fiber id: %x\n", info.FiberId); + } + if (info.TraceId) { + builder.AppendFormat("Trace id: %v\n", info.TraceId); + } + if (info.TraceLoggingTag) { + builder.AppendFormat("Trace logging tag: %v\n", info.TraceLoggingTag); + } + FormatBacktrace(&builder, info.Backtrace); + builder.AppendString("\n"); + } + return builder.Flush(); +} + +TString FormatIntrospectionInfos(const std::vector<TFiberIntrospectionInfo>& infos) +{ + TStringBuilder builder; + for (const auto& info : infos) { + builder.AppendFormat("Fiber id: %x\n", info.FiberId); + builder.AppendFormat("State: %v\n", info.State); + if (info.WaitingSince) { + builder.AppendFormat("Waiting since: %v\n", info.WaitingSince); + } + if (info.ThreadId != InvalidThreadId) { + builder.AppendFormat("Thread id: %v\n", info.ThreadId); + } + if (!info.ThreadName.empty()) { + builder.AppendFormat("Thread name: %v\n", info.ThreadName); + } + if (info.TraceId) { + builder.AppendFormat("Trace id: %v\n", info.TraceId); + } + if (info.TraceLoggingTag) { + builder.AppendFormat("Trace logging tag: %v\n", info.TraceLoggingTag); + } + FormatBacktrace(&builder, info.Backtrace); + builder.AppendString("\n"); + } + return builder.Flush(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/introspect.h b/yt/yt/library/backtrace_introspector/introspect.h new file mode 100644 index 0000000000..2be09d2ec8 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect.h @@ -0,0 +1,57 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/threading/public.h> + +#include <yt/yt/core/tracing/public.h> + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// +// Thread introspection API + +struct TThreadIntrospectionInfo +{ + NThreading::TThreadId ThreadId; + NConcurrency::TFiberId FiberId; + TString ThreadName; + NTracing::TTraceId TraceId; + //! Empty if no trace context is known. + TString TraceLoggingTag; + std::vector<const void*> Backtrace; +}; + +std::vector<TThreadIntrospectionInfo> IntrospectThreads(); + +//////////////////////////////////////////////////////////////////////////////// +// Fiber introspection API + +struct TFiberIntrospectionInfo +{ + NConcurrency::EFiberState State; + NConcurrency::TFiberId FiberId; + //! Zero if fiber is not waiting. + TInstant WaitingSince; + //! |InvalidThreadId| is fiber is not running. + NThreading::TThreadId ThreadId; + //! Empty if fiber is not running. + TString ThreadName; + NTracing::TTraceId TraceId; + //! Empty if no trace context is known. + TString TraceLoggingTag; + std::vector<const void*> Backtrace; +}; + +std::vector<TFiberIntrospectionInfo> IntrospectFibers(); + +//////////////////////////////////////////////////////////////////////////////// + +TString FormatIntrospectionInfos(const std::vector<TThreadIntrospectionInfo>& infos); +TString FormatIntrospectionInfos(const std::vector<TFiberIntrospectionInfo>& infos); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/introspect_linux.cpp b/yt/yt/library/backtrace_introspector/introspect_linux.cpp new file mode 100644 index 0000000000..3fc1a077f6 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/introspect_linux.cpp @@ -0,0 +1,211 @@ +#include "introspect.h" + +#include "private.h" + +#include <yt/yt/core/misc/finally.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/concurrency/fiber.h> +#include <yt/yt/core/concurrency/scheduler_api.h> + +#include <yt/yt/core/tracing/trace_context.h> + +#include <library/cpp/yt/memory/safe_memory_reader.h> + +#include <library/cpp/yt/backtrace/backtrace.h> + +#include <library/cpp/yt/backtrace/cursors/libunwind/libunwind_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/frame_pointer/frame_pointer_cursor.h> + +#include <library/cpp/yt/backtrace/cursors/interop/interop.h> + +#include <library/cpp/yt/misc/thread_name.h> + +#include <util/system/yield.h> + +#include <sys/syscall.h> + +namespace NYT::NBacktraceIntrospector { + +using namespace NConcurrency; +using namespace NTracing; +using namespace NBacktrace; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = BacktraceIntrospectorLogger; + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct TStaticString +{ + TStaticString() = default; + + explicit TStaticString(TStringBuf str) + { + Length = std::min(std::ssize(str), std::ssize(Buffer)); + std::copy(str.data(), str.data() + Length, Buffer.data()); + } + + operator TString() const + { + return TString(Buffer.data(), static_cast<size_t>(Length)); + } + + std::array<char, 256> Buffer; + int Length = 0; +}; + +struct TStaticBacktrace +{ + operator std::vector<const void*>() const + { + return std::vector<const void*>(Frames.data(), Frames.data() + FrameCount); + } + + std::array<const void*, 100> Frames; + int FrameCount = 0; +}; + +struct TSignalHandlerContext +{ + TSignalHandlerContext(); + ~TSignalHandlerContext(); + + std::atomic<bool> Finished = false; + + TFiberId FiberId = {}; + TTraceId TraceId = {}; + TStaticString TraceLoggingTag; + TStaticBacktrace Backtrace; + TThreadName ThreadName = {}; + + TSafeMemoryReader* MemoryReader = Singleton<TSafeMemoryReader>(); + + void SetFinished() + { + Finished.store(true); + } + + void WaitUntilFinished() + { + while (!Finished.load()) { + ThreadYield(); + } + } +}; + +static TSignalHandlerContext* SignalHandlerContext; + +TSignalHandlerContext::TSignalHandlerContext() +{ + YT_VERIFY(!SignalHandlerContext); + SignalHandlerContext = this; +} + +TSignalHandlerContext::~TSignalHandlerContext() +{ + YT_VERIFY(SignalHandlerContext == this); + SignalHandlerContext = nullptr; +} + +void SignalHandler(int sig, siginfo_t* /*info*/, void* threadContext) +{ + YT_VERIFY(sig == SIGUSR1); + + SignalHandlerContext->FiberId = GetCurrentFiberId(); + SignalHandlerContext->ThreadName = GetCurrentThreadName(); + if (const auto* traceContext = TryGetCurrentTraceContext()) { + SignalHandlerContext->TraceId = traceContext->GetTraceId(); + SignalHandlerContext->TraceLoggingTag = TStaticString(traceContext->GetLoggingTag()); + } + + auto cursorContext = FramePointerCursorContextFromUcontext(*static_cast<const ucontext_t*>(threadContext)); + TFramePointerCursor cursor(SignalHandlerContext->MemoryReader, cursorContext); + while (!cursor.IsFinished() && SignalHandlerContext->Backtrace.FrameCount < std::ssize(SignalHandlerContext->Backtrace.Frames)) { + SignalHandlerContext->Backtrace.Frames[SignalHandlerContext->Backtrace.FrameCount++] = cursor.GetCurrentIP(); + cursor.MoveNext(); + } + + SignalHandlerContext->SetFinished(); +} + +} // namespace + +std::vector<TThreadIntrospectionInfo> IntrospectThreads() +{ + static std::atomic<bool> IntrospectionLock; + + if (IntrospectionLock.exchange(true)) { + THROW_ERROR_EXCEPTION("Thread introspection is already in progress"); + } + + auto introspectionLockGuard = Finally([] { + YT_VERIFY(IntrospectionLock.exchange(false)); + }); + + YT_LOG_INFO("Thread introspection started"); + + { + struct sigaction action; + action.sa_flags = SA_SIGINFO | SA_RESTART; + ::sigemptyset(&action.sa_mask); + action.sa_sigaction = SignalHandler; + + if (::sigaction(SIGUSR1, &action, nullptr) != 0) { + THROW_ERROR_EXCEPTION("Failed to install signal handler") + << TError::FromSystem(); + } + } + + std::vector<TThreadIntrospectionInfo> infos; + for (auto threadId : GetCurrentProcessThreadIds()) { + TSignalHandlerContext signalHandlerContext; + if (::syscall(SYS_tkill, threadId, SIGUSR1) != 0) { + YT_LOG_DEBUG(TError::FromSystem(), "Failed to signal to thread (ThreadId: %v)", + threadId); + continue; + } + + YT_LOG_DEBUG("Sent signal to thread (ThreadId: %v)", + threadId); + + signalHandlerContext.WaitUntilFinished(); + + YT_LOG_DEBUG("Signal handler finished (ThreadId: %v, FiberId: %x)", + threadId, + signalHandlerContext.FiberId); + + infos.push_back(TThreadIntrospectionInfo{ + .ThreadId = threadId, + .FiberId = signalHandlerContext.FiberId, + .ThreadName = TString(signalHandlerContext.ThreadName.Buffer.data(), static_cast<size_t>(signalHandlerContext.ThreadName.Length)), + .TraceId = signalHandlerContext.TraceId, + .TraceLoggingTag = signalHandlerContext.TraceLoggingTag, + .Backtrace = signalHandlerContext.Backtrace, + }); + } + + { + struct sigaction action; + action.sa_flags = SA_RESTART; + ::sigemptyset(&action.sa_mask); + action.sa_handler = SIG_IGN; + + if (::sigaction(SIGUSR1, &action, nullptr) != 0) { + THROW_ERROR_EXCEPTION("Failed to de-install signal handler") + << TError::FromSystem(); + } + } + + YT_LOG_INFO("Thread introspection completed"); + + return infos; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/private.h b/yt/yt/library/backtrace_introspector/private.h new file mode 100644 index 0000000000..59f25e6023 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/private.h @@ -0,0 +1,16 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger BacktraceIntrospectorLogger("BacktraceIntrospector"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector + diff --git a/yt/yt/library/backtrace_introspector/public.h b/yt/yt/library/backtrace_introspector/public.h new file mode 100644 index 0000000000..54a8bd06ed --- /dev/null +++ b/yt/yt/library/backtrace_introspector/public.h @@ -0,0 +1,12 @@ +#pragma once + +namespace NYT::NBacktraceIntrospector { + +//////////////////////////////////////////////////////////////////////////////// + +struct TThreadIntrospectionInfo; +struct TFiberIntrospectionInfo; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBacktraceIntrospector diff --git a/yt/yt/library/backtrace_introspector/ya.make b/yt/yt/library/backtrace_introspector/ya.make new file mode 100644 index 0000000000..884b8fb562 --- /dev/null +++ b/yt/yt/library/backtrace_introspector/ya.make @@ -0,0 +1,31 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + introspect.cpp +) +IF (OS_LINUX) + SRCS(introspect_linux.cpp) +ELSE() + SRCS(introspect_dummy.cpp) +ENDIF() + +PEERDIR( + yt/yt/core + + library/cpp/yt/backtrace/cursors/interop + library/cpp/yt/backtrace/cursors/libunwind + library/cpp/yt/backtrace/cursors/frame_pointer + library/cpp/yt/misc +) + +END() + +RECURSE( + http +) + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/column_converters/boolean_column_converter.cpp b/yt/yt/library/column_converters/boolean_column_converter.cpp new file mode 100644 index 0000000000..37e27bc56c --- /dev/null +++ b/yt/yt/library/column_converters/boolean_column_converter.cpp @@ -0,0 +1,100 @@ +#include "boolean_column_converter.h" + +#include "helpers.h" + +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void FillColumnarBooleanValues( + TBatchColumn* column, + i64 startIndex, + i64 valueCount, + TRef bitmap) +{ + column->StartIndex = startIndex; + column->ValueCount = valueCount; + + auto& values = column->Values.emplace(); + values.BitWidth = 1; + values.Data = bitmap; +} + +//////////////////////////////////////////////////////////////////////////////// + +class TBooleanColumnConverter + : public IColumnConverter +{ +public: + TBooleanColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) + : ColumnIndex_(columnIndex) + , ColumnSchema_(columnSchema) + { } + + TConvertedColumn Convert(const std::vector<TUnversionedRowValues>& rowsValues) override + { + Reset(); + AddValues(rowsValues); + + auto column = std::make_shared<TBatchColumn>(); + auto nullBitmapRef = NullBitmap_.Flush<TConverterTag>(); + auto valuesRef = Values_.Flush<TConverterTag>(); + + FillColumnarBooleanValues(column.get(), 0, rowsValues.size(), valuesRef); + FillColumnarNullBitmap(column.get(), 0, rowsValues.size(), nullBitmapRef); + + column->Type = ColumnSchema_.LogicalType(); + column->Id = ColumnIndex_; + + TOwningColumn owner = { + .Column = std::move(column), + .NullBitmap = std::move(nullBitmapRef), + .ValueBuffer = std::move(valuesRef), + }; + + return {{owner}, owner.Column.get()}; + } + + +private: + const int ColumnIndex_; + const NTableClient::TColumnSchema ColumnSchema_; + + TBitmapOutput Values_; + TBitmapOutput NullBitmap_; + + void Reset() + { + Values_.Flush<TConverterTag>(); + NullBitmap_.Flush<TConverterTag>(); + } + + void AddValues(const std::vector<TUnversionedRowValues>& rowsValues) + { + for (auto rowValues : rowsValues) { + auto value = rowValues[ColumnIndex_]; + bool isNull = value == nullptr || value->Type == NTableClient::EValueType::Null; + bool data = isNull ? false : value->Data.Boolean; + NullBitmap_.Append(isNull); + Values_.Append(data); + } + } +}; + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateBooleanColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TBooleanColumnConverter>(columnIndex, columnSchema); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/boolean_column_converter.h b/yt/yt/library/column_converters/boolean_column_converter.h new file mode 100644 index 0000000000..0495c4a188 --- /dev/null +++ b/yt/yt/library/column_converters/boolean_column_converter.h @@ -0,0 +1,15 @@ +#pragma once + +#include "column_converter.h" + +#include <yt/yt/client/table_client/public.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateBooleanColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/column_converter.cpp b/yt/yt/library/column_converters/column_converter.cpp new file mode 100644 index 0000000000..21c9982549 --- /dev/null +++ b/yt/yt/library/column_converters/column_converter.cpp @@ -0,0 +1,91 @@ +#include "column_converter.h" + +#include "boolean_column_converter.h" +#include "floating_point_column_converter.h" +#include "integer_column_converter.h" +#include "null_column_converter.h" +#include "string_column_converter.h" + +#include <yt/yt/client/table_client/row_base.h> +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NColumnConverters { + +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateColumnConvert( + const NTableClient::TColumnSchema& columnSchema, + int columnIndex) +{ + switch (columnSchema.GetWireType()) { + case EValueType::Int64: + return CreateInt64ColumnConverter(columnIndex, columnSchema); + + case EValueType::Uint64: + return CreateUint64ColumnConverter(columnIndex, columnSchema); + + case EValueType::Double: + switch (columnSchema.CastToV1Type()) { + case NTableClient::ESimpleLogicalValueType::Float: + return CreateFloatingPoint32ColumnConverter(columnIndex, columnSchema); + default: + return CreateFloatingPoint64ColumnConverter(columnIndex, columnSchema); + } + + case EValueType::String: + return CreateStringConverter(columnIndex, columnSchema); + + case EValueType::Boolean: + return CreateBooleanColumnConverter(columnIndex, columnSchema); + + case EValueType::Any: + return CreateAnyConverter(columnIndex, columnSchema); + + case EValueType::Composite: + return CreateCompositeConverter(columnIndex, columnSchema); + + case EValueType::Null: + return CreateNullConverter(columnIndex); + + case EValueType::Min: + case EValueType::TheBottom: + case EValueType::Max: + break; + } + ThrowUnexpectedValueType(columnSchema.GetWireType()); +} + +//////////////////////////////////////////////////////////////////////////////// + + +TConvertedColumnRange ConvertRowsToColumns( + TRange<TUnversionedRow> rows, + const std::vector<TColumnSchema>& columnSchema) +{ + TConvertedColumnRange convertedColumnsRange; + std::vector<TUnversionedRowValues> rowsValues; + rowsValues.reserve(rows.size()); + + for (const auto& row : rows) { + TUnversionedRowValues rowValues; + rowValues.resize(columnSchema.size(), nullptr); + for (const auto* item = row.Begin(); item != row.End(); ++item) { + rowValues[item->Id] = item; + } + rowsValues.push_back(std::move(rowValues)); + } + + for (int columnId = 0; columnId < std::ssize(columnSchema); columnId++) { + auto converter = CreateColumnConvert(columnSchema[columnId], columnId); + auto columns = converter->Convert(rowsValues); + convertedColumnsRange.push_back(columns); + } + return convertedColumnsRange; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/column_converter.h b/yt/yt/library/column_converters/column_converter.h new file mode 100644 index 0000000000..64cec2fd44 --- /dev/null +++ b/yt/yt/library/column_converters/column_converter.h @@ -0,0 +1,54 @@ +#pragma once + +#include <yt/yt/client/table_client/row_batch.h> + +#include <yt/yt/core/misc/bitmap.h> + +#include <library/cpp/yt/memory/ref.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +using TBatchColumn = NTableClient::IUnversionedColumnarRowBatch::TColumn; +using TBatchColumnPtr = std::shared_ptr<TBatchColumn>; +using TUnversionedRowValues = std::vector<const NTableClient::TUnversionedValue*>; + +//////////////////////////////////////////////////////////////////////////////// + +struct TOwningColumn +{ + TBatchColumnPtr Column; + TSharedRef NullBitmap; + TSharedRef ValueBuffer; + TSharedRef StringBuffer; +}; + +struct TConvertedColumn +{ + std::vector<TOwningColumn> Columns; + TBatchColumn* RootColumn; +}; + +using TConvertedColumnRange = std::vector<TConvertedColumn>; + +//////////////////////////////////////////////////////////////////////////////// + +struct IColumnConverter + : private TNonCopyable +{ + virtual ~IColumnConverter() = default; + virtual TConvertedColumn Convert(const std::vector<TUnversionedRowValues>& rowsValues) = 0; +}; + +using IColumnConverterPtr = std::unique_ptr<IColumnConverter>; + +//////////////////////////////////////////////////////////////////////////////// + +TConvertedColumnRange ConvertRowsToColumns( + TRange<NTableClient::TUnversionedRow> rows, + const std::vector<NTableClient::TColumnSchema>& columnSchema); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/floating_point_column_converter.cpp b/yt/yt/library/column_converters/floating_point_column_converter.cpp new file mode 100644 index 0000000000..bc18a53f14 --- /dev/null +++ b/yt/yt/library/column_converters/floating_point_column_converter.cpp @@ -0,0 +1,135 @@ +#include "floating_point_column_converter.h" + +#include "helpers.h" + +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NColumnConverters { + +using namespace NProto; +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +template <typename T> +void FillColumnarFloatingPointValues( + NTableClient::IUnversionedColumnarRowBatch::TColumn* column, + i64 startIndex, + i64 valueCount, + TRef data) +{ + column->StartIndex = startIndex; + column->ValueCount = valueCount; + + auto& values = column->Values.emplace(); + values.BitWidth = sizeof(T) * 8; + values.Data = data; +} + +//////////////////////////////////////////////////////////////////////////////// + +template <typename T> +TSharedRef SerializeFloatingPointVector(const std::vector<T>& values) +{ + auto data = TSharedMutableRef::Allocate<TConverterTag>(values.size() * sizeof(T) + sizeof(ui64), {.InitializeStorage = false}); + *reinterpret_cast<ui64*>(data.Begin()) = static_cast<ui64>(values.size()); + std::memcpy( + data.Begin() + sizeof(ui64), + values.data(), + values.size() * sizeof(T)); + return data; +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class TValue, NTableClient::EValueType ValueType> +class TFloatingPointColumnConverter + : public IColumnConverter +{ +public: + static_assert(std::is_floating_point_v<TValue>); + + TFloatingPointColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) + : ColumnIndex_(columnIndex) + , ColumnSchema_(columnSchema) + + { } + + TConvertedColumn Convert(const std::vector<TUnversionedRowValues>& rowsValues) + { + Reset(); + AddValues(rowsValues); + auto nullBitmapRef = NullBitmap_.Flush<TConverterTag>(); + auto valuesRef = TSharedRef::MakeCopy<TConverterTag>(TRef(Values_.data(), sizeof(TValue) * Values_.size())); + + auto column = std::make_shared<TBatchColumn>(); + + FillColumnarFloatingPointValues<TValue>( + column.get(), + 0, + rowsValues.size(), + valuesRef); + + FillColumnarNullBitmap( + column.get(), + 0, + rowsValues.size(), + nullBitmapRef); + + column->Type = ColumnSchema_.LogicalType(); + column->Id = ColumnIndex_; + + TOwningColumn owner = { + .Column = std::move(column), + .NullBitmap = std::move(nullBitmapRef), + .ValueBuffer = std::move(valuesRef), + }; + + return {{owner}, owner.Column.get()}; + } + +private: + const int ColumnIndex_; + const TColumnSchema ColumnSchema_; + + std::vector<TValue> Values_; + TBitmapOutput NullBitmap_; + + void Reset() + { + Values_.clear(); + NullBitmap_.Flush<TConverterTag>(); + } + + void AddValues(const std::vector<TUnversionedRowValues>& rowsValues) + { + for (auto rowValues : rowsValues) { + auto value = rowValues[ColumnIndex_]; + bool isNull = value == nullptr || value->Type == NTableClient::EValueType::Null; + TValue data = isNull ? 0 : value->Data.Double; + NullBitmap_.Append(isNull); + Values_.push_back(data); + } + } +}; + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateFloatingPoint32ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TFloatingPointColumnConverter<float, NTableClient::EValueType::Double>>(columnIndex, columnSchema); +} + +IColumnConverterPtr CreateFloatingPoint64ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TFloatingPointColumnConverter<double, NTableClient::EValueType::Double>>(columnIndex, columnSchema); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/floating_point_column_converter.h b/yt/yt/library/column_converters/floating_point_column_converter.h new file mode 100644 index 0000000000..3739d4e729 --- /dev/null +++ b/yt/yt/library/column_converters/floating_point_column_converter.h @@ -0,0 +1,15 @@ +#pragma once + +#include "column_converter.h" + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateFloatingPoint32ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema); + +IColumnConverterPtr CreateFloatingPoint64ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/helpers.cpp b/yt/yt/library/column_converters/helpers.cpp new file mode 100644 index 0000000000..cddac06d79 --- /dev/null +++ b/yt/yt/library/column_converters/helpers.cpp @@ -0,0 +1,59 @@ +#include "helpers.h" + +#include <yt/yt/client/table_client/columnar.h> +#include <yt/yt/client/table_client/logical_type.h> +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <yt/yt/core/misc/bitmap.h> + +namespace NYT::NColumnConverters { + +using namespace NProto; +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +void FillColumnarNullBitmap( + NTableClient::IUnversionedColumnarRowBatch::TColumn* column, + i64 startIndex, + i64 valueCount, + TRef bitmap) +{ + column->StartIndex = startIndex; + column->ValueCount = valueCount; + + auto& nullBitmap = column->NullBitmap.emplace(); + nullBitmap.Data = bitmap; +} + + +void FillColumnarDictionary( + NTableClient::IUnversionedColumnarRowBatch::TColumn* primaryColumn, + NTableClient::IUnversionedColumnarRowBatch::TColumn* dictionaryColumn, + NTableClient::IUnversionedColumnarRowBatch::TDictionaryId dictionaryId, + NTableClient::TLogicalTypePtr type, + i64 startIndex, + i64 valueCount, + TRef ids) +{ + primaryColumn->StartIndex = startIndex; + primaryColumn->ValueCount = valueCount; + + dictionaryColumn->Type = type && type->GetMetatype() == ELogicalMetatype::Optional + ? type->AsOptionalTypeRef().GetElement() + : type; + + auto& primaryValues = primaryColumn->Values.emplace(); + primaryValues.BitWidth = 32; + primaryValues.Data = ids; + + auto& dictionary = primaryColumn->Dictionary.emplace(); + dictionary.DictionaryId = dictionaryId; + dictionary.ZeroMeansNull = true; + dictionary.ValueColumn = dictionaryColumn; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/helpers.h b/yt/yt/library/column_converters/helpers.h new file mode 100644 index 0000000000..6957ff13c1 --- /dev/null +++ b/yt/yt/library/column_converters/helpers.h @@ -0,0 +1,39 @@ +#pragma once + +#include <yt/yt/client/table_client/row_batch.h> +#include <yt/yt/client/table_client/schema.h> + +#include <yt/yt/core/misc/common.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +void FillColumnarNullBitmap( + NTableClient::IUnversionedColumnarRowBatch::TColumn* column, + i64 startIndex, + i64 valueCount, + TRef bitmap); + +void FillColumnarDictionary( + NTableClient::IUnversionedColumnarRowBatch::TColumn* primaryColumn, + NTableClient::IUnversionedColumnarRowBatch::TColumn* dictionaryColumn, + NTableClient::IUnversionedColumnarRowBatch::TDictionaryId dictionaryId, + NTableClient::TLogicalTypePtr type, + i64 startIndex, + i64 valueCount, + TRef ids); + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EUnversionedStringSegmentType, + ((DictionaryDense) (0)) + ((DirectDense) (1)) +); + +struct TConverterTag +{}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/integer_column_converter.cpp b/yt/yt/library/column_converters/integer_column_converter.cpp new file mode 100644 index 0000000000..862c23e5b7 --- /dev/null +++ b/yt/yt/library/column_converters/integer_column_converter.cpp @@ -0,0 +1,175 @@ +#include "integer_column_converter.h" + +#include "helpers.h" + +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <library/cpp/yt/coding/zig_zag.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +ui64 EncodeValue(i64 value) +{ + return ZigZagEncode64(value); +} + +ui64 EncodeValue(ui64 value) +{ + return value; +} + +template <class TValue> +typename std::enable_if<std::is_signed<TValue>::value, TValue>::type +GetValue(const NTableClient::TUnversionedValue& value) +{ + return value.Data.Int64; +} + +template <class TValue> +typename std::enable_if<std::is_unsigned<TValue>::value, TValue>::type +GetValue(const NTableClient::TUnversionedValue& value) +{ + return value.Data.Uint64; +} + +//////////////////////////////////////////////////////////////////////////////// + +void FillColumnarIntegerValues( + NTableClient::IUnversionedColumnarRowBatch::TColumn* column, + i64 startIndex, + i64 valueCount, + NTableClient::EValueType valueType, + ui64 baseValue, + TRef data) +{ + column->StartIndex = startIndex; + column->ValueCount = valueCount; + + auto& values = column->Values.emplace(); + values.BaseValue = baseValue; + values.BitWidth = 64; + values.ZigZagEncoded = (valueType == NTableClient::EValueType::Int64); + values.Data = data; +} + +//////////////////////////////////////////////////////////////////////////////// + +// TValue - i64 or ui64. +template <class TValue> +class TIntegerColumnConverter + : public IColumnConverter +{ +public: + static_assert(std::is_integral_v<TValue>); + + TIntegerColumnConverter( + int columnIndex, + NTableClient::EValueType ValueType, + NTableClient::TColumnSchema columnSchema) + : ColumnIndex_(columnIndex) + , ColumnSchema_(columnSchema) + , ValueType_(ValueType) + { } + + TConvertedColumn Convert(const std::vector<TUnversionedRowValues>& rowsValues) override + { + Reset(); + AddValues(rowsValues); + for (i64 index = 0; index < std::ssize(Values_); ++index) { + if (!NullBitmap_[index]) { + Values_[index] -= MinValue_; + } + } + + auto nullBitmapRef = NullBitmap_.Flush<TConverterTag>(); + auto valuesRef = TSharedRef::MakeCopy<TConverterTag>(TRef(Values_.data(), sizeof(ui64) * Values_.size())); + auto column = std::make_shared<TBatchColumn>(); + + FillColumnarIntegerValues( + column.get(), + 0, + RowCount_, + ValueType_, + MinValue_, + valuesRef); + + FillColumnarNullBitmap( + column.get(), + 0, + RowCount_, + nullBitmapRef); + + column->Type = ColumnSchema_.LogicalType(); + column->Id = ColumnIndex_; + + TOwningColumn owner = { + .Column = std::move(column), + .NullBitmap = std::move(nullBitmapRef), + .ValueBuffer = std::move(valuesRef), + }; + + return {{owner}, owner.Column.get()}; + } + + +private: + const int ColumnIndex_; + const NTableClient::TColumnSchema ColumnSchema_; + const NTableClient::EValueType ValueType_; + + i64 RowCount_ = 0; + TBitmapOutput NullBitmap_; + std::vector<ui64> Values_; + + ui64 MaxValue_; + ui64 MinValue_; + + void Reset() + { + Values_.clear(); + RowCount_ = 0; + MaxValue_ = 0; + MinValue_ = std::numeric_limits<ui64>::max(); + NullBitmap_.Flush<TConverterTag>(); + } + + void AddValues(const std::vector<TUnversionedRowValues>& rowsValues) + { + for (auto rowValues : rowsValues) { + auto value = rowValues[ColumnIndex_]; + bool isNull = value == nullptr || value->Type == NTableClient::EValueType::Null; + ui64 data = 0; + if (!isNull) { + YT_VERIFY(value != nullptr); + data = EncodeValue(GetValue<TValue>(*value)); + } + Values_.push_back(data); + NullBitmap_.Append(isNull); + ++RowCount_; + } + } +}; + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateInt64ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TIntegerColumnConverter<i64>>(columnIndex, NTableClient::EValueType::Int64, columnSchema); +} + + +IColumnConverterPtr CreateUint64ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TIntegerColumnConverter<ui64>>(columnIndex, NTableClient::EValueType::Uint64, columnSchema); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/integer_column_converter.h b/yt/yt/library/column_converters/integer_column_converter.h new file mode 100644 index 0000000000..99b9d86342 --- /dev/null +++ b/yt/yt/library/column_converters/integer_column_converter.h @@ -0,0 +1,17 @@ +#pragma once + +#include "column_converter.h" + +#include <yt/yt/client/table_client/public.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateInt64ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema); + +std::unique_ptr<IColumnConverter> CreateUint64ColumnConverter(int columnIndex, const NTableClient::TColumnSchema& columnSchema); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/null_column_converter.cpp b/yt/yt/library/column_converters/null_column_converter.cpp new file mode 100644 index 0000000000..d07ab24ceb --- /dev/null +++ b/yt/yt/library/column_converters/null_column_converter.cpp @@ -0,0 +1,49 @@ +#include "null_column_converter.h" + +#include <yt/yt/client/table_client/logical_type.h> + +namespace NYT::NColumnConverters { + +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +class TNullColumnWriterConverter + : public IColumnConverter +{ +public: + TNullColumnWriterConverter(int columnIndex) + : ColumnIndex_(columnIndex) + { } + + TConvertedColumn Convert(const std::vector<TUnversionedRowValues>& rowsValues) override + { + auto rowCount = rowsValues.size(); + + auto column = std::make_shared<TBatchColumn>(); + + column->Id = ColumnIndex_; + column->Type = SimpleLogicalType(ESimpleLogicalValueType::Null); + column->ValueCount = rowCount; + + TOwningColumn owner = { + .Column = std::move(column), + }; + + return {{owner}, owner.Column.get()}; + } + +private: + const int ColumnIndex_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateNullConverter(int columnIndex) +{ + return std::make_unique<TNullColumnWriterConverter>(columnIndex); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/null_column_converter.h b/yt/yt/library/column_converters/null_column_converter.h new file mode 100644 index 0000000000..a8f97c84a1 --- /dev/null +++ b/yt/yt/library/column_converters/null_column_converter.h @@ -0,0 +1,13 @@ +#pragma once + +#include "column_converter.h" + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateNullConverter(int columnIndex); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/string_column_converter.cpp b/yt/yt/library/column_converters/string_column_converter.cpp new file mode 100644 index 0000000000..c8a4354c47 --- /dev/null +++ b/yt/yt/library/column_converters/string_column_converter.cpp @@ -0,0 +1,375 @@ +#include "string_column_converter.h" + +#include "helpers.h" + +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <yt/yt/core/misc/bit_packed_unsigned_vector.h> + +#include <library/cpp/yt/string/string_builder.h> + +namespace NYT::NColumnConverters { + +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void FillColumnarStringValues( + NTableClient::IUnversionedColumnarRowBatch::TColumn* column, + i64 startIndex, + i64 valueCount, + ui32 avgLength, + TRef offsets, + TRef stringData) +{ + column->StartIndex = startIndex; + column->ValueCount = valueCount; + + auto& values = column->Values.emplace(); + values.BitWidth = 32; + values.ZigZagEncoded = true; + values.Data = offsets; + + auto& strings = column->Strings.emplace(); + strings.AvgLength = avgLength; + strings.Data = stringData; +} + +bool IsValueNull(TStringBuf lhs) +{ + return !lhs.data(); +} + +//////////////////////////////////////////////////////////////////////////////// + + +template <EValueType ValueType> +class TStringConverter + : public IColumnConverter +{ +public: + TStringConverter( + int columnIndex, + const TColumnSchema& columnSchema) + : ColumnIndex_(columnIndex) + , ColumnSchema_(columnSchema) + { } + + TConvertedColumn Convert(const std::vector<TUnversionedRowValues>& rowsValues) override + { + Reset(); + AddValues(rowsValues); + return GetColumns(); + } + +private: + const int ColumnIndex_; + const TColumnSchema ColumnSchema_; + + ui32 RowCount_ = 0; + ui64 AllStringsSize_ = 0; + ui64 DictionaryByteSize_ = 0; + + std::vector<TStringBuf> Values_; + THashMap<TStringBuf, ui32> Dictionary_; + TStringBuilder DirectBuffer_; + + void Reset() + { + AllStringsSize_ = 0; + RowCount_ = 0; + DictionaryByteSize_ = 0; + + DirectBuffer_.Reset(); + Values_.clear(); + Dictionary_.clear(); + } + + TSharedRef GetDirectDenseNullBitmap() const + { + TBitmapOutput nullBitmap(Values_.size()); + + for (auto value : Values_) { + nullBitmap.Append(IsValueNull(value)); + } + + return nullBitmap.Flush<TConverterTag>(); + } + + std::vector<ui32> GetDirectDenseOffsets() const + { + std::vector<ui32> offsets; + offsets.reserve(Values_.size()); + + ui32 offset = 0; + for (auto value : Values_) { + offset += value.length(); + offsets.push_back(offset); + } + + return offsets; + } + + TConvertedColumn GetDirectColumn(TSharedRef nullBitmap) + { + auto offsets = GetDirectDenseOffsets(); + + // Save offsets as diff from expected. + ui32 expectedLength; + ui32 maxDiff; + PrepareDiffFromExpected(&offsets, &expectedLength, &maxDiff); + + auto directData = DirectBuffer_.GetBuffer(); + + auto offsetsRef = TSharedRef::MakeCopy<TConverterTag>(TRef(offsets.data(), sizeof(ui32) * offsets.size())); + auto directDataPtr = TSharedRef::MakeCopy<TConverterTag>(TRef(directData.data(), directData.size())); + auto column = std::make_shared<TBatchColumn>(); + + FillColumnarStringValues( + column.get(), + 0, + RowCount_, + expectedLength, + TRef(offsetsRef), + TRef(directDataPtr)); + + FillColumnarNullBitmap( + column.get(), + 0, + RowCount_, + TRef(nullBitmap)); + + column->Type = ColumnSchema_.LogicalType(); + column->Id = ColumnIndex_; + + TOwningColumn owner = { + .Column = std::move(column), + .NullBitmap = std::move(nullBitmap), + .ValueBuffer = std::move(offsetsRef), + .StringBuffer = std::move(directDataPtr), + }; + return {{owner}, owner.Column.get()}; + } + + TConvertedColumn GetDictionaryColumn() + { + auto dictionaryData = TSharedMutableRef::Allocate<TConverterTag>(DictionaryByteSize_, {.InitializeStorage = false}); + + std::vector<ui32> dictionaryOffsets; + dictionaryOffsets.reserve(Dictionary_.size()); + + std::vector<ui32> ids; + ids.reserve(Values_.size()); + + ui32 dictionarySize = 0; + ui32 dictionaryOffset = 0; + for (auto value : Values_) { + if (IsValueNull(value)) { + ids.push_back(0); + continue; + } + + ui32 id = GetOrCrash(Dictionary_, value); + ids.push_back(id); + + if (id > dictionarySize) { + std::memcpy( + dictionaryData.Begin() + dictionaryOffset, + value.data(), + value.length()); + dictionaryOffset += value.length(); + dictionaryOffsets.push_back(dictionaryOffset); + ++dictionarySize; + } + } + + YT_VERIFY(dictionaryOffset == DictionaryByteSize_); + + // 1. Value ids. + auto idsRef = TSharedRef::MakeCopy<TConverterTag>(TRef(ids.data(), sizeof(ui32) * ids.size())); + + // 2. Dictionary offsets. + ui32 expectedLength; + ui32 maxDiff; + PrepareDiffFromExpected(&dictionaryOffsets, &expectedLength, &maxDiff); + auto dictionaryOffsetsRef = TSharedRef::MakeCopy<TConverterTag>(TRef(dictionaryOffsets.data(), sizeof(ui32) * dictionaryOffsets.size())); + + auto primaryColumn = std::make_shared<TBatchColumn>(); + auto dictionaryColumn = std::make_shared<TBatchColumn>(); + + FillColumnarStringValues( + dictionaryColumn.get(), + 0, + dictionaryOffsets.size(), + expectedLength, + TRef(dictionaryOffsetsRef), + dictionaryData); + + FillColumnarDictionary( + primaryColumn.get(), + dictionaryColumn.get(), + NTableClient::IUnversionedColumnarRowBatch::GenerateDictionaryId(), + primaryColumn->Type, + 0, + RowCount_, + idsRef); + + dictionaryColumn->Type = ColumnSchema_.LogicalType(); + primaryColumn->Type = ColumnSchema_.LogicalType(); + primaryColumn->Id = ColumnIndex_; + + TOwningColumn dictOwner = { + .Column = std::move(dictionaryColumn), + .ValueBuffer = std::move(dictionaryOffsetsRef), + .StringBuffer = std::move(dictionaryData), + }; + + TOwningColumn primeOwner = { + .Column = std::move(primaryColumn), + .ValueBuffer = std::move(idsRef), + }; + + return {{primeOwner, dictOwner}, primeOwner.Column.get()}; + } + + TConvertedColumn GetColumns() + { + auto costs = GetEncodingMethodsCosts(); + + auto minElement = std::min_element(costs.begin(), costs.end()); + auto type = EUnversionedStringSegmentType(std::distance(costs.begin(), minElement)); + + switch (type) { + + case EUnversionedStringSegmentType::DirectDense: + return GetDirectColumn(GetDirectDenseNullBitmap()); + + case EUnversionedStringSegmentType::DictionaryDense: + return GetDictionaryColumn(); + + default: + YT_ABORT(); + } + } + + TEnumIndexedVector<EUnversionedStringSegmentType, ui64> GetEncodingMethodsCosts() const + { + TEnumIndexedVector<EUnversionedStringSegmentType, ui64> costs; + for (auto type : TEnumTraits<EUnversionedStringSegmentType>::GetDomainValues()) { + costs[type] = GetSpecificEncodingMethodCosts(type); + } + return costs; + } + + ui64 GetSpecificEncodingMethodCosts(EUnversionedStringSegmentType type) const + { + switch (type) { + case EUnversionedStringSegmentType::DictionaryDense: + return GetDictionaryByteSize(); + + case EUnversionedStringSegmentType::DirectDense: + return GetDirectByteSize(); + + default: + YT_ABORT(); + } + } + + void AddValues(const std::vector<TUnversionedRowValues>& rowsValues) + { + for (auto rowValues : rowsValues) { + auto unversionedValue = rowValues[ColumnIndex_]; + YT_VERIFY(unversionedValue != nullptr); + auto value = CaptureValue(*unversionedValue); + Values_.push_back(value); + ++RowCount_; + } + } + + ui64 GetDirectByteSize() const + { + return AllStringsSize_; + } + + ui64 GetDictionaryByteSize() const + { + return DictionaryByteSize_ + Values_.size() * sizeof(ui32); + } + + + TStringBuf CaptureValue(const TUnversionedValue& unversionedValue) + { + if (unversionedValue.Type == EValueType::Null) { + return {}; + } + + auto valueCapacity = IsAnyOrComposite(ValueType) && !IsAnyOrComposite(unversionedValue.Type) + ? GetYsonSize(unversionedValue) + : static_cast<i64>(unversionedValue.Length); + + char* buffer = DirectBuffer_.Preallocate(valueCapacity); + if (!buffer) { + // This means, that we reserved nothing, because all strings are either null or empty. + // To distinguish between null and empty, we set preallocated pointer to special value. + static char* const EmptyStringBase = reinterpret_cast<char*>(1); + buffer = EmptyStringBase; + } + + auto start = buffer; + + if (IsAnyOrComposite(ValueType) && !IsAnyOrComposite(unversionedValue.Type)) { + // Any non-any and non-null value convert to YSON. + buffer += WriteYson(buffer, unversionedValue); + } else { + std::memcpy( + buffer, + unversionedValue.Data.String, + unversionedValue.Length); + buffer += unversionedValue.Length; + } + + auto value = TStringBuf(start, buffer); + + YT_VERIFY(value.size() <= valueCapacity); + + DirectBuffer_.Advance(value.size()); + + if (Dictionary_.emplace(value, Dictionary_.size() + 1).second) { + DictionaryByteSize_ += value.size(); + } + AllStringsSize_ += value.size(); + return value; + } +}; + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateStringConverter( + int columnIndex, + const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TStringConverter<EValueType::String>>(columnIndex, columnSchema); +} + +IColumnConverterPtr CreateAnyConverter( + int columnIndex, + const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TStringConverter<EValueType::Any>>(columnIndex, columnSchema); +} + +IColumnConverterPtr CreateCompositeConverter( + int columnIndex, + const NTableClient::TColumnSchema& columnSchema) +{ + return std::make_unique<TStringConverter<EValueType::Composite>>(columnIndex, columnSchema); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/string_column_converter.h b/yt/yt/library/column_converters/string_column_converter.h new file mode 100644 index 0000000000..b9c3d2bdf7 --- /dev/null +++ b/yt/yt/library/column_converters/string_column_converter.h @@ -0,0 +1,25 @@ +#pragma once + +#include "column_converter.h" + +#include <yt/yt/client/table_client/public.h> + +namespace NYT::NColumnConverters { + +//////////////////////////////////////////////////////////////////////////////// + +IColumnConverterPtr CreateStringConverter( + int columnIndex, + const NTableClient::TColumnSchema& columnSchema); + +IColumnConverterPtr CreateAnyConverter( + int columnIndex, + const NTableClient::TColumnSchema& columnSchema); + +IColumnConverterPtr CreateCompositeConverter( + int columnIndex, + const NTableClient::TColumnSchema& columnSchema); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NColumnConverters diff --git a/yt/yt/library/column_converters/ya.make b/yt/yt/library/column_converters/ya.make new file mode 100644 index 0000000000..55cd9f86c0 --- /dev/null +++ b/yt/yt/library/column_converters/ya.make @@ -0,0 +1,19 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + boolean_column_converter.cpp + column_converter.cpp + floating_point_column_converter.cpp + helpers.cpp + integer_column_converter.cpp + null_column_converter.cpp + string_column_converter.cpp +) + +PEERDIR( + yt/yt/core +) + +END() diff --git a/yt/yt/library/containers/cgroup.cpp b/yt/yt/library/containers/cgroup.cpp new file mode 100644 index 0000000000..b0d46f732a --- /dev/null +++ b/yt/yt/library/containers/cgroup.cpp @@ -0,0 +1,752 @@ +#include "cgroup.h" +#include "private.h" + +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/ytree/fluent.h> + +#include <util/string/split.h> +#include <util/system/filemap.h> + +#include <util/system/yield.h> + +#ifdef _linux_ + #include <unistd.h> + #include <sys/stat.h> + #include <errno.h> +#endif + +namespace NYT::NContainers { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = ContainersLogger; +static const TString CGroupRootPath("/sys/fs/cgroup"); +#ifdef _linux_ +static const int ReadByAll = S_IRUSR | S_IRGRP | S_IROTH; +static const int ReadExecuteByAll = ReadByAll | S_IXUSR | S_IXGRP | S_IXOTH; +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +TString GetParentFor(const TString& type) +{ +#ifdef _linux_ + auto rawData = TUnbufferedFileInput("/proc/self/cgroup") + .ReadAll(); + auto result = ParseProcessCGroups(rawData); + return result[type]; +#else + Y_UNUSED(type); + return "_parent_"; +#endif +} + +#ifdef _linux_ + +std::vector<TString> ReadAllValues(const TString& fileName) +{ + auto raw = TUnbufferedFileInput(fileName) + .ReadAll(); + + YT_LOG_DEBUG("File %v contains %Qv", + fileName, + raw); + + TVector<TString> values; + StringSplitter(raw.data()) + .SplitBySet(" \n") + .SkipEmpty() + .Collect(&values); + return values; +} + +TDuration FromJiffies(ui64 jiffies) +{ + static const auto TicksPerSecond = sysconf(_SC_CLK_TCK); + return TDuration::MicroSeconds(1000 * 1000 * jiffies / TicksPerSecond); +} + +#endif + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +void TKillProcessGroupTool::operator()(const TString& processGroupPath) const +{ + SafeSetUid(0); + TNonOwningCGroup group(processGroupPath); + group.Kill(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TNonOwningCGroup::TNonOwningCGroup(const TString& fullPath) + : FullPath_(fullPath) +{ } + +TNonOwningCGroup::TNonOwningCGroup(const TString& type, const TString& name) + : FullPath_(NFS::CombinePaths({ + CGroupRootPath, + type, + GetParentFor(type), + name + })) +{ } + +TNonOwningCGroup::TNonOwningCGroup(TNonOwningCGroup&& other) + : FullPath_(std::move(other.FullPath_)) +{ } + +void TNonOwningCGroup::AddTask(int pid) const +{ + YT_LOG_INFO( + "Adding task to cgroup (Task: %v, Cgroup: %v)", + pid, + FullPath_); + Append("tasks", ToString(pid)); +} + +void TNonOwningCGroup::AddCurrentTask() const +{ + YT_VERIFY(!IsNull()); +#ifdef _linux_ + auto pid = getpid(); + AddTask(pid); +#endif +} + +TString TNonOwningCGroup::Get(const TString& name) const +{ + YT_VERIFY(!IsNull()); + TString result; +#ifdef _linux_ + const auto path = GetPath(name); + result = TFileInput(path).ReadLine(); +#else + Y_UNUSED(name); +#endif + return result; +} + +void TNonOwningCGroup::Set(const TString& name, const TString& value) const +{ + YT_VERIFY(!IsNull()); +#ifdef _linux_ + auto path = GetPath(name); + TUnbufferedFileOutput output(TFile(path, EOpenModeFlag::WrOnly)); + output << value; +#else + Y_UNUSED(name); + Y_UNUSED(value); +#endif +} + +void TNonOwningCGroup::Append(const TString& name, const TString& value) const +{ + YT_VERIFY(!IsNull()); +#ifdef _linux_ + auto path = GetPath(name); + TUnbufferedFileOutput output(TFile(path, EOpenModeFlag::ForAppend)); + output << value; +#else + Y_UNUSED(name); + Y_UNUSED(value); +#endif +} + +bool TNonOwningCGroup::IsRoot() const +{ + return FullPath_ == CGroupRootPath; +} + +bool TNonOwningCGroup::IsNull() const +{ + return FullPath_.empty(); +} + +bool TNonOwningCGroup::Exists() const +{ + return NFS::Exists(FullPath_); +} + +std::vector<int> TNonOwningCGroup::GetProcesses() const +{ + std::vector<int> results; + if (!IsNull()) { +#ifdef _linux_ + auto values = ReadAllValues(GetPath("cgroup.procs")); + for (const auto& value : values) { + int pid = FromString<int>(value); + results.push_back(pid); + } +#endif + } + return results; +} + +std::vector<int> TNonOwningCGroup::GetTasks() const +{ + std::vector<int> results; + if (!IsNull()) { +#ifdef _linux_ + auto values = ReadAllValues(GetPath("tasks")); + for (const auto& value : values) { + int pid = FromString<int>(value); + results.push_back(pid); + } +#endif + } + return results; +} + +const TString& TNonOwningCGroup::GetFullPath() const +{ + return FullPath_; +} + +std::vector<TNonOwningCGroup> TNonOwningCGroup::GetChildren() const +{ + // We retry enumerating directories, since it may fail with weird diagnostics if + // number of subcgroups changes. + while (true) { + try { + std::vector<TNonOwningCGroup> result; + + if (IsNull()) { + return result; + } + + auto directories = NFS::EnumerateDirectories(FullPath_); + for (const auto& directory : directories) { + result.emplace_back(NFS::CombinePaths(FullPath_, directory)); + } + return result; + } catch (const std::exception& ex) { + YT_LOG_WARNING(ex, "Failed to list subcgroups (Path: %v)", FullPath_); + } + } +} + +void TNonOwningCGroup::EnsureExistance() const +{ + YT_LOG_INFO("Creating cgroup (Cgroup: %v)", FullPath_); + + YT_VERIFY(!IsNull()); + +#ifdef _linux_ + NFS::MakeDirRecursive(FullPath_, 0755); +#endif +} + +void TNonOwningCGroup::Lock() const +{ + Traverse( + BIND([] (const TNonOwningCGroup& group) { group.DoLock(); }), + BIND([] (const TNonOwningCGroup& /*group*/) {})); +} + +void TNonOwningCGroup::Unlock() const +{ + Traverse( + BIND([] (const TNonOwningCGroup& /*group*/) {}), + BIND([] (const TNonOwningCGroup& group) { group.DoUnlock(); })); +} + +void TNonOwningCGroup::Kill() const +{ + YT_VERIFY(!IsRoot()); + + Traverse( + BIND([] (const TNonOwningCGroup& group) { group.DoKill(); }), + BIND([] (const TNonOwningCGroup& /*group*/) {})); +} + +void TNonOwningCGroup::RemoveAllSubcgroups() const +{ + Traverse( + BIND([] (const TNonOwningCGroup& group) { + group.TryUnlock(); + }), + BIND([this_ = this] (const TNonOwningCGroup& group) { + if (this_ != &group) { + group.DoRemove(); + } + })); +} + +void TNonOwningCGroup::RemoveRecursive() const +{ + RemoveAllSubcgroups(); + DoRemove(); +} + +void TNonOwningCGroup::DoLock() const +{ + YT_LOG_INFO("Locking cgroup (Cgroup: %v)", FullPath_); + +#ifdef _linux_ + if (!IsNull()) { + int code = chmod(FullPath_.data(), ReadExecuteByAll); + YT_VERIFY(code == 0); + + code = chmod(GetPath("tasks").data(), ReadByAll); + YT_VERIFY(code == 0); + } +#endif +} + +bool TNonOwningCGroup::TryUnlock() const +{ + YT_LOG_INFO("Unlocking cgroup (Cgroup: %v)", FullPath_); + + if (!Exists()) { + return true; + } + + bool result = true; + +#ifdef _linux_ + if (!IsNull()) { + int code = chmod(GetPath("tasks").data(), ReadByAll | S_IWUSR); + if (code != 0) { + result = false; + } + + code = chmod(FullPath_.data(), ReadExecuteByAll | S_IWUSR); + if (code != 0) { + result = false; + } + } +#endif + + return result; +} + +void TNonOwningCGroup::DoUnlock() const +{ + YT_VERIFY(TryUnlock()); +} + +void TNonOwningCGroup::DoKill() const +{ + YT_LOG_DEBUG("Started killing processes in cgroup (Cgroup: %v)", FullPath_); + +#ifdef _linux_ + while (true) { + auto pids = GetTasks(); + if (pids.empty()) + break; + + YT_LOG_DEBUG("Killing processes (Pids: %v)", pids); + + for (int pid : pids) { + auto result = kill(pid, SIGKILL); + if (result == -1) { + YT_VERIFY(errno == ESRCH); + } + } + + ThreadYield(); + } +#endif + + YT_LOG_DEBUG("Finished killing processes in cgroup (Cgroup: %v)", FullPath_); +} + +void TNonOwningCGroup::DoRemove() const +{ + if (NFS::Exists(FullPath_)) { + NFS::Remove(FullPath_); + } +} + +void TNonOwningCGroup::Traverse( + const TCallback<void(const TNonOwningCGroup&)>& preorderAction, + const TCallback<void(const TNonOwningCGroup&)>& postorderAction) const +{ + preorderAction(*this); + + for (const auto& child : GetChildren()) { + child.Traverse(preorderAction, postorderAction); + } + + postorderAction(*this); +} + +TString TNonOwningCGroup::GetPath(const TString& filename) const +{ + return NFS::CombinePaths(FullPath_, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +TCGroup::TCGroup(const TString& type, const TString& name) + : TNonOwningCGroup(type, name) +{ } + +TCGroup::TCGroup(TCGroup&& other) + : TNonOwningCGroup(std::move(other)) + , Created_(other.Created_) +{ + other.Created_ = false; +} + +TCGroup::TCGroup(TNonOwningCGroup&& other) + : TNonOwningCGroup(std::move(other)) + , Created_(false) +{ } + +TCGroup::~TCGroup() +{ + if (Created_) { + Destroy(); + } +} + +void TCGroup::Create() +{ + EnsureExistance(); + Created_ = true; +} + +void TCGroup::Destroy() +{ + YT_LOG_INFO("Destroying cgroup (Cgroup: %v)", FullPath_); + YT_VERIFY(Created_); + +#ifdef _linux_ + try { + NFS::Remove(FullPath_); + } catch (const std::exception& ex) { + YT_LOG_FATAL(ex, "Failed to destroy cgroup (Cgroup: %v)", FullPath_); + } +#endif + Created_ = false; +} + +bool TCGroup::IsCreated() const +{ + return Created_; +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TCpuAccounting::Name = "cpuacct"; + +TCpuAccounting::TStatistics& operator-=(TCpuAccounting::TStatistics& lhs, const TCpuAccounting::TStatistics& rhs) +{ + #define XX(name) lhs.name = lhs.name.ValueOrThrow() - rhs.name.ValueOrThrow(); + XX(UserUsageTime) + XX(SystemUsageTime) + XX(WaitTime) + XX(ThrottledTime) + XX(ContextSwitchesDelta) + XX(PeakThreadCount) + #undef XX + return lhs; +} + +TCpuAccounting::TCpuAccounting(const TString& name) + : TCGroup(Name, name) +{ } + +TCpuAccounting::TCpuAccounting(TNonOwningCGroup&& nonOwningCGroup) + : TCGroup(std::move(nonOwningCGroup)) +{ } + +TCpuAccounting::TStatistics TCpuAccounting::GetStatisticsRecursive() const +{ + TCpuAccounting::TStatistics result; +#ifdef _linux_ + try { + auto path = NFS::CombinePaths(GetFullPath(), "cpuacct.stat"); + auto values = ReadAllValues(path); + YT_VERIFY(values.size() == 4); + + TString type[2]; + ui64 jiffies[2]; + + for (int i = 0; i < 2; ++i) { + type[i] = values[2 * i]; + jiffies[i] = FromString<ui64>(values[2 * i + 1]); + } + + for (int i = 0; i < 2; ++i) { + if (type[i] == "user") { + result.UserUsageTime = FromJiffies(jiffies[i]); + } else if (type[i] == "system") { + result.SystemUsageTime = FromJiffies(jiffies[i]); + } + } + } catch (const std::exception& ex) { + YT_LOG_FATAL( + ex, + "Failed to retreive CPU statistics from cgroup (Cgroup: %v)", + GetFullPath()); + } +#endif + return result; +} + +TCpuAccounting::TStatistics TCpuAccounting::GetStatistics() const +{ + auto statistics = GetStatisticsRecursive(); + for (auto& cgroup : GetChildren()) { + auto cpuCGroup = TCpuAccounting(std::move(cgroup)); + statistics -= cpuCGroup.GetStatisticsRecursive(); + } + return statistics; +} + + +//////////////////////////////////////////////////////////////////////////////// + +const TString TCpu::Name = "cpu"; + +static const int DefaultCpuShare = 1024; + +TCpu::TCpu(const TString& name) + : TCGroup(Name, name) +{ } + +void TCpu::SetShare(double share) +{ + int cpuShare = static_cast<int>(share * DefaultCpuShare); + Set("cpu.shares", ToString(cpuShare)); +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TBlockIO::Name = "blkio"; + +TBlockIO::TBlockIO(const TString& name) + : TCGroup(Name, name) +{ } + +// For more information about format of data +// read https://www.kernel.org/doc/Documentation/cgroups/blkio-controller.txt + +TBlockIO::TStatistics TBlockIO::GetStatistics() const +{ + TBlockIO::TStatistics result; +#ifdef _linux_ + auto bytesStats = GetDetailedStatistics("blkio.io_service_bytes"); + for (const auto& item : bytesStats) { + if (item.Type == "Read") { + result.IOReadByte = result.IOReadByte.ValueOrThrow() + item.Value; + } else if (item.Type == "Write") { + result.IOWriteByte = result.IOReadByte.ValueOrThrow() + item.Value; + } + } + + auto ioStats = GetDetailedStatistics("blkio.io_serviced"); + for (const auto& item : ioStats) { + if (item.Type == "Read") { + result.IOReadOps = result.IOReadOps.ValueOrThrow() + item.Value; + result.IOOps = result.IOOps.ValueOrThrow() + item.Value; + } else if (item.Type == "Write") { + result.IOWriteOps = result.IOWriteOps.ValueOrThrow() + item.Value; + result.IOOps = result.IOOps.ValueOrThrow() + item.Value; + } + } +#endif + return result; +} + +std::vector<TBlockIO::TStatisticsItem> TBlockIO::GetIOServiceBytes() const +{ + return GetDetailedStatistics("blkio.io_service_bytes"); +} + +std::vector<TBlockIO::TStatisticsItem> TBlockIO::GetIOServiced() const +{ + return GetDetailedStatistics("blkio.io_serviced"); +} + +std::vector<TBlockIO::TStatisticsItem> TBlockIO::GetDetailedStatistics(const char* filename) const +{ + std::vector<TBlockIO::TStatisticsItem> result; +#ifdef _linux_ + try { + auto path = NFS::CombinePaths(GetFullPath(), filename); + auto values = ReadAllValues(path); + + int lineNumber = 0; + while (3 * lineNumber + 2 < std::ssize(values)) { + TStatisticsItem item; + item.DeviceId = values[3 * lineNumber]; + item.Type = values[3 * lineNumber + 1]; + item.Value = FromString<ui64>(values[3 * lineNumber + 2]); + + { + auto guard = Guard(SpinLock_); + DeviceIds_.insert(item.DeviceId); + } + + if (item.Type == "Read" || item.Type == "Write") { + result.push_back(item); + + YT_LOG_DEBUG("IO operations serviced (OperationCount: %v, OperationType: %v, DeviceId: %v)", + item.Value, + item.Type, + item.DeviceId); + } + ++lineNumber; + } + } catch (const std::exception& ex) { + YT_LOG_FATAL( + ex, + "Failed to retreive block IO statistics from cgroup (Cgroup: %v)", + GetFullPath()); + } +#else + Y_UNUSED(filename); +#endif + return result; +} + +void TBlockIO::ThrottleOperations(i64 operations) const +{ + auto guard = Guard(SpinLock_); + for (const auto& deviceId : DeviceIds_) { + auto value = Format("%v %v", deviceId, operations); + Append("blkio.throttle.read_iops_device", value); + Append("blkio.throttle.write_iops_device", value); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TMemory::Name = "memory"; + +TMemory::TMemory(const TString& name) + : TCGroup(Name, name) +{ } + +TMemory::TStatistics TMemory::GetStatistics() const +{ + TMemory::TStatistics result; +#ifdef _linux_ + try { + auto values = ReadAllValues(GetPath("memory.stat")); + int lineNumber = 0; + while (2 * lineNumber + 1 < std::ssize(values)) { + const auto& type = values[2 * lineNumber]; + const auto& unparsedValue = values[2 * lineNumber + 1]; + if (type == "rss") { + result.Rss = FromString<ui64>(unparsedValue); + } + if (type == "mapped_file") { + result.MappedFile = FromString<ui64>(unparsedValue); + } + if (type == "pgmajfault") { + result.MajorPageFaults = FromString<ui64>(unparsedValue); + } + ++lineNumber; + } + } catch (const std::exception& ex) { + YT_LOG_FATAL( + ex, + "Failed to retreive memory statistics from cgroup (Cgroup: %v)", + GetFullPath()); + } +#endif + return result; +} + +i64 TMemory::GetMaxMemoryUsage() const +{ + return FromString<i64>(Get("memory.max_usage_in_bytes")); +} + +void TMemory::SetLimitInBytes(i64 bytes) const +{ + Set("memory.limit_in_bytes", ToString(bytes)); +} + +void TMemory::ForceEmpty() const +{ + Set("memory.force_empty", "0"); +} + +//////////////////////////////////////////////////////////////////////////////// + +const TString TFreezer::Name = "freezer"; + +TFreezer::TFreezer(const TString& name) + : TCGroup(Name, name) +{ } + +TString TFreezer::GetState() const +{ + return Get("freezer.state"); +} + +void TFreezer::Freeze() const +{ + Set("freezer.state", "FROZEN"); +} + +void TFreezer::Unfreeze() const +{ + Set("freezer.state", "THAWED"); +} + +//////////////////////////////////////////////////////////////////////////////// + +std::map<TString, TString> ParseProcessCGroups(const TString& str) +{ + std::map<TString, TString> result; + + TVector<TString> values; + StringSplitter(str.data()).SplitBySet(":\n").SkipEmpty().Collect(&values); + for (size_t i = 0; i + 2 < values.size(); i += 3) { + // Check format. + FromString<int>(values[i]); + + const auto& subsystemsSet = values[i + 1]; + const auto& name = values[i + 2]; + + TVector<TString> subsystems; + StringSplitter(subsystemsSet.data()).Split(',').SkipEmpty().Collect(&subsystems); + for (const auto& subsystem : subsystems) { + if (!subsystem.StartsWith("name=")) { + int start = 0; + if (name.StartsWith("/")) { + start = 1; + } + result[subsystem] = name.substr(start); + } + } + } + + return result; +} + +std::map<TString, TString> GetProcessCGroups(pid_t pid) +{ + auto cgroupsPath = Format("/proc/%v/cgroup", pid); + auto rawCgroups = TFileInput{cgroupsPath}.ReadAll(); + return ParseProcessCGroups(rawCgroups); +} + +bool IsValidCGroupType(const TString& type) +{ + return + type == TCpuAccounting::Name || + type == TCpu::Name || + type == TBlockIO::Name || + type == TMemory::Name || + type == TFreezer::Name; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/cgroup.h b/yt/yt/library/containers/cgroup.h new file mode 100644 index 0000000000..a69f8f8872 --- /dev/null +++ b/yt/yt/library/containers/cgroup.h @@ -0,0 +1,290 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/public.h> + +#include <yt/yt/core/ytree/yson_struct.h> +#include <yt/yt/core/yson/public.h> + +#include <yt/yt/core/misc/property.h> + +#include <library/cpp/yt/threading/spin_lock.h> + +#include <vector> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +void RemoveAllSubcgroups(const TString& path); + +//////////////////////////////////////////////////////////////////////////////// + +struct TKillProcessGroupTool +{ + void operator()(const TString& processGroupPath) const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TNonOwningCGroup + : private TNonCopyable +{ +public: + DEFINE_BYREF_RO_PROPERTY(TString, FullPath); + +public: + TNonOwningCGroup() = default; + explicit TNonOwningCGroup(const TString& fullPath); + TNonOwningCGroup(const TString& type, const TString& name); + TNonOwningCGroup(TNonOwningCGroup&& other); + + void AddTask(int pid) const; + void AddCurrentTask() const; + + bool IsRoot() const; + bool IsNull() const; + bool Exists() const; + + std::vector<int> GetProcesses() const; + std::vector<int> GetTasks() const; + const TString& GetFullPath() const; + + std::vector<TNonOwningCGroup> GetChildren() const; + + void EnsureExistance() const; + + void Lock() const; + void Unlock() const; + + void Kill() const; + + void RemoveAllSubcgroups() const; + void RemoveRecursive() const; + +protected: + TString Get(const TString& name) const; + void Set(const TString& name, const TString& value) const; + void Append(const TString& name, const TString& value) const; + + void DoLock() const; + void DoUnlock() const; + + bool TryUnlock() const; + + void DoKill() const; + + void DoRemove() const; + + void Traverse( + const TCallback<void(const TNonOwningCGroup&)>& preorderAction, + const TCallback<void(const TNonOwningCGroup&)>& postorderAction) const; + + TString GetPath(const TString& filename) const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TCGroup + : public TNonOwningCGroup +{ +protected: + TCGroup(const TString& type, const TString& name); + TCGroup(TNonOwningCGroup&& other); + TCGroup(TCGroup&& other); + +public: + ~TCGroup(); + + void Create(); + void Destroy(); + + bool IsCreated() const; + +private: + bool Created_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TCpuAccounting + : public TCGroup +{ +public: + static const TString Name; + + struct TStatistics + { + TErrorOr<TDuration> TotalUsageTime; + TErrorOr<TDuration> UserUsageTime; + TErrorOr<TDuration> SystemUsageTime; + TErrorOr<TDuration> WaitTime; + TErrorOr<TDuration> ThrottledTime; + + TErrorOr<ui64> ThreadCount; + TErrorOr<ui64> ContextSwitches; + TErrorOr<ui64> ContextSwitchesDelta; + TErrorOr<ui64> PeakThreadCount; + + TErrorOr<TDuration> LimitTime; + TErrorOr<TDuration> GuaranteeTime; + }; + + explicit TCpuAccounting(const TString& name); + + TStatistics GetStatisticsRecursive() const; + TStatistics GetStatistics() const; + +private: + explicit TCpuAccounting(TNonOwningCGroup&& nonOwningCGroup); +}; + +void Serialize(const TCpuAccounting::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TCpu + : public TCGroup +{ +public: + static const TString Name; + + explicit TCpu(const TString& name); + + void SetShare(double share); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TBlockIO + : public TCGroup +{ +public: + static const TString Name; + + struct TStatistics + { + TErrorOr<ui64> IOReadByte; + TErrorOr<ui64> IOWriteByte; + TErrorOr<ui64> IOBytesLimit; + + TErrorOr<ui64> IOReadOps; + TErrorOr<ui64> IOWriteOps; + TErrorOr<ui64> IOOps; + TErrorOr<ui64> IOOpsLimit; + + TErrorOr<TDuration> IOTotalTime; + TErrorOr<TDuration> IOWaitTime; + }; + + struct TStatisticsItem + { + TString DeviceId; + TString Type; + ui64 Value = 0; + }; + + explicit TBlockIO(const TString& name); + + TStatistics GetStatistics() const; + void ThrottleOperations(i64 iops) const; + +private: + //! Guards device ids. + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + //! Set of all seen device ids. + mutable THashSet<TString> DeviceIds_; + + std::vector<TBlockIO::TStatisticsItem> GetDetailedStatistics(const char* filename) const; + + std::vector<TStatisticsItem> GetIOServiceBytes() const; + std::vector<TStatisticsItem> GetIOServiced() const; +}; + +void Serialize(const TBlockIO::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TMemory + : public TCGroup +{ +public: + static const TString Name; + + struct TStatistics + { + TErrorOr<ui64> Rss; + TErrorOr<ui64> MappedFile; + TErrorOr<ui64> MinorPageFaults; + TErrorOr<ui64> MajorPageFaults; + + TErrorOr<ui64> FileCacheUsage; + TErrorOr<ui64> AnonUsage; + TErrorOr<ui64> AnonLimit; + TErrorOr<ui64> MemoryUsage; + TErrorOr<ui64> MemoryGuarantee; + TErrorOr<ui64> MemoryLimit; + TErrorOr<ui64> MaxMemoryUsage; + + TErrorOr<ui64> OomKills; + TErrorOr<ui64> OomKillsTotal; + }; + + explicit TMemory(const TString& name); + + TStatistics GetStatistics() const; + i64 GetMaxMemoryUsage() const; + + void SetLimitInBytes(i64 bytes) const; + + void ForceEmpty() const; +}; + +void Serialize(const TMemory::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TNetwork +{ +public: + struct TStatistics + { + TErrorOr<ui64> TxBytes; + TErrorOr<ui64> TxPackets; + TErrorOr<ui64> TxDrops; + TErrorOr<ui64> TxLimit; + + TErrorOr<ui64> RxBytes; + TErrorOr<ui64> RxPackets; + TErrorOr<ui64> RxDrops; + TErrorOr<ui64> RxLimit; + }; +}; + +void Serialize(const TNetwork::TStatistics& statistics, NYson::IYsonConsumer* consumer); + +//////////////////////////////////////////////////////////////////////////////// + +class TFreezer + : public TCGroup +{ +public: + static const TString Name; + + explicit TFreezer(const TString& name); + + TString GetState() const; + void Freeze() const; + void Unfreeze() const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +std::map<TString, TString> ParseProcessCGroups(const TString& str); +std::map<TString, TString> GetProcessCGroups(pid_t pid); +bool IsValidCGroupType(const TString& type); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/config.cpp b/yt/yt/library/containers/config.cpp new file mode 100644 index 0000000000..39e46f2372 --- /dev/null +++ b/yt/yt/library/containers/config.cpp @@ -0,0 +1,64 @@ +#include "config.h" + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +void TPodSpecConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cpu_to_vcpu_factor", &TThis::CpuToVCpuFactor) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +bool TCGroupConfig::IsCGroupSupported(const TString& cgroupType) const +{ + auto it = std::find_if( + SupportedCGroups.begin(), + SupportedCGroups.end(), + [&] (const TString& type) { + return type == cgroupType; + }); + return it != SupportedCGroups.end(); +} + +void TCGroupConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("supported_cgroups", &TThis::SupportedCGroups) + .Default(); + + registrar.Postprocessor([] (TThis* config) { + for (const auto& type : config->SupportedCGroups) { + if (!IsValidCGroupType(type)) { + THROW_ERROR_EXCEPTION("Invalid cgroup type %Qv", type); + } + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TPortoExecutorDynamicConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("retries_timeout", &TThis::RetriesTimeout) + .Default(TDuration::Seconds(10)); + registrar.Parameter("poll_period", &TThis::PollPeriod) + .Default(TDuration::MilliSeconds(100)); + registrar.Parameter("api_timeout", &TThis::ApiTimeout) + .Default(TDuration::Minutes(5)); + registrar.Parameter("api_disk_timeout", &TThis::ApiDiskTimeout) + .Default(TDuration::Minutes(30)); + registrar.Parameter("enable_network_isolation", &TThis::EnableNetworkIsolation) + .Default(true); + registrar.Parameter("enable_test_porto_failures", &TThis::EnableTestPortoFailures) + .Default(false); + registrar.Parameter("stub_error_code", &TThis::StubErrorCode) + .Default(EPortoErrorCode::SocketError); + registrar.Parameter("enable_test_porto_not_responding", &TThis::EnableTestPortoNotResponding) + .Default(false); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/config.h b/yt/yt/library/containers/config.h new file mode 100644 index 0000000000..3639274cff --- /dev/null +++ b/yt/yt/library/containers/config.h @@ -0,0 +1,64 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TPodSpecConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::optional<double> CpuToVCpuFactor; + + REGISTER_YSON_STRUCT(TPodSpecConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TPodSpecConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TCGroupConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::vector<TString> SupportedCGroups; + + bool IsCGroupSupported(const TString& cgroupType) const; + + REGISTER_YSON_STRUCT(TCGroupConfig); + + static void Register(TRegistrar registrar); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoExecutorDynamicConfig + : public NYTree::TYsonStruct +{ +public: + TDuration RetriesTimeout; + TDuration PollPeriod; + TDuration ApiTimeout; + TDuration ApiDiskTimeout; + bool EnableNetworkIsolation; + bool EnableTestPortoFailures; + bool EnableTestPortoNotResponding; + + EPortoErrorCode StubErrorCode; + + REGISTER_YSON_STRUCT(TPortoExecutorDynamicConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TPortoExecutorDynamicConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/instance.cpp b/yt/yt/library/containers/instance.cpp new file mode 100644 index 0000000000..757ba766a0 --- /dev/null +++ b/yt/yt/library/containers/instance.cpp @@ -0,0 +1,803 @@ +#ifdef __linux__ + +#include "instance.h" + +#include "porto_executor.h" +#include "private.h" + +#include <yt/yt/library/containers/cgroup.h> +#include <yt/yt/library/containers/config.h> + +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/collection_helpers.h> +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/proc.h> + +#include <library/cpp/porto/libporto.hpp> + +#include <util/stream/file.h> + +#include <util/string/cast.h> +#include <util/string/split.h> + +#include <util/system/env.h> + +#include <initializer_list> +#include <string> + +namespace NYT::NContainers { + +using namespace NConcurrency; +using namespace NNet; + +//////////////////////////////////////////////////////////////////////////////// + +namespace NDetail { + +// Porto passes command string to wordexp, where quota (') symbol +// is delimiter. So we must replace it with concatenation ('"'"'). +TString EscapeForWordexp(const char* in) +{ + TString buffer; + while (*in) { + if (*in == '\'') { + buffer.append(R"('"'"')"); + } else { + buffer.append(*in); + } + in++; + } + return buffer; +} + +i64 Extract( + const TString& input, + const TString& pattern, + const TString& terminator = "\n") +{ + auto start = input.find(pattern) + pattern.length(); + auto end = input.find(terminator, start); + return std::stol(input.substr(start, (end == input.npos) ? end : end - start)); +} + +i64 ExtractSum( + const TString& input, + const TString& pattern, + const TString& delimiter, + const TString& terminator = "\n") +{ + i64 sum = 0; + TString::size_type pos = 0; + while (pos < input.length()) { + pos = input.find(pattern, pos); + if (pos == input.npos) { + break; + } + pos += pattern.length(); + + pos = input.find(delimiter, pos); + if (pos == input.npos) { + break; + } + + pos++; + auto end = input.find(terminator, pos); + sum += std::stol(input.substr(pos, (end == input.npos) ? end : end - pos)); + } + return sum; +} + +using TPortoStatRule = std::pair<TString, std::function<i64(const TString& input)>>; + +static const std::function<i64(const TString&)> LongExtractor = [] (const TString& in) { + return std::stol(in); +}; + +static const std::function<i64(const TString&)> CoreNsPerSecondExtractor = [] (const TString& in) { + int pos = in.find("c", 0); + return (std::stod(in.substr(0, pos))) * 1'000'000'000; +}; + +static const std::function<i64(const TString&)> GetIOStatExtractor(const TString& rwMode = "") +{ + return [rwMode] (const TString& in) { + return ExtractSum(in, "hw", rwMode + ":", ";"); + }; +} + +static const std::function<i64(const TString&)> GetStatByKeyExtractor(const TString& statKey) +{ + return [statKey] (const TString& in) { + return Extract(in, statKey); + }; +} + +const THashMap<EStatField, TPortoStatRule> PortoStatRules = { + {EStatField::CpuUsage, {"cpu_usage", LongExtractor}}, + {EStatField::CpuSystemUsage, {"cpu_usage_system", LongExtractor}}, + {EStatField::CpuWait, {"cpu_wait", LongExtractor}}, + {EStatField::CpuThrottled, {"cpu_throttled", LongExtractor}}, + {EStatField::ThreadCount, {"thread_count", LongExtractor}}, + {EStatField::CpuLimit, {"cpu_limit_bound", CoreNsPerSecondExtractor}}, + {EStatField::CpuGuarantee, {"cpu_guarantee_bound", CoreNsPerSecondExtractor}}, + {EStatField::Rss, {"memory.stat", GetStatByKeyExtractor("total_rss")}}, + {EStatField::MappedFile, {"memory.stat", GetStatByKeyExtractor("total_mapped_file")}}, + {EStatField::MinorPageFaults, {"minor_faults", LongExtractor}}, + {EStatField::MajorPageFaults, {"major_faults", LongExtractor}}, + {EStatField::FileCacheUsage, {"cache_usage", LongExtractor}}, + {EStatField::AnonMemoryUsage, {"anon_usage", LongExtractor}}, + {EStatField::AnonMemoryLimit, {"anon_limit_total", LongExtractor}}, + {EStatField::MemoryUsage, {"memory_usage", LongExtractor}}, + {EStatField::MemoryGuarantee, {"memory_guarantee", LongExtractor}}, + {EStatField::MemoryLimit, {"memory_limit_total", LongExtractor}}, + {EStatField::MaxMemoryUsage, {"memory.max_usage_in_bytes", LongExtractor}}, + {EStatField::OomKills, {"oom_kills", LongExtractor}}, + {EStatField::OomKillsTotal, {"oom_kills_total", LongExtractor}}, + + {EStatField::IOReadByte, {"io_read", GetIOStatExtractor()}}, + {EStatField::IOWriteByte, {"io_write", GetIOStatExtractor()}}, + {EStatField::IOBytesLimit, {"io_limit", GetIOStatExtractor()}}, + {EStatField::IOReadOps, {"io_read_ops", GetIOStatExtractor()}}, + {EStatField::IOWriteOps, {"io_write_ops", GetIOStatExtractor()}}, + {EStatField::IOOps, {"io_ops", GetIOStatExtractor()}}, + {EStatField::IOOpsLimit, {"io_ops_limit", GetIOStatExtractor()}}, + {EStatField::IOTotalTime, {"io_time", GetIOStatExtractor()}}, + {EStatField::IOWaitTime, {"io_wait", GetIOStatExtractor()}}, + + {EStatField::NetTxBytes, {"net_tx_bytes[veth]", LongExtractor}}, + {EStatField::NetTxPackets, {"net_tx_packets[veth]", LongExtractor}}, + {EStatField::NetTxDrops, {"net_tx_drops[veth]", LongExtractor}}, + {EStatField::NetTxLimit, {"net_limit[veth]", LongExtractor}}, + {EStatField::NetRxBytes, {"net_rx_bytes[veth]", LongExtractor}}, + {EStatField::NetRxPackets, {"net_rx_packets[veth]", LongExtractor}}, + {EStatField::NetRxDrops, {"net_rx_drops[veth]", LongExtractor}}, + {EStatField::NetRxLimit, {"net_rx_limit[veth]", LongExtractor}}, +}; + +std::optional<TString> GetParentName(const TString& name) +{ + if (name.empty()) { + return std::nullopt; + } + + auto slashPosition = name.rfind('/'); + if (slashPosition == TString::npos) { + return ""; + } + + return name.substr(0, slashPosition); +} + +std::optional<TString> GetRootName(const TString& name) +{ + if (name.empty()) { + return std::nullopt; + } + + if (name == "/") { + return name; + } + + auto slashPosition = name.find('/'); + if (slashPosition == TString::npos) { + return name; + } + + return name.substr(0, slashPosition); +} + +} // namespace NDetail + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoInstanceLauncher + : public IInstanceLauncher +{ +public: + TPortoInstanceLauncher(const TString& name, IPortoExecutorPtr executor) + : Executor_(std::move(executor)) + , Logger(ContainersLogger.WithTag("Container: %v", name)) + { + Spec_.Name = name; + Spec_.CGroupControllers = { + "freezer", + "cpu", + "cpuacct", + "net_cls", + "blkio", + "devices", + "pids" + }; + } + + const TString& GetName() const override + { + return Spec_.Name; + } + + bool HasRoot() const override + { + return static_cast<bool>(Spec_.RootFS); + } + + void SetStdIn(const TString& inputPath) override + { + Spec_.StdinPath = inputPath; + } + + void SetStdOut(const TString& outPath) override + { + Spec_.StdoutPath = outPath; + } + + void SetStdErr(const TString& errorPath) override + { + Spec_.StderrPath = errorPath; + } + + void SetCwd(const TString& pwd) override + { + Spec_.CurrentWorkingDirectory = pwd; + } + + void SetCoreDumpHandler(const std::optional<TString>& handler) override + { + if (handler) { + Spec_.CoreCommand = *handler; + Spec_.EnableCoreDumps = true; + } else { + Spec_.EnableCoreDumps = false; + } + } + + void SetRoot(const TRootFS& rootFS) override + { + Spec_.RootFS = rootFS; + } + + void SetThreadLimit(i64 threadLimit) override + { + Spec_.ThreadLimit = threadLimit; + } + + void SetDevices(const std::vector<TDevice>& devices) override + { + Spec_.Devices = devices; + } + + void SetEnablePorto(EEnablePorto enablePorto) override + { + Spec_.EnablePorto = enablePorto; + } + + void SetIsolate(bool isolate) override + { + Spec_.Isolate = isolate; + } + + void EnableMemoryTracking() override + { + Spec_.CGroupControllers.push_back("memory"); + } + + void SetGroup(int groupId) override + { + Spec_.GroupId = groupId; + } + + void SetUser(const TString& user) override + { + Spec_.User = user; + } + + void SetIPAddresses(const std::vector<NNet::TIP6Address>& addresses, bool enableNat64) override + { + Spec_.IPAddresses = addresses; + Spec_.EnableNat64 = enableNat64; + Spec_.DisableNetwork = false; + } + + void DisableNetwork() override + { + Spec_.DisableNetwork = true; + Spec_.IPAddresses.clear(); + Spec_.EnableNat64 = false; + } + + void SetHostName(const TString& hostName) override + { + Spec_.HostName = hostName; + } + + TFuture<IInstancePtr> Launch( + const TString& path, + const std::vector<TString>& args, + const THashMap<TString, TString>& env) override + { + TStringBuilder commandBuilder; + auto append = [&] (const auto& value) { + commandBuilder.AppendString("'"); + commandBuilder.AppendString(NDetail::EscapeForWordexp(value.c_str())); + commandBuilder.AppendString("' "); + }; + + append(path); + for (const auto& arg : args) { + append(arg); + } + + Spec_.Command = commandBuilder.Flush(); + YT_LOG_DEBUG("Executing Porto container (Name: %v, Command: %v)", + Spec_.Name, + Spec_.Command); + + Spec_.Env = env; + + auto onContainerCreated = [this, this_ = MakeStrong(this)] (const TError& error) -> IInstancePtr { + if (!error.IsOK()) { + THROW_ERROR_EXCEPTION(EErrorCode::FailedToStartContainer, "Unable to start container") + << error; + } + + return GetPortoInstance(Executor_, Spec_.Name); + }; + + return Executor_->CreateContainer(Spec_, /* start */ true) + .Apply(BIND(onContainerCreated)); + } + +private: + IPortoExecutorPtr Executor_; + TRunnableContainerSpec Spec_; + const NLogging::TLogger Logger; +}; + +IInstanceLauncherPtr CreatePortoInstanceLauncher(const TString& name, IPortoExecutorPtr executor) +{ + return New<TPortoInstanceLauncher>(name, executor); +} + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoInstance + : public IInstance +{ +public: + static IInstancePtr GetSelf(IPortoExecutorPtr executor) + { + return New<TPortoInstance>(GetSelfContainerName(executor), executor); + } + + static IInstancePtr GetInstance(IPortoExecutorPtr executor, const TString& name) + { + return New<TPortoInstance>(name, executor); + } + + void Kill(int signal) override + { + auto error = WaitFor(Executor_->KillContainer(Name_, signal)); + // Killing already finished process is not an error. + if (error.FindMatching(EPortoErrorCode::InvalidState)) { + return; + } + if (!error.IsOK()) { + THROW_ERROR_EXCEPTION("Failed to send signal to Porto instance") + << TErrorAttribute("signal", signal) + << TErrorAttribute("container", Name_) + << error; + } + } + + void Destroy() override + { + WaitFor(Executor_->DestroyContainer(Name_)) + .ThrowOnError(); + Destroyed_ = true; + } + + void Stop() override + { + WaitFor(Executor_->StopContainer(Name_)) + .ThrowOnError(); + } + + TErrorOr<ui64> CalculateCpuUserUsage( + TErrorOr<ui64>& cpuUsage, + TErrorOr<ui64>& cpuSystemUsage) const + { + if (cpuUsage.IsOK() && cpuSystemUsage.IsOK()) { + return cpuUsage.Value() > cpuSystemUsage.Value() ? cpuUsage.Value() - cpuSystemUsage.Value() : 0; + } else if (cpuUsage.IsOK()) { + return TError("Missing property %Qlv in Porto response", EStatField::CpuSystemUsage) + << TErrorAttribute("container", Name_); + } else { + return TError("Missing property %Qlv in Porto response", EStatField::CpuUsage) + << TErrorAttribute("container", Name_); + } + } + + TResourceUsage GetResourceUsage( + const std::vector<EStatField>& fields) const override + { + std::vector<TString> properties; + properties.push_back("absolute_name"); + + bool userTimeRequested = false; + bool contextSwitchesRequested = false; + for (auto field : fields) { + if (auto it = NDetail::PortoStatRules.find(field)) { + const auto& rule = it->second; + properties.push_back(rule.first); + } else if (field == EStatField::ContextSwitchesDelta || field == EStatField::ContextSwitches) { + contextSwitchesRequested = true; + } else if (field == EStatField::CpuUserUsage) { + userTimeRequested = true; + } else { + THROW_ERROR_EXCEPTION("Unknown resource field %Qlv requested", field) + << TErrorAttribute("container", Name_); + } + } + + auto propertyMap = WaitFor(Executor_->GetContainerProperties(Name_, properties)) + .ValueOrThrow(); + + TResourceUsage result; + + for (auto field : fields) { + auto ruleIt = NDetail::PortoStatRules.find(field); + if (ruleIt == NDetail::PortoStatRules.end()) { + continue; + } + + const auto& [property, callback] = ruleIt->second; + auto& record = result[field]; + if (auto responseIt = propertyMap.find(property); responseIt != propertyMap.end()) { + const auto& valueOrError = responseIt->second; + if (valueOrError.IsOK()) { + const auto& value = valueOrError.Value(); + + try { + record = callback(value); + } catch (const std::exception& ex) { + record = TError("Error parsing Porto property %Qlv", field) + << TErrorAttribute("container", Name_) + << TErrorAttribute("property_value", value) + << ex; + } + } else { + record = TError("Error getting Porto property %Qlv", field) + << TErrorAttribute("container", Name_) + << valueOrError; + } + } else { + record = TError("Missing property %Qlv in Porto response", field) + << TErrorAttribute("container", Name_); + } + } + + // We should maintain context switch information even if this field + // is not requested since metrics of individual containers can go up and down. + auto subcontainers = WaitFor(Executor_->ListSubcontainers(Name_, /*includeRoot*/ true)) + .ValueOrThrow(); + + auto metricMap = WaitFor(Executor_->GetContainerMetrics(subcontainers, "ctxsw")) + .ValueOrThrow(); + + // TODO(don-dron): remove diff calculation from GetResourceUsage, because GetResourceUsage must return only snapshot stat. + { + auto guard = Guard(ContextSwitchMapLock_); + + for (const auto& [container, newValue] : metricMap) { + auto& prevValue = ContextSwitchMap_[container]; + TotalContextSwitches_ += std::max<i64>(0LL, newValue - prevValue); + prevValue = newValue; + } + + if (contextSwitchesRequested) { + result[EStatField::ContextSwitchesDelta] = TotalContextSwitches_; + } + } + + if (contextSwitchesRequested) { + ui64 totalContextSwitches = 0; + + for (const auto& [container, newValue] : metricMap) { + totalContextSwitches += std::max<ui64>(0UL, newValue); + } + + result[EStatField::ContextSwitches] = totalContextSwitches; + } + + if (userTimeRequested) { + result[EStatField::CpuUserUsage] = CalculateCpuUserUsage( + result[EStatField::CpuUsage], + result[EStatField::CpuSystemUsage]); + } + + return result; + } + + TResourceLimits GetResourceLimits() const override + { + std::vector<TString> properties; + static TString memoryLimitProperty = "memory_limit_total"; + static TString cpuLimitProperty = "cpu_limit_bound"; + static TString cpuGuaranteeProperty = "cpu_guarantee_bound"; + properties.push_back(memoryLimitProperty); + properties.push_back(cpuLimitProperty); + properties.push_back(cpuGuaranteeProperty); + + auto responseOrError = WaitFor(Executor_->GetContainerProperties(Name_, properties)); + THROW_ERROR_EXCEPTION_IF_FAILED(responseOrError, "Failed to get Porto container resource limits"); + + const auto& response = responseOrError.Value(); + + const auto& memoryLimitRsp = response.at(memoryLimitProperty); + THROW_ERROR_EXCEPTION_IF_FAILED(memoryLimitRsp, "Failed to get memory limit from Porto"); + + i64 memoryLimit; + if (!TryFromString<i64>(memoryLimitRsp.Value(), memoryLimit)) { + THROW_ERROR_EXCEPTION("Failed to parse memory limit value from Porto") + << TErrorAttribute(memoryLimitProperty, memoryLimitRsp.Value()); + } + + const auto& cpuLimitRsp = response.at(cpuLimitProperty); + THROW_ERROR_EXCEPTION_IF_FAILED(cpuLimitRsp, "Failed to get CPU limit from Porto"); + + double cpuLimit; + YT_VERIFY(cpuLimitRsp.Value().EndsWith('c')); + auto cpuLimitValue = TStringBuf(cpuLimitRsp.Value().begin(), cpuLimitRsp.Value().size() - 1); + if (!TryFromString<double>(cpuLimitValue, cpuLimit)) { + THROW_ERROR_EXCEPTION("Failed to parse CPU limit value from Porto") + << TErrorAttribute(cpuLimitProperty, cpuLimitRsp.Value()); + } + + const auto& cpuGuaranteeRsp = response.at(cpuGuaranteeProperty); + THROW_ERROR_EXCEPTION_IF_FAILED(cpuGuaranteeRsp, "Failed to get CPU guarantee from Porto"); + + double cpuGuarantee; + if (!cpuGuaranteeRsp.Value()) { + // XXX: hack for missing response from porto. + cpuGuarantee = 0.0; + } else { + YT_VERIFY(cpuGuaranteeRsp.Value().EndsWith('c')); + auto cpuGuaranteeValue = TStringBuf(cpuGuaranteeRsp.Value().begin(), cpuGuaranteeRsp.Value().size() - 1); + if (!TryFromString<double>(cpuGuaranteeValue, cpuGuarantee)) { + THROW_ERROR_EXCEPTION("Failed to parse CPU guarantee value from Porto") + << TErrorAttribute(cpuGuaranteeProperty, cpuGuaranteeRsp.Value()); + } + } + + return TResourceLimits{ + .CpuLimit = cpuLimit, + .CpuGuarantee = cpuGuarantee, + .Memory = memoryLimit, + }; + } + + void SetCpuGuarantee(double cores) override + { + SetProperty("cpu_guarantee", ToString(cores) + "c"); + } + + void SetCpuLimit(double cores) override + { + SetProperty("cpu_limit", ToString(cores) + "c"); + } + + void SetCpuWeight(double weight) override + { + SetProperty("cpu_weight", weight); + } + + void SetMemoryGuarantee(i64 memoryGuarantee) override + { + SetProperty("memory_guarantee", memoryGuarantee); + } + + void SetIOWeight(double weight) override + { + SetProperty("io_weight", weight); + } + + void SetIOThrottle(i64 operations) override + { + SetProperty("io_ops_limit", operations); + } + + TString GetStderr() const override + { + return *WaitFor(Executor_->GetContainerProperty(Name_, "stderr")) + .ValueOrThrow(); + } + + TString GetName() const override + { + return Name_; + } + + std::optional<TString> GetParentName() const override + { + return NDetail::GetParentName(Name_); + } + + std::optional<TString> GetRootName() const override + { + return NDetail::GetRootName(Name_); + } + + pid_t GetPid() const override + { + auto pid = *WaitFor(Executor_->GetContainerProperty(Name_, "root_pid")) + .ValueOrThrow(); + return std::stoi(pid); + } + + i64 GetMajorPageFaultCount() const override + { + auto faults = WaitFor(Executor_->GetContainerProperty(Name_, "major_faults")) + .ValueOrThrow(); + return faults + ? std::stoll(*faults) + : 0; + } + + std::vector<pid_t> GetPids() const override + { + auto getPidCgroup = [&] (const TString& cgroups) { + for (TStringBuf cgroup : StringSplitter(cgroups).SplitByString("; ")) { + if (cgroup.StartsWith("pids:")) { + auto startPosition = cgroup.find('/'); + YT_VERIFY(startPosition != TString::npos); + return cgroup.substr(startPosition); + } + } + THROW_ERROR_EXCEPTION("Pids cgroup not found for container %Qv", GetName()) + << TErrorAttribute("cgroups", cgroups); + }; + + auto cgroups = *WaitFor(Executor_->GetContainerProperty(Name_, "cgroups")) + .ValueOrThrow(); + // Porto returns full cgroup name, with mount prefix, such as "/sys/fs/cgroup/pids". + auto instanceCgroup = getPidCgroup(cgroups); + + std::vector<pid_t> pids; + for (auto pid : ListPids()) { + std::map<TString, TString> cgroups; + try { + cgroups = GetProcessCGroups(pid); + } catch (const std::exception& ex) { + YT_LOG_DEBUG(ex, "Failed to get CGroups for process (Pid: %v)", pid); + continue; + } + + // Pid cgroups are returned in short form. + auto processPidCgroup = cgroups["pids"]; + if (!processPidCgroup.empty() && instanceCgroup.EndsWith(processPidCgroup)) { + pids.push_back(pid); + } + } + + return pids; + } + + TFuture<void> Wait() override + { + return Executor_->PollContainer(Name_) + .Apply(BIND([] (int status) { + StatusToError(status) + .ThrowOnError(); + })); + } + +private: + const TString Name_; + const IPortoExecutorPtr Executor_; + const NLogging::TLogger Logger; + + bool Destroyed_ = false; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, ContextSwitchMapLock_); + mutable i64 TotalContextSwitches_ = 0; + mutable THashMap<TString, i64> ContextSwitchMap_; + + TPortoInstance(TString name, IPortoExecutorPtr executor) + : Name_(std::move(name)) + , Executor_(std::move(executor)) + , Logger(ContainersLogger.WithTag("Container: %v", Name_)) + { } + + void SetProperty(const TString& key, const TString& value) + { + WaitFor(Executor_->SetContainerProperty(Name_, key, value)) + .ThrowOnError(); + } + + void SetProperty(const TString& key, i64 value) + { + SetProperty(key, ToString(value)); + } + + void SetProperty(const TString& key, double value) + { + SetProperty(key, ToString(value)); + } + + DECLARE_NEW_FRIEND() +}; + +//////////////////////////////////////////////////////////////////////////////// + +TString GetSelfContainerName(const IPortoExecutorPtr& executor) +{ + try { + auto properties = WaitFor(executor->GetContainerProperties( + "self", + std::vector<TString>{"absolute_name", "absolute_namespace"})) + .ValueOrThrow(); + + auto absoluteName = properties.at("absolute_name") + .ValueOrThrow(); + auto absoluteNamespace = properties.at("absolute_namespace") + .ValueOrThrow(); + + if (absoluteName == "/") { + return absoluteName; + } + + if (absoluteName.length() < absoluteNamespace.length()) { + YT_VERIFY(absoluteName + "/" == absoluteNamespace); + return ""; + } else { + YT_VERIFY(absoluteName.StartsWith(absoluteNamespace)); + return absoluteName.substr(absoluteNamespace.length()); + } + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Failed to get name for container \"self\"") + << ex; + } +} + +IInstancePtr GetSelfPortoInstance(IPortoExecutorPtr executor) +{ + return TPortoInstance::GetSelf(executor); +} + +IInstancePtr GetPortoInstance(IPortoExecutorPtr executor, const TString& name) +{ + return TPortoInstance::GetInstance(executor, name); +} + +IInstancePtr GetRootPortoInstance(IPortoExecutorPtr executor) +{ + auto self = GetSelfPortoInstance(executor); + return TPortoInstance::GetInstance(executor, *self->GetRootName()); +} + +double GetSelfPortoInstanceVCpuFactor() +{ + auto config = New<TPortoExecutorDynamicConfig>(); + auto executorPtr = CreatePortoExecutor(config, ""); + auto currentContainer = GetSelfPortoInstance(executorPtr); + double cpuLimit = currentContainer->GetResourceLimits().CpuLimit; + if (cpuLimit <= 0) { + THROW_ERROR_EXCEPTION("Cpu limit must be greater than 0"); + } + + // DEPLOY_VCPU_LIMIT stores value in millicores + if (TString vcpuLimitStr = GetEnv("DEPLOY_VCPU_LIMIT"); !vcpuLimitStr.Empty()) { + double vcpuLimit = FromString<double>(vcpuLimitStr) / 1000.0; + return vcpuLimit / cpuLimit; + } + THROW_ERROR_EXCEPTION("Failed to get vcpu limit from env variable"); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers + +#endif diff --git a/yt/yt/library/containers/instance.h b/yt/yt/library/containers/instance.h new file mode 100644 index 0000000000..719190a18f --- /dev/null +++ b/yt/yt/library/containers/instance.h @@ -0,0 +1,167 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/net/address.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +using TResourceUsage = THashMap<EStatField, TErrorOr<ui64>>; + +const std::vector<EStatField> InstanceStatFields{ + EStatField::CpuUsage, + EStatField::CpuUserUsage, + EStatField::CpuSystemUsage, + EStatField::CpuWait, + EStatField::CpuThrottled, + EStatField::ContextSwitches, + EStatField::ContextSwitchesDelta, + EStatField::ThreadCount, + EStatField::CpuLimit, + EStatField::CpuGuarantee, + + EStatField::Rss, + EStatField::MappedFile, + EStatField::MajorPageFaults, + EStatField::MinorPageFaults, + EStatField::FileCacheUsage, + EStatField::AnonMemoryUsage, + EStatField::AnonMemoryLimit, + EStatField::MemoryUsage, + EStatField::MemoryGuarantee, + EStatField::MemoryLimit, + EStatField::MaxMemoryUsage, + EStatField::OomKills, + EStatField::OomKillsTotal, + + EStatField::IOReadByte, + EStatField::IOWriteByte, + EStatField::IOBytesLimit, + EStatField::IOReadOps, + EStatField::IOWriteOps, + EStatField::IOOps, + EStatField::IOOpsLimit, + EStatField::IOTotalTime, + EStatField::IOWaitTime, + + EStatField::NetTxBytes, + EStatField::NetTxPackets, + EStatField::NetTxDrops, + EStatField::NetTxLimit, + EStatField::NetRxBytes, + EStatField::NetRxPackets, + EStatField::NetRxDrops, + EStatField::NetRxLimit, +}; + +struct TResourceLimits +{ + double CpuLimit; + double CpuGuarantee; + i64 Memory; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct IInstanceLauncher + : public TRefCounted +{ + virtual bool HasRoot() const = 0; + virtual const TString& GetName() const = 0; + + virtual void SetStdIn(const TString& inputPath) = 0; + virtual void SetStdOut(const TString& outPath) = 0; + virtual void SetStdErr(const TString& errorPath) = 0; + virtual void SetCwd(const TString& pwd) = 0; + + // Null core dump handler implies disabled core dumps. + virtual void SetCoreDumpHandler(const std::optional<TString>& handler) = 0; + virtual void SetRoot(const TRootFS& rootFS) = 0; + + virtual void SetThreadLimit(i64 threadLimit) = 0; + virtual void SetDevices(const std::vector<TDevice>& devices) = 0; + + virtual void SetEnablePorto(EEnablePorto enablePorto) = 0; + virtual void SetIsolate(bool isolate) = 0; + virtual void EnableMemoryTracking() = 0; + virtual void SetGroup(int groupId) = 0; + virtual void SetUser(const TString& user) = 0; + virtual void SetIPAddresses( + const std::vector<NNet::TIP6Address>& addresses, + bool enableNat64 = false) = 0; + virtual void DisableNetwork() = 0; + virtual void SetHostName(const TString& hostName) = 0; + + virtual TFuture<IInstancePtr> Launch( + const TString& path, + const std::vector<TString>& args, + const THashMap<TString, TString>& env) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IInstanceLauncher) + +#ifdef _linux_ +IInstanceLauncherPtr CreatePortoInstanceLauncher(const TString& name, IPortoExecutorPtr executor); +#endif + +//////////////////////////////////////////////////////////////////////////////// + +struct IInstance + : public TRefCounted +{ + virtual void Kill(int signal) = 0; + virtual void Stop() = 0; + virtual void Destroy() = 0; + + virtual TResourceUsage GetResourceUsage( + const std::vector<EStatField>& fields = InstanceStatFields) const = 0; + virtual TResourceLimits GetResourceLimits() const = 0; + virtual void SetCpuGuarantee(double cores) = 0; + virtual void SetCpuLimit(double cores) = 0; + virtual void SetCpuWeight(double weight) = 0; + virtual void SetIOWeight(double weight) = 0; + virtual void SetIOThrottle(i64 operations) = 0; + virtual void SetMemoryGuarantee(i64 memoryGuarantee) = 0; + + virtual TString GetStderr() const = 0; + + virtual TString GetName() const = 0; + virtual std::optional<TString> GetParentName() const = 0; + virtual std::optional<TString> GetRootName() const = 0; + + //! Returns externally visible pid of the root proccess inside container. + //! Throws if container is not running. + virtual pid_t GetPid() const = 0; + //! Returns the list of externally visible pids of processes running inside container. + virtual std::vector<pid_t> GetPids() const = 0; + + virtual i64 GetMajorPageFaultCount() const = 0; + + //! Future is set when container reaches terminal state (stopped or dead). + //! Resulting error is OK iff container exited with code 0. + virtual TFuture<void> Wait() = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IInstance) + +//////////////////////////////////////////////////////////////////////////////// + +#ifdef _linux_ +TString GetSelfContainerName(const IPortoExecutorPtr& executor); + +IInstancePtr GetSelfPortoInstance(IPortoExecutorPtr executor); +IInstancePtr GetRootPortoInstance(IPortoExecutorPtr executor); +IInstancePtr GetPortoInstance(IPortoExecutorPtr executor, const TString& name); + +//! Works only in Yandex.Deploy pod environment where env DEPLOY_VCPU_LIMIT is set. +//! Throws if this env is absent. +double GetSelfPortoInstanceVCpuFactor(); +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/instance_limits_tracker.cpp b/yt/yt/library/containers/instance_limits_tracker.cpp new file mode 100644 index 0000000000..ae280fa561 --- /dev/null +++ b/yt/yt/library/containers/instance_limits_tracker.cpp @@ -0,0 +1,173 @@ +#include "public.h" +#include "instance_limits_tracker.h" +#include "instance.h" +#include "porto_resource_tracker.h" +#include "private.h" + +#include <yt/yt/core/concurrency/periodic_executor.h> + +#include <yt/yt/core/ytree/fluent.h> +#include <yt/yt/core/ytree/ypath_service.h> + +namespace NYT::NContainers { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = ContainersLogger; + +//////////////////////////////////////////////////////////////////////////////// + +TInstanceLimitsTracker::TInstanceLimitsTracker( + IInstancePtr instance, + IInstancePtr root, + IInvokerPtr invoker, + TDuration updatePeriod) + : Invoker_(std::move(invoker)) + , Executor_(New<NConcurrency::TPeriodicExecutor>( + Invoker_, + BIND(&TInstanceLimitsTracker::DoUpdateLimits, MakeWeak(this)), + updatePeriod)) +{ +#ifdef _linux_ + SelfTracker_ = New<TPortoResourceTracker>(std::move(instance), updatePeriod / 2); + RootTracker_ = New<TPortoResourceTracker>(std::move(root), updatePeriod / 2); +#else + Y_UNUSED(instance); + Y_UNUSED(root); +#endif +} + +void TInstanceLimitsTracker::Start() +{ + if (!Running_) { + Executor_->Start(); + Running_ = true; + YT_LOG_INFO("Instance limits tracker started"); + } +} + +void TInstanceLimitsTracker::Stop() +{ + if (Running_) { + YT_UNUSED_FUTURE(Executor_->Stop()); + Running_ = false; + YT_LOG_INFO("Instance limits tracker stopped"); + } +} + +void TInstanceLimitsTracker::DoUpdateLimits() +{ + VERIFY_INVOKER_AFFINITY(Invoker_); + +#ifdef _linux_ + YT_LOG_DEBUG("Checking for instance limits update"); + + auto setIfOk = [] (auto* destination, const auto& valueOrError, const TString& fieldName, bool alert = true) { + if (valueOrError.IsOK()) { + *destination = valueOrError.Value(); + } else { + YT_LOG_ALERT_IF(alert, valueOrError, "Failed to get container limit (Field: %v)", + fieldName); + + YT_LOG_DEBUG(valueOrError, "Failed to get container limit (Field: %v)", + fieldName); + } + }; + + try { + auto memoryStatistics = SelfTracker_->GetMemoryStatistics(); + auto netStatistics = RootTracker_->GetNetworkStatistics(); + auto cpuStatistics = SelfTracker_->GetCpuStatistics(); + + setIfOk(&MemoryUsage_, memoryStatistics.Rss, "MemoryRss"); + + TDuration cpuGuarantee; + TDuration cpuLimit; + setIfOk(&cpuGuarantee, cpuStatistics.GuaranteeTime, "CpuGuarantee"); + setIfOk(&cpuLimit, cpuStatistics.LimitTime, "CpuLimit"); + + if (CpuGuarantee_ != cpuGuarantee) { + YT_LOG_INFO("Instance CPU guarantee updated (OldCpuGuarantee: %v, NewCpuGuarantee: %v)", + CpuGuarantee_, + cpuGuarantee); + CpuGuarantee_ = cpuGuarantee; + // NB: We do not fire LimitsUpdated since this value used only for diagnostics. + } + + TInstanceLimits limits; + limits.Cpu = cpuLimit.SecondsFloat(); + + if (memoryStatistics.AnonLimit.IsOK() && memoryStatistics.MemoryLimit.IsOK()) { + i64 anonLimit = memoryStatistics.AnonLimit.Value(); + i64 memoryLimit = memoryStatistics.MemoryLimit.Value(); + + if (anonLimit > 0 && memoryLimit > 0) { + limits.Memory = std::min(anonLimit, memoryLimit); + } else if (anonLimit > 0) { + limits.Memory = anonLimit; + } else { + limits.Memory = memoryLimit; + } + } else { + setIfOk(&limits.Memory, memoryStatistics.MemoryLimit, "MemoryLimit"); + } + + static constexpr bool DontFireAlertOnError = {}; + setIfOk(&limits.NetTx, netStatistics.TxLimit, "NetTxLimit", DontFireAlertOnError); + setIfOk(&limits.NetRx, netStatistics.RxLimit, "NetRxLimit", DontFireAlertOnError); + + if (InstanceLimits_ != limits) { + YT_LOG_INFO("Instance limits updated (OldLimits: %v, NewLimits: %v)", + InstanceLimits_, + limits); + InstanceLimits_ = limits; + LimitsUpdated_.Fire(limits); + } + } catch (const std::exception& ex) { + YT_LOG_WARNING(ex, "Failed to get instance limits"); + } +#endif +} + +IYPathServicePtr TInstanceLimitsTracker::GetOrchidService() +{ + return IYPathService::FromProducer(BIND(&TInstanceLimitsTracker::DoBuildOrchid, MakeStrong(this))) + ->Via(Invoker_); +} + +void TInstanceLimitsTracker::DoBuildOrchid(NYson::IYsonConsumer* consumer) const +{ + NYTree::BuildYsonFluently(consumer) + .BeginMap() + .DoIf(static_cast<bool>(InstanceLimits_), [&] (auto fluent) { + fluent.Item("cpu_limit").Value(InstanceLimits_->Cpu); + }) + .DoIf(static_cast<bool>(CpuGuarantee_), [&] (auto fluent) { + fluent.Item("cpu_guarantee").Value(*CpuGuarantee_); + }) + .DoIf(static_cast<bool>(InstanceLimits_), [&] (auto fluent) { + fluent.Item("memory_limit").Value(InstanceLimits_->Memory); + }) + .DoIf(static_cast<bool>(MemoryUsage_), [&] (auto fluent) { + fluent.Item("memory_usage").Value(*MemoryUsage_); + }) + .EndMap(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue(TStringBuilderBase* builder, const TInstanceLimits& limits, TStringBuf /*format*/) +{ + builder->AppendFormat( + "{Cpu: %v, Memory: %v, NetTx: %v, NetRx: %v}", + limits.Cpu, + limits.Memory, + limits.NetTx, + limits.NetRx); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainters diff --git a/yt/yt/library/containers/instance_limits_tracker.h b/yt/yt/library/containers/instance_limits_tracker.h new file mode 100644 index 0000000000..e652fff446 --- /dev/null +++ b/yt/yt/library/containers/instance_limits_tracker.h @@ -0,0 +1,59 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/actions/signal.h> + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/yson/public.h> + +#include <yt/yt/core/ytree/public.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TInstanceLimitsTracker + : public TRefCounted +{ +public: + //! Raises when container limits change. + DEFINE_SIGNAL(void(const TInstanceLimits&), LimitsUpdated); + +public: + TInstanceLimitsTracker( + IInstancePtr instance, + IInstancePtr root, + IInvokerPtr invoker, + TDuration updatePeriod); + + void Start(); + void Stop(); + + NYTree::IYPathServicePtr GetOrchidService(); + +private: + void DoUpdateLimits(); + void DoBuildOrchid(NYson::IYsonConsumer* consumer) const; + + TPortoResourceTrackerPtr SelfTracker_; + TPortoResourceTrackerPtr RootTracker_; + const IInvokerPtr Invoker_; + const NConcurrency::TPeriodicExecutorPtr Executor_; + + std::optional<TDuration> CpuGuarantee_; + std::optional<TInstanceLimits> InstanceLimits_; + std::optional<i64> MemoryUsage_; + bool Running_ = false; +}; + +DEFINE_REFCOUNTED_TYPE(TInstanceLimitsTracker) + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue(TStringBuilderBase* builder, const TInstanceLimits& limits, TStringBuf format); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_executor.cpp b/yt/yt/library/containers/porto_executor.cpp new file mode 100644 index 0000000000..ea96afe007 --- /dev/null +++ b/yt/yt/library/containers/porto_executor.cpp @@ -0,0 +1,1079 @@ +#include "porto_executor.h" +#include "config.h" + +#include "private.h" + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/fs.h> + +#include <yt/yt/core/profiling/timing.h> + +#include <yt/yt/core/ytree/convert.h> + +#include <library/cpp/porto/proto/rpc.pb.h> + +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +#include <string> + +namespace NYT::NContainers { + +using namespace NConcurrency; +using Porto::EError; + +//////////////////////////////////////////////////////////////////////////////// + +#ifdef _linux_ + +static const NLogging::TLogger& Logger = ContainersLogger; +static constexpr auto RetryInterval = TDuration::MilliSeconds(100); + +//////////////////////////////////////////////////////////////////////////////// + +TString PortoErrorCodeFormatter(int code) +{ + return TEnumTraits<EPortoErrorCode>::ToString(static_cast<EPortoErrorCode>(code)); +} + +YT_DEFINE_ERROR_CODE_RANGE(12000, 13999, "NYT::NContainers::EPortoErrorCode", PortoErrorCodeFormatter); + +//////////////////////////////////////////////////////////////////////////////// + +EPortoErrorCode ConvertPortoErrorCode(EError portoError) +{ + return static_cast<EPortoErrorCode>(PortoErrorCodeBase + portoError); +} + +bool IsRetriableErrorCode(EPortoErrorCode error, bool idempotent) +{ + return + error == EPortoErrorCode::Unknown || + // TODO(babenko): it's not obvious that we can always retry SocketError + // but this is how it has used to work for a while. + error == EPortoErrorCode::SocketError || + error == EPortoErrorCode::SocketTimeout && idempotent; +} + +THashMap<TString, TErrorOr<TString>> ParsePortoGetResponse( + const Porto::TGetResponse_TContainerGetListResponse& response) +{ + THashMap<TString, TErrorOr<TString>> result; + for (const auto& property : response.keyval()) { + if (property.error() == EError::Success) { + result[property.variable()] = property.value(); + } else { + result[property.variable()] = TError(ConvertPortoErrorCode(property.error()), property.errormsg()) + << TErrorAttribute("porto_error", ConvertPortoErrorCode(property.error())); + } + } + return result; +} + +THashMap<TString, TErrorOr<TString>> ParseSinglePortoGetResponse( + const TString& name, + const Porto::TGetResponse& getResponse) +{ + for (const auto& container : getResponse.list()) { + if (container.name() == name) { + return ParsePortoGetResponse(container); + } + } + THROW_ERROR_EXCEPTION("Unable to get properties from Porto") + << TErrorAttribute("container", name); +} + +THashMap<TString, THashMap<TString, TErrorOr<TString>>> ParseMultiplePortoGetResponse( + const Porto::TGetResponse& getResponse) +{ + THashMap<TString, THashMap<TString, TErrorOr<TString>>> result; + for (const auto& container : getResponse.list()) { + result[container.name()] = ParsePortoGetResponse(container); + } + return result; +} + +TString FormatEnablePorto(EEnablePorto value) +{ + switch (value) { + case EEnablePorto::None: return "none"; + case EEnablePorto::Isolate: return "isolate"; + case EEnablePorto::Full: return "full"; + default: YT_ABORT(); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoExecutor + : public IPortoExecutor +{ +public: + TPortoExecutor( + TPortoExecutorDynamicConfigPtr config, + const TString& threadNameSuffix, + const NProfiling::TProfiler& profiler) + : Config_(std::move(config)) + , Queue_(New<TActionQueue>(Format("Porto:%v", threadNameSuffix))) + , Profiler_(profiler) + , PollExecutor_(New<TPeriodicExecutor>( + Queue_->GetInvoker(), + BIND(&TPortoExecutor::DoPoll, MakeWeak(this)), + Config_->PollPeriod)) + { + DynamicConfig_.Store(New<TPortoExecutorDynamicConfig>()); + + Api_->SetTimeout(Config_->ApiTimeout.Seconds()); + Api_->SetDiskTimeout(Config_->ApiDiskTimeout.Seconds()); + + PollExecutor_->Start(); + } + + void SubscribeFailed(const TCallback<void (const TError&)>& callback) override + { + Failed_.Subscribe(callback); + } + + void UnsubscribeFailed(const TCallback<void (const TError&)>& callback) override + { + Failed_.Unsubscribe(callback); + } + + void OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig) override + { + DynamicConfig_.Store(newConfig); + } + +private: + template <class T, class... TArgs1, class... TArgs2> + auto ExecutePortoApiAction( + T(TPortoExecutor::*Method)(TArgs1...), + const TString& command, + TArgs2&&... args) + { + YT_LOG_DEBUG("Enqueue Porto API action (Command: %v)", command); + return BIND(Method, MakeStrong(this), std::forward<TArgs2>(args)...) + .AsyncVia(Queue_->GetInvoker()) + .Run(); + }; + +public: + TFuture<void> CreateContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoCreateContainer, + "CreateContainer", + container); + } + + TFuture<void> CreateContainer(const TRunnableContainerSpec& containerSpec, bool start) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoCreateContainerFromSpec, + "CreateContainerFromSpec", + containerSpec, + start); + } + + TFuture<std::optional<TString>> GetContainerProperty( + const TString& container, + const TString& property) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerProperty, + "GetContainerProperty", + container, + property); + } + + TFuture<THashMap<TString, TErrorOr<TString>>> GetContainerProperties( + const TString& container, + const std::vector<TString>& properties) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerProperties, + "GetContainerProperty", + container, + properties); + } + + TFuture<THashMap<TString, THashMap<TString, TErrorOr<TString>>>> GetContainerProperties( + const std::vector<TString>& containers, + const std::vector<TString>& properties) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerMultipleProperties, + "GetContainerProperty", + containers, + properties); + } + + TFuture<THashMap<TString, i64>> GetContainerMetrics( + const std::vector<TString>& containers, + const TString& metric) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoGetContainerMetrics, + "GetContainerMetrics", + containers, + metric); + } + + TFuture<void> SetContainerProperty( + const TString& container, + const TString& property, + const TString& value) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoSetContainerProperty, + "SetContainerProperty", + container, + property, + value); + } + + TFuture<void> DestroyContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoDestroyContainer, + "DestroyContainer", + container); + } + + TFuture<void> StopContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoStopContainer, + "StopContainer", + container); + } + + TFuture<void> StartContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoStartContainer, + "StartContainer", + container); + } + + TFuture<TString> ConvertPath(const TString& path, const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoConvertPath, + "ConvertPath", + path, + container); + } + + TFuture<void> KillContainer(const TString& container, int signal) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoKillContainer, + "KillContainer", + container, + signal); + } + + TFuture<std::vector<TString>> ListSubcontainers( + const TString& rootContainer, + bool includeRoot) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoListSubcontainers, + "ListSubcontainers", + rootContainer, + includeRoot); + } + + TFuture<int> PollContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoPollContainer, + "PollContainer", + container); + } + + TFuture<int> WaitContainer(const TString& container) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoWaitContainer, + "WaitContainer", + container); + } + + // This method allocates porto "resources", so it should be uncancellable. + TFuture<TString> CreateVolume( + const TString& path, + const THashMap<TString, TString>& properties) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoCreateVolume, + "CreateVolume", + path, + properties) + .ToUncancelable(); + } + + // This method allocates porto "resources", so it should be uncancellable. + TFuture<void> LinkVolume( + const TString& path, + const TString& name) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoLinkVolume, + "LinkVolume", + path, + name) + .ToUncancelable(); + } + + // This method deallocates porto "resources", so it should be uncancellable. + TFuture<void> UnlinkVolume( + const TString& path, + const TString& name) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoUnlinkVolume, + "UnlinkVolume", + path, + name) + .ToUncancelable(); + } + + TFuture<std::vector<TString>> ListVolumePaths() override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoListVolumePaths, + "ListVolumePaths"); + } + + // This method allocates porto "resources", so it should be uncancellable. + TFuture<void> ImportLayer(const TString& archivePath, const TString& layerId, const TString& place) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoImportLayer, + "ImportLayer", + archivePath, + layerId, + place) + .ToUncancelable(); + } + + // This method deallocates porto "resources", so it should be uncancellable. + TFuture<void> RemoveLayer(const TString& layerId, const TString& place, bool async) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoRemoveLayer, + "RemoveLayer", + layerId, + place, + async) + .ToUncancelable(); + } + + TFuture<std::vector<TString>> ListLayers(const TString& place) override + { + return ExecutePortoApiAction( + &TPortoExecutor::DoListLayers, + "ListLayers", + place); + } + + IInvokerPtr GetInvoker() const override + { + return Queue_->GetInvoker(); + } + +private: + const TPortoExecutorDynamicConfigPtr Config_; + const TActionQueuePtr Queue_; + const NProfiling::TProfiler Profiler_; + const std::unique_ptr<Porto::TPortoApi> Api_ = std::make_unique<Porto::TPortoApi>(); + const TPeriodicExecutorPtr PollExecutor_; + TAtomicIntrusivePtr<TPortoExecutorDynamicConfig> DynamicConfig_; + + std::vector<TString> Containers_; + THashMap<TString, TPromise<int>> ContainerMap_; + TSingleShotCallbackList<void(const TError&)> Failed_; + + struct TCommandEntry + { + explicit TCommandEntry(const NProfiling::TProfiler& registry) + : TimeGauge(registry.Timer("/command_time")) + , RetryCounter(registry.Counter("/command_retries")) + , SuccessCounter(registry.Counter("/command_successes")) + , FailureCounter(registry.Counter("/command_failures")) + { } + + NProfiling::TEventTimer TimeGauge; + NProfiling::TCounter RetryCounter; + NProfiling::TCounter SuccessCounter; + NProfiling::TCounter FailureCounter; + }; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, CommandLock_); + THashMap<TString, TCommandEntry> CommandToEntry_; + + static const std::vector<TString> ContainerRequestVars_; + + bool IsTestPortoFailureEnabled() const + { + auto config = DynamicConfig_.Acquire(); + return config->EnableTestPortoFailures; + } + + bool IsTestPortoTimeout() const + { + auto config = DynamicConfig_.Acquire(); + return config->EnableTestPortoNotResponding; + } + + EPortoErrorCode GetFailedStubError() const + { + auto config = DynamicConfig_.Acquire(); + return config->StubErrorCode; + } + + static TError CreatePortoError(EPortoErrorCode errorCode, const TString& message) + { + return TError(errorCode, "Porto API error") + << TErrorAttribute("original_porto_error_code", static_cast<int>(errorCode) - PortoErrorCodeBase) + << TErrorAttribute("porto_error_message", message); + } + + THashMap<TString, TErrorOr<TString>> DoGetContainerProperties( + const TString& container, + const std::vector<TString>& properties) + { + auto response = DoRequestContainerProperties({container}, properties); + return ParseSinglePortoGetResponse(container, response); + } + + THashMap<TString, THashMap<TString, TErrorOr<TString>>> DoGetContainerMultipleProperties( + const std::vector<TString>& containers, + const std::vector<TString>& properties) + { + auto response = DoRequestContainerProperties(containers, properties); + return ParseMultiplePortoGetResponse(response); + } + + std::optional<TString> DoGetContainerProperty( + const TString& container, + const TString& property) + { + auto response = DoRequestContainerProperties({container}, {property}); + auto parsedResponse = ParseSinglePortoGetResponse(container, response); + auto it = parsedResponse.find(property); + if (it == parsedResponse.end()) { + return std::nullopt; + } else { + return it->second.ValueOrThrow(); + } + } + + void DoCreateContainer(const TString& container) + { + ExecuteApiCall( + [&] { return Api_->Create(container); }, + "Create", + /*idempotent*/ false); + } + + void DoCreateContainerFromSpec(const TRunnableContainerSpec& spec, bool start) + { + Porto::TContainerSpec portoSpec; + + // Required properties. + portoSpec.set_name(spec.Name); + portoSpec.set_command(spec.Command); + + portoSpec.set_enable_porto(FormatEnablePorto(spec.EnablePorto)); + portoSpec.set_isolate(spec.Isolate); + + if (spec.StdinPath) { + portoSpec.set_stdin_path(*spec.StdinPath); + } + if (spec.StdoutPath) { + portoSpec.set_stdout_path(*spec.StdoutPath); + } + if (spec.StderrPath) { + portoSpec.set_stderr_path(*spec.StderrPath); + } + + if (spec.CurrentWorkingDirectory) { + portoSpec.set_cwd(*spec.CurrentWorkingDirectory); + } + + if (spec.CoreCommand) { + portoSpec.set_core_command(*spec.CoreCommand); + } + if (spec.User) { + portoSpec.set_user(*spec.User); + } + + // Useful for jobs, where we operate with numeric group ids. + if (spec.GroupId) { + portoSpec.set_group(ToString(*spec.GroupId)); + } + + if (spec.ThreadLimit) { + portoSpec.set_thread_limit(*spec.ThreadLimit); + } + + if (spec.HostName) { + // To get a reasonable and unique host name inside container. + portoSpec.set_hostname(*spec.HostName); + if (!spec.IPAddresses.empty()) { + const auto& address = spec.IPAddresses[0]; + auto etcHosts = Format("%v %v\n", address, *spec.HostName); + // To be able to resolve hostname into IP inside container. + portoSpec.set_etc_hosts(etcHosts); + } + } + + if (spec.DisableNetwork) { + auto* netConfig = portoSpec.mutable_net()->add_cfg(); + netConfig->set_opt("none"); + } else if (!spec.IPAddresses.empty() && Config_->EnableNetworkIsolation) { + // This label is intended for HBF-agent: YT-12512. + auto* label = portoSpec.mutable_labels()->add_map(); + label->set_key("HBF.ignore_address"); + label->set_val("1"); + + auto* netConfig = portoSpec.mutable_net()->add_cfg(); + netConfig->set_opt("L3"); + netConfig->add_arg("veth0"); + + for (const auto& address : spec.IPAddresses) { + auto* ipConfig = portoSpec.mutable_ip()->add_cfg(); + ipConfig->set_dev("veth0"); + ipConfig->set_ip(ToString(address)); + } + + if (spec.EnableNat64) { + // Behave like nanny does. + portoSpec.set_resolv_conf("nameserver fd64::1;nameserver 2a02:6b8:0:3400::5005;options attempts:1 timeout:1"); + } + } + + for (const auto& [key, value] : spec.Labels) { + auto* map = portoSpec.mutable_labels()->add_map(); + map->set_key(key); + map->set_val(value); + } + + for (const auto& [name, value] : spec.Env) { + auto* var = portoSpec.mutable_env()->add_var(); + var->set_name(name); + var->set_value(value); + } + + for (const auto& controller : spec.CGroupControllers) { + portoSpec.mutable_controllers()->add_controller(controller); + } + + for (const auto& device : spec.Devices) { + auto* portoDevice = portoSpec.mutable_devices()->add_device(); + portoDevice->set_device(device.DeviceName); + portoDevice->set_access(device.Enabled ? "rw" : "-"); + } + + auto addBind = [&] (const TBind& bind) { + auto* portoBind = portoSpec.mutable_bind()->add_bind(); + portoBind->set_target(bind.TargetPath); + portoBind->set_source(bind.SourcePath); + portoBind->add_flag(bind.ReadOnly ? "ro" : "rw"); + }; + + if (spec.RootFS) { + portoSpec.set_root_readonly(spec.RootFS->IsRootReadOnly); + portoSpec.set_root(spec.RootFS->RootPath); + + for (const auto& bind : spec.RootFS->Binds) { + addBind(bind); + } + } + + { + auto* ulimit = portoSpec.mutable_ulimit()->add_ulimit(); + ulimit->set_type("core"); + if (spec.EnableCoreDumps) { + ulimit->set_unlimited(true); + } else { + ulimit->set_hard(0); + ulimit->set_soft(0); + } + } + + // Set some universal defaults. + portoSpec.set_oom_is_fatal(false); + + ExecuteApiCall( + [&] { return Api_->CreateFromSpec(portoSpec, {}, start); }, + "CreateFromSpec", + /*idempotent*/ false); + } + + void DoSetContainerProperty(const TString& container, const TString& property, const TString& value) + { + ExecuteApiCall( + [&] { return Api_->SetProperty(container, property, value); }, + "SetProperty", + /*idempotent*/ true); + } + + void DoDestroyContainer(const TString& container) + { + try { + ExecuteApiCall( + [&] { return Api_->Destroy(container); }, + "Destroy", + /*idempotent*/ true); + } catch (const TErrorException& ex) { + if (!ex.Error().FindMatching(EPortoErrorCode::ContainerDoesNotExist)) { + throw; + } + } + } + + void DoStopContainer(const TString& container) + { + ExecuteApiCall( + [&] { return Api_->Stop(container); }, + "Stop", + /*idempotent*/ true); + } + + void DoStartContainer(const TString& container) + { + ExecuteApiCall( + [&] { return Api_->Start(container); }, + "Start", + /*idempotent*/ false); + } + + TString DoConvertPath(const TString& path, const TString& container) + { + TString result; + ExecuteApiCall( + [&] { return Api_->ConvertPath(path, container, "self", result); }, + "ConvertPath", + /*idempotent*/ true); + return result; + } + + void DoKillContainer(const TString& container, int signal) + { + ExecuteApiCall( + [&] { return Api_->Kill(container, signal); }, + "Kill", + /*idempotent*/ false); + } + + std::vector<TString> DoListSubcontainers(const TString& rootContainer, bool includeRoot) + { + Porto::TListContainersRequest req; + auto filter = req.add_filters(); + filter->set_name(rootContainer + "/*"); + if (includeRoot) { + auto rootFilter = req.add_filters(); + rootFilter->set_name(rootContainer); + } + auto fieldOptions = req.mutable_field_options(); + fieldOptions->add_properties("absolute_name"); + TVector<Porto::TContainer> containers; + ExecuteApiCall( + [&] { return Api_->ListContainersBy(req, containers); }, + "ListContainersBy", + /*idempotent*/ true); + + std::vector<TString> containerNames; + containerNames.reserve(containers.size()); + for (const auto& container : containers) { + const auto& absoluteName = container.status().absolute_name(); + if (!absoluteName.empty()) { + containerNames.push_back(absoluteName); + } + } + return containerNames; + } + + TFuture<int> DoWaitContainer(const TString& container) + { + auto result = NewPromise<int>(); + auto waitCallback = [=, this, this_ = MakeStrong(this)] (const Porto::TWaitResponse& rsp) { + return OnContainerTerminated(rsp, result); + }; + + ExecuteApiCall( + [&] { return Api_->AsyncWait({container}, {}, waitCallback); }, + "AsyncWait", + /*idempotent*/ false); + + return result.ToFuture().ToImmediatelyCancelable(); + } + + void OnContainerTerminated(const Porto::TWaitResponse& portoWaitResponse, TPromise<int> result) + { + const auto& container = portoWaitResponse.name(); + const auto& state = portoWaitResponse.state(); + if (state != "dead" && state != "stopped") { + result.TrySet(TError("Container finished with unexpected state") + << TErrorAttribute("container_name", container) + << TErrorAttribute("container_state", state)); + return; + } + + // TODO(max42): switch to Subscribe. + YT_UNUSED_FUTURE(GetContainerProperty(container, "exit_status").Apply(BIND( + [=] (const TErrorOr<std::optional<TString>>& errorOrExitCode) { + if (!errorOrExitCode.IsOK()) { + result.TrySet(TError("Container finished, but exit status is unknown") + << errorOrExitCode); + return; + } + + const auto& optionalExitCode = errorOrExitCode.Value(); + if (!optionalExitCode) { + result.TrySet(TError("Container finished, but exit status is unknown") + << TErrorAttribute("container_name", container) + << TErrorAttribute("container_state", state)); + return; + } + + try { + int exitStatus = FromString<int>(*optionalExitCode); + result.TrySet(exitStatus); + } catch (const std::exception& ex) { + auto error = TError("Failed to parse porto exit status") + << TErrorAttribute("container_name", container) + << TErrorAttribute("exit_status", optionalExitCode.value()); + error.MutableInnerErrors()->push_back(TError(ex)); + result.TrySet(error); + } + }))); + } + + TFuture<int> DoPollContainer(const TString& container) + { + auto [it, inserted] = ContainerMap_.insert({container, NewPromise<int>()}); + if (!inserted) { + YT_LOG_WARNING("Container already added for polling (Container: %v)", + container); + } else { + Containers_.push_back(container); + } + return it->second.ToFuture(); + } + + Porto::TGetResponse DoRequestContainerProperties( + const std::vector<TString>& containers, + const std::vector<TString>& vars) + { + TVector<TString> containers_(containers.begin(), containers.end()); + TVector<TString> vars_(vars.begin(), vars.end()); + + const Porto::TGetResponse* getResponse; + + ExecuteApiCall( + [&] { + getResponse = Api_->Get(containers_, vars_); + return getResponse ? EError::Success : EError::Unknown; + }, + "Get", + /*idempotent*/ true); + + YT_VERIFY(getResponse); + return *getResponse; + } + + THashMap<TString, i64> DoGetContainerMetrics( + const std::vector<TString>& containers, + const TString& metric) + { + TVector<TString> containers_(containers.begin(), containers.end()); + + TMap<TString, uint64_t> result; + + ExecuteApiCall( + [&] { return Api_->GetProcMetric(containers_, metric, result); }, + "GetProcMetric", + /*idempotent*/ true); + + return {result.begin(), result.end()}; + } + + void DoPoll() + { + try { + if (Containers_.empty()) { + return; + } + + auto getResponse = DoRequestContainerProperties(Containers_, ContainerRequestVars_); + + if (getResponse.list().empty()) { + return; + } + + auto getProperty = [] ( + const Porto::TGetResponse::TContainerGetListResponse& container, + const TString& name) -> Porto::TGetResponse::TContainerGetValueResponse + { + for (const auto& property : container.keyval()) { + if (property.variable() == name) { + return property; + } + } + + return {}; + }; + + for (const auto& container : getResponse.list()) { + auto state = getProperty(container, "state"); + if (state.error() == EError::ContainerDoesNotExist) { + HandleResult(container.name(), state); + } else if (state.value() == "dead" || state.value() == "stopped") { + HandleResult(container.name(), getProperty(container, "exit_status")); + } + //TODO(dcherednik): other states + } + } catch (const std::exception& ex) { + YT_LOG_ERROR(ex, "Fatal exception occurred while polling Porto"); + Failed_.Fire(TError(ex)); + } + } + + TString DoCreateVolume( + const TString& path, + const THashMap<TString, TString>& properties) + { + auto volume = path; + TMap<TString, TString> propertyMap(properties.begin(), properties.end()); + ExecuteApiCall( + [&] { return Api_->CreateVolume(volume, propertyMap); }, + "CreateVolume", + /*idempotent*/ false); + return volume; + } + + void DoLinkVolume(const TString& path, const TString& container) + { + ExecuteApiCall( + [&] { return Api_->LinkVolume(path, container); }, + "LinkVolume", + /*idempotent*/ false); + } + + void DoUnlinkVolume(const TString& path, const TString& container) + { + ExecuteApiCall( + [&] { return Api_->UnlinkVolume(path, container); }, + "UnlinkVolume", + /*idempotent*/ false); + } + + std::vector<TString> DoListVolumePaths() + { + TVector<TString> volumes; + ExecuteApiCall( + [&] { return Api_->ListVolumes(volumes); }, + "ListVolume", + /*idempotent*/ true); + return {volumes.begin(), volumes.end()}; + } + + void DoImportLayer(const TString& archivePath, const TString& layerId, const TString& place) + { + ExecuteApiCall( + [&] { return Api_->ImportLayer(layerId, archivePath, false, place); }, + "ImportLayer", + /*idempotent*/ false); + } + + void DoRemoveLayer(const TString& layerId, const TString& place, bool async) + { + ExecuteApiCall( + [&] { return Api_->RemoveLayer(layerId, place, async); }, + "RemoveLayer", + /*idempotent*/ false); + } + + std::vector<TString> DoListLayers(const TString& place) + { + TVector<TString> layers; + ExecuteApiCall( + [&] { return Api_->ListLayers(layers, place); }, + "ListLayers", + /*idempotent*/ true); + return {layers.begin(), layers.end()}; + } + + TCommandEntry* GetCommandEntry(const TString& command) + { + auto guard = Guard(CommandLock_); + if (auto it = CommandToEntry_.find(command)) { + return &it->second; + } + return &CommandToEntry_.emplace(command, TCommandEntry(Profiler_.WithTag("command", command))).first->second; + } + + void ExecuteApiCall( + std::function<EError()> callback, + const TString& command, + bool idempotent) + { + YT_LOG_DEBUG("Porto API call started (Command: %v)", command); + + if (IsTestPortoTimeout()) { + YT_LOG_DEBUG("Testing porto timeout (Command: %v)", command); + + auto config = DynamicConfig_.Acquire(); + TDelayedExecutor::WaitForDuration(config->ApiTimeout); + + THROW_ERROR CreatePortoError(GetFailedStubError(), "Porto timeout"); + } + + if (IsTestPortoFailureEnabled()) { + YT_LOG_DEBUG("Testing porto failure (Command: %v)", command); + THROW_ERROR CreatePortoError(GetFailedStubError(), "Porto stub error"); + } + + auto* entry = GetCommandEntry(command); + auto startTime = NProfiling::GetInstant(); + while (true) { + EError error; + + { + NProfiling::TWallTimer timer; + error = callback(); + entry->TimeGauge.Record(timer.GetElapsedTime()); + } + + if (error == EError::Success) { + entry->SuccessCounter.Increment(); + break; + } + + entry->FailureCounter.Increment(); + HandleApiError(command, startTime, idempotent); + + YT_LOG_DEBUG("Sleeping and retrying Porto API call (Command: %v)", command); + entry->RetryCounter.Increment(); + + TDelayedExecutor::WaitForDuration(RetryInterval); + } + + YT_LOG_DEBUG("Porto API call completed (Command: %v)", command); + } + + void HandleApiError( + const TString& command, + TInstant startTime, + bool idempotent) + { + TString errorMessage; + auto error = ConvertPortoErrorCode(Api_->GetLastError(errorMessage)); + + // These errors are typical during job cleanup: we might try to kill a container that is already stopped. + bool debug = (error == EPortoErrorCode::ContainerDoesNotExist || error == EPortoErrorCode::InvalidState); + YT_LOG_EVENT( + Logger, + debug ? NLogging::ELogLevel::Debug : NLogging::ELogLevel::Error, + "Porto API call error (Error: %v, Command: %v, Message: %v)", + error, + command, + errorMessage); + + if (!IsRetriableErrorCode(error, idempotent) || NProfiling::GetInstant() - startTime > Config_->RetriesTimeout) { + THROW_ERROR CreatePortoError(error, errorMessage); + } + } + + void HandleResult(const TString& container, const Porto::TGetResponse::TContainerGetValueResponse& rsp) + { + auto portoErrorCode = ConvertPortoErrorCode(rsp.error()); + auto it = ContainerMap_.find(container); + if (it == ContainerMap_.end()) { + YT_LOG_ERROR("Got an unexpected container " + "(Container: %v, ResponseError: %v, ErrorMessage: %v, Value: %v)", + container, + portoErrorCode, + rsp.errormsg(), + rsp.value()); + return; + } else { + if (portoErrorCode != EPortoErrorCode::Success) { + YT_LOG_ERROR("Container finished with Porto API error " + "(Container: %v, ResponseError: %v, ErrorMessage: %v, Value: %v)", + container, + portoErrorCode, + rsp.errormsg(), + rsp.value()); + it->second.Set(CreatePortoError(portoErrorCode, rsp.errormsg())); + } else { + try { + int exitStatus = std::stoi(rsp.value()); + YT_LOG_DEBUG("Container finished with exit code (Container: %v, ExitCode: %v)", + container, + exitStatus); + + it->second.Set(exitStatus); + } catch (const std::exception& ex) { + it->second.Set(TError("Failed to parse Porto exit status") << ex); + } + } + } + RemoveFromPoller(container); + } + + void RemoveFromPoller(const TString& container) + { + ContainerMap_.erase(container); + + Containers_.clear(); + for (const auto& [name, pid] : ContainerMap_) { + Containers_.push_back(name); + } + } +}; + +const std::vector<TString> TPortoExecutor::ContainerRequestVars_ = { + "state", + "exit_status" +}; + +//////////////////////////////////////////////////////////////////////////////// + +IPortoExecutorPtr CreatePortoExecutor( + TPortoExecutorDynamicConfigPtr config, + const TString& threadNameSuffix, + const NProfiling::TProfiler& profiler) +{ + return New<TPortoExecutor>( + std::move(config), + threadNameSuffix, + profiler); +} + +//////////////////////////////////////////////////////////////////////////////// + +#else + +IPortoExecutorPtr CreatePortoExecutor( + TPortoExecutorDynamicConfigPtr /* config */, + const TString& /* threadNameSuffix */, + const NProfiling::TProfiler& /* profiler */) +{ + THROW_ERROR_EXCEPTION("Porto executor is not available on this platform"); +} + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_executor.h b/yt/yt/library/containers/porto_executor.h new file mode 100644 index 0000000000..d629ab6275 --- /dev/null +++ b/yt/yt/library/containers/porto_executor.h @@ -0,0 +1,142 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/future.h> +#include <yt/yt/core/actions/signal.h> + +#include <yt/yt/core/net/address.h> + +#include <library/cpp/porto/libporto.hpp> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +struct TVolumeId +{ + TString Path; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TRunnableContainerSpec +{ + TString Name; + TString Command; + + EEnablePorto EnablePorto = EEnablePorto::None; + bool Isolate = true; + + std::optional<TString> StdinPath; + std::optional<TString> StdoutPath; + std::optional<TString> StderrPath; + std::optional<TString> CurrentWorkingDirectory; + std::optional<TString> CoreCommand; + std::optional<TString> User; + std::optional<int> GroupId; + + bool EnableCoreDumps = true; + + std::optional<i64> ThreadLimit; + + std::optional<TString> HostName; + std::vector<NYT::NNet::TIP6Address> IPAddresses; + bool EnableNat64 = false; + bool DisableNetwork = false; + + THashMap<TString, TString> Labels; + THashMap<TString, TString> Env; + std::vector<TString> CGroupControllers; + std::vector<TDevice> Devices; + std::optional<TRootFS> RootFS; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct IPortoExecutor + : public TRefCounted +{ + virtual void OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig) = 0; + + virtual TFuture<void> CreateContainer(const TString& container) = 0; + + virtual TFuture<void> CreateContainer(const TRunnableContainerSpec& containerSpec, bool start) = 0; + + virtual TFuture<void> SetContainerProperty( + const TString& container, + const TString& property, + const TString& value) = 0; + + virtual TFuture<std::optional<TString>> GetContainerProperty( + const TString& container, + const TString& property) = 0; + + virtual TFuture<THashMap<TString, TErrorOr<TString>>> GetContainerProperties( + const TString& container, + const std::vector<TString>& properties) = 0; + virtual TFuture<THashMap<TString, THashMap<TString, TErrorOr<TString>>>> GetContainerProperties( + const std::vector<TString>& containers, + const std::vector<TString>& properties) = 0; + + virtual TFuture<THashMap<TString, i64>> GetContainerMetrics( + const std::vector<TString>& containers, + const TString& metric) = 0; + virtual TFuture<void> DestroyContainer(const TString& container) = 0; + virtual TFuture<void> StopContainer(const TString& container) = 0; + virtual TFuture<void> StartContainer(const TString& container) = 0; + virtual TFuture<void> KillContainer(const TString& container, int signal) = 0; + + virtual TFuture<TString> ConvertPath(const TString& path, const TString& container) = 0; + + // Returns absolute names of immediate children only. + virtual TFuture<std::vector<TString>> ListSubcontainers( + const TString& rootContainer, + bool includeRoot) = 0; + // Starts polling a given container, returns future with exit code of finished process. + virtual TFuture<int> PollContainer(const TString& container) = 0; + + // Returns future with exit code of finished process. + // NB: temporarily broken, see https://st.yandex-team.ru/PORTO-846 for details. + virtual TFuture<int> WaitContainer(const TString& container) = 0; + + virtual TFuture<TString> CreateVolume( + const TString& path, + const THashMap<TString, TString>& properties) = 0; + virtual TFuture<void> LinkVolume( + const TString& path, + const TString& name) = 0; + virtual TFuture<void> UnlinkVolume( + const TString& path, + const TString& name) = 0; + virtual TFuture<std::vector<TString>> ListVolumePaths() = 0; + + virtual TFuture<void> ImportLayer( + const TString& archivePath, + const TString& layerId, + const TString& place) = 0; + virtual TFuture<void> RemoveLayer( + const TString& layerId, + const TString& place, + bool async) = 0; + virtual TFuture<std::vector<TString>> ListLayers(const TString& place) = 0; + + virtual IInvokerPtr GetInvoker() const = 0; + + DECLARE_INTERFACE_SIGNAL(void(const TError&), Failed); +}; + +DEFINE_REFCOUNTED_TYPE(IPortoExecutor) + +//////////////////////////////////////////////////////////////////////////////// + +IPortoExecutorPtr CreatePortoExecutor( + TPortoExecutorDynamicConfigPtr config, + const TString& threadNameSuffix, + const NProfiling::TProfiler& profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_health_checker.cpp b/yt/yt/library/containers/porto_health_checker.cpp new file mode 100644 index 0000000000..1a86fef841 --- /dev/null +++ b/yt/yt/library/containers/porto_health_checker.cpp @@ -0,0 +1,69 @@ + +#include "porto_health_checker.h" + +#include "porto_executor.h" +#include "private.h" +#include "config.h" + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/misc/fs.h> + +#include <util/random/random.h> + +namespace NYT::NContainers { + +using namespace NConcurrency; +using namespace NLogging; +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +TPortoHealthChecker::TPortoHealthChecker( + TPortoExecutorDynamicConfigPtr config, + IInvokerPtr invoker, + TLogger logger) + : Config_(std::move(config)) + , Logger(std::move(logger)) + , CheckInvoker_(std::move(invoker)) + , Executor_(CreatePortoExecutor( + Config_, + "porto_check")) +{ } + +void TPortoHealthChecker::Start() +{ + YT_LOG_DEBUG("Porto health checker started"); + + PeriodicExecutor_ = New<TPeriodicExecutor>( + CheckInvoker_, + BIND(&TPortoHealthChecker::OnCheck, MakeWeak(this)), + Config_->RetriesTimeout); + PeriodicExecutor_->Start(); +} + +void TPortoHealthChecker::OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig) +{ + YT_LOG_DEBUG( + "Porto health checker dynamic config changed (EnableTestPortoFailures: %v, StubErrorCode: %v)", + Config_->EnableTestPortoFailures, + Config_->StubErrorCode); + + Executor_->OnDynamicConfigChanged(newConfig); +} + +void TPortoHealthChecker::OnCheck() +{ + YT_LOG_DEBUG("Run porto health check"); + + auto result = WaitFor(Executor_->ListVolumePaths().AsVoid()); + if (result.IsOK()) { + Success_.Fire(); + } else { + Failed_.Fire(result); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_health_checker.h b/yt/yt/library/containers/porto_health_checker.h new file mode 100644 index 0000000000..f0fb8f0908 --- /dev/null +++ b/yt/yt/library/containers/porto_health_checker.h @@ -0,0 +1,52 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/profiling/sensor.h> + +#include <yt/yt/core/actions/signal.h> + +#include <yt/yt/core/concurrency/periodic_executor.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/error.h> + +#include <atomic> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoHealthChecker + : public TRefCounted +{ +public: + TPortoHealthChecker( + TPortoExecutorDynamicConfigPtr config, + IInvokerPtr invoker, + NLogging::TLogger logger); + + void Start(); + + void OnDynamicConfigChanged(const TPortoExecutorDynamicConfigPtr& newConfig); + + DEFINE_SIGNAL(void(), Success); + + DEFINE_SIGNAL(void(const TError&), Failed); + +private: + const TPortoExecutorDynamicConfigPtr Config_; + const NLogging::TLogger Logger; + const IInvokerPtr CheckInvoker_; + const IPortoExecutorPtr Executor_; + NConcurrency::TPeriodicExecutorPtr PeriodicExecutor_; + + void OnCheck(); +}; + +DEFINE_REFCOUNTED_TYPE(TPortoHealthChecker) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_resource_tracker.cpp b/yt/yt/library/containers/porto_resource_tracker.cpp new file mode 100644 index 0000000000..871379514c --- /dev/null +++ b/yt/yt/library/containers/porto_resource_tracker.cpp @@ -0,0 +1,711 @@ +#include "porto_resource_tracker.h" +#include "private.h" + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/error.h> + +#include <yt/yt/core/net/address.h> + +#include <yt/yt/core/ytree/public.h> + +#include <yt/yt/library/process/process.h> + +#include <yt/yt/library/containers/cgroup.h> +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/instance.h> +#include <yt/yt/library/containers/porto_executor.h> +#include <yt/yt/library/containers/public.h> + +namespace NYT::NContainers { + +using namespace NProfiling; + +static const auto& Logger = ContainersLogger; + +#ifdef _linux_ + +//////////////////////////////////////////////////////////////////////////////// + +struct TPortoProfilers + : public TRefCounted +{ + TPortoResourceProfilerPtr DaemonProfiler; + TPortoResourceProfilerPtr ContainerProfiler; + + TPortoProfilers( + TPortoResourceProfilerPtr daemonProfiler, + TPortoResourceProfilerPtr containerProfiler) + : DaemonProfiler(std::move(daemonProfiler)) + , ContainerProfiler(std::move(containerProfiler)) + { } +}; + +DEFINE_REFCOUNTED_TYPE(TPortoProfilers) + +//////////////////////////////////////////////////////////////////////////////// + +static TErrorOr<ui64> GetFieldOrError( + const TResourceUsage& usage, + EStatField field) +{ + auto it = usage.find(field); + if (it == usage.end()) { + return TError("Resource usage is missing %Qlv field", field); + } + const auto& errorOrValue = it->second; + if (errorOrValue.FindMatching(EPortoErrorCode::NotSupported)) { + return TError("Property %Qlv not supported in Porto response", field); + } + return errorOrValue; +} + +//////////////////////////////////////////////////////////////////////////////// + +TPortoResourceTracker::TPortoResourceTracker( + IInstancePtr instance, + TDuration updatePeriod, + bool isDeltaTracker, + bool isForceUpdate) + : Instance_(std::move(instance)) + , UpdatePeriod_(updatePeriod) + , IsDeltaTracker_(isDeltaTracker) + , IsForceUpdate_(isForceUpdate) +{ + ResourceUsage_ = { + {EStatField::IOReadByte, 0}, + {EStatField::IOWriteByte, 0}, + {EStatField::IOBytesLimit, 0}, + {EStatField::IOReadOps, 0}, + {EStatField::IOWriteOps, 0}, + {EStatField::IOOps, 0}, + {EStatField::IOOpsLimit, 0}, + {EStatField::IOTotalTime, 0}, + {EStatField::IOWaitTime, 0} + }; + ResourceUsageDelta_ = ResourceUsage_; +} + +static TErrorOr<TDuration> ExtractDuration(TErrorOr<ui64> timeNs) +{ + if (timeNs.IsOK()) { + return TErrorOr<TDuration>(TDuration::MicroSeconds(timeNs.Value() / 1000)); + } else { + return TError(timeNs); + } +} + +TCpuStatistics TPortoResourceTracker::ExtractCpuStatistics(const TResourceUsage& resourceUsage) const +{ + // NB: Job proxy uses last sample of CPU statistics but we are interested in + // peak thread count value. + auto currentThreadCountPeak = GetFieldOrError(resourceUsage, EStatField::ThreadCount); + + PeakThreadCount_ = currentThreadCountPeak.IsOK() && PeakThreadCount_.IsOK() + ? std::max<ui64>( + PeakThreadCount_.Value(), + currentThreadCountPeak.Value()) + : currentThreadCountPeak.IsOK() ? currentThreadCountPeak : PeakThreadCount_; + + auto totalTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuUsage); + auto systemTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuSystemUsage); + auto userTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuUserUsage); + auto waitTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuWait); + auto throttledNs = GetFieldOrError(resourceUsage, EStatField::CpuThrottled); + auto limitTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuLimit); + auto guaranteeTimeNs = GetFieldOrError(resourceUsage, EStatField::CpuGuarantee); + + return TCpuStatistics{ + .TotalUsageTime = ExtractDuration(totalTimeNs), + .UserUsageTime = ExtractDuration(userTimeNs), + .SystemUsageTime = ExtractDuration(systemTimeNs), + .WaitTime = ExtractDuration(waitTimeNs), + .ThrottledTime = ExtractDuration(throttledNs), + .ThreadCount = GetFieldOrError(resourceUsage, EStatField::ThreadCount), + .ContextSwitches = GetFieldOrError(resourceUsage, EStatField::ContextSwitches), + .ContextSwitchesDelta = GetFieldOrError(resourceUsage, EStatField::ContextSwitchesDelta), + .PeakThreadCount = PeakThreadCount_, + .LimitTime = ExtractDuration(limitTimeNs), + .GuaranteeTime = ExtractDuration(guaranteeTimeNs), + }; +} + +TMemoryStatistics TPortoResourceTracker::ExtractMemoryStatistics(const TResourceUsage& resourceUsage) const +{ + return TMemoryStatistics{ + .Rss = GetFieldOrError(resourceUsage, EStatField::Rss), + .MappedFile = GetFieldOrError(resourceUsage, EStatField::MappedFile), + .MinorPageFaults = GetFieldOrError(resourceUsage, EStatField::MinorPageFaults), + .MajorPageFaults = GetFieldOrError(resourceUsage, EStatField::MajorPageFaults), + .FileCacheUsage = GetFieldOrError(resourceUsage, EStatField::FileCacheUsage), + .AnonUsage = GetFieldOrError(resourceUsage, EStatField::AnonMemoryUsage), + .AnonLimit = GetFieldOrError(resourceUsage, EStatField::AnonMemoryLimit), + .MemoryUsage = GetFieldOrError(resourceUsage, EStatField::MemoryUsage), + .MemoryGuarantee = GetFieldOrError(resourceUsage, EStatField::MemoryGuarantee), + .MemoryLimit = GetFieldOrError(resourceUsage, EStatField::MemoryLimit), + .MaxMemoryUsage = GetFieldOrError(resourceUsage, EStatField::MaxMemoryUsage), + .OomKills = GetFieldOrError(resourceUsage, EStatField::OomKills), + .OomKillsTotal = GetFieldOrError(resourceUsage, EStatField::OomKillsTotal) + }; +} + +TBlockIOStatistics TPortoResourceTracker::ExtractBlockIOStatistics(const TResourceUsage& resourceUsage) const +{ + auto totalTimeNs = GetFieldOrError(resourceUsage, EStatField::IOTotalTime); + auto waitTimeNs = GetFieldOrError(resourceUsage, EStatField::IOWaitTime); + + return TBlockIOStatistics{ + .IOReadByte = GetFieldOrError(resourceUsage, EStatField::IOReadByte), + .IOWriteByte = GetFieldOrError(resourceUsage, EStatField::IOWriteByte), + .IOBytesLimit = GetFieldOrError(resourceUsage, EStatField::IOBytesLimit), + .IOReadOps = GetFieldOrError(resourceUsage, EStatField::IOReadOps), + .IOWriteOps = GetFieldOrError(resourceUsage, EStatField::IOWriteOps), + .IOOps = GetFieldOrError(resourceUsage, EStatField::IOOps), + .IOOpsLimit = GetFieldOrError(resourceUsage, EStatField::IOOpsLimit), + .IOTotalTime = ExtractDuration(totalTimeNs), + .IOWaitTime = ExtractDuration(waitTimeNs) + }; +} + +TNetworkStatistics TPortoResourceTracker::ExtractNetworkStatistics(const TResourceUsage& resourceUsage) const +{ + return TNetworkStatistics{ + .TxBytes = GetFieldOrError(resourceUsage, EStatField::NetTxBytes), + .TxPackets = GetFieldOrError(resourceUsage, EStatField::NetTxPackets), + .TxDrops = GetFieldOrError(resourceUsage, EStatField::NetTxDrops), + .TxLimit = GetFieldOrError(resourceUsage, EStatField::NetTxLimit), + + .RxBytes = GetFieldOrError(resourceUsage, EStatField::NetRxBytes), + .RxPackets = GetFieldOrError(resourceUsage, EStatField::NetRxPackets), + .RxDrops = GetFieldOrError(resourceUsage, EStatField::NetRxDrops), + .RxLimit = GetFieldOrError(resourceUsage, EStatField::NetRxLimit), + }; +} + +TTotalStatistics TPortoResourceTracker::ExtractTotalStatistics(const TResourceUsage& resourceUsage) const +{ + return TTotalStatistics{ + .CpuStatistics = ExtractCpuStatistics(resourceUsage), + .MemoryStatistics = ExtractMemoryStatistics(resourceUsage), + .BlockIOStatistics = ExtractBlockIOStatistics(resourceUsage), + .NetworkStatistics = ExtractNetworkStatistics(resourceUsage), + }; +} + +TCpuStatistics TPortoResourceTracker::GetCpuStatistics() const +{ + return GetStatistics( + CachedCpuStatistics_, + "CPU", + [&] (TResourceUsage& resourceUsage) { + return ExtractCpuStatistics(resourceUsage); + }); +} + +TMemoryStatistics TPortoResourceTracker::GetMemoryStatistics() const +{ + return GetStatistics( + CachedMemoryStatistics_, + "memory", + [&] (TResourceUsage& resourceUsage) { + return ExtractMemoryStatistics(resourceUsage); + }); +} + +TBlockIOStatistics TPortoResourceTracker::GetBlockIOStatistics() const +{ + return GetStatistics( + CachedBlockIOStatistics_, + "block IO", + [&] (TResourceUsage& resourceUsage) { + return ExtractBlockIOStatistics(resourceUsage); + }); +} + +TNetworkStatistics TPortoResourceTracker::GetNetworkStatistics() const +{ + return GetStatistics( + CachedNetworkStatistics_, + "network", + [&] (TResourceUsage& resourceUsage) { + return ExtractNetworkStatistics(resourceUsage); + }); +} + +TTotalStatistics TPortoResourceTracker::GetTotalStatistics() const +{ + return GetStatistics( + CachedTotalStatistics_, + "total", + [&] (TResourceUsage& resourceUsage) { + return ExtractTotalStatistics(resourceUsage); + }); +} + +template <class T, class F> +T TPortoResourceTracker::GetStatistics( + std::optional<T>& cachedStatistics, + const TString& statisticsKind, + F extractor) const +{ + UpdateResourceUsageStatisticsIfExpired(); + + auto guard = Guard(SpinLock_); + try { + auto newStatistics = extractor(IsDeltaTracker_ ? ResourceUsageDelta_ : ResourceUsage_); + cachedStatistics = newStatistics; + return newStatistics; + } catch (const std::exception& ex) { + if (!cachedStatistics) { + THROW_ERROR_EXCEPTION("Unable to get %v statistics", statisticsKind) + << ex; + } + YT_LOG_WARNING(ex, "Unable to get %v statistics; using the last one", statisticsKind); + return *cachedStatistics; + } +} + +bool TPortoResourceTracker::AreResourceUsageStatisticsExpired() const +{ + return TInstant::Now() - LastUpdateTime_.load() > UpdatePeriod_; +} + +TInstant TPortoResourceTracker::GetLastUpdateTime() const +{ + return LastUpdateTime_.load(); +} + +void TPortoResourceTracker::UpdateResourceUsageStatisticsIfExpired() const +{ + if (IsForceUpdate_ || AreResourceUsageStatisticsExpired()) { + DoUpdateResourceUsage(); + } +} + +TErrorOr<ui64> TPortoResourceTracker::CalculateCounterDelta( + const TErrorOr<ui64>& oldValue, + const TErrorOr<ui64>& newValue) const +{ + if (oldValue.IsOK() && newValue.IsOK()) { + return newValue.Value() - oldValue.Value(); + } else if (newValue.IsOK()) { + // It is better to return an error than an incorrect value. + return oldValue; + } else { + return newValue; + } +} + +static bool IsCumulativeStatistics(EStatField statistic) +{ + return + statistic == EStatField::CpuUsage || + statistic == EStatField::CpuUserUsage || + statistic == EStatField::CpuSystemUsage || + statistic == EStatField::CpuWait || + statistic == EStatField::CpuThrottled || + + statistic == EStatField::ContextSwitches || + + statistic == EStatField::MinorPageFaults || + statistic == EStatField::MajorPageFaults || + statistic == EStatField::OomKills || + statistic == EStatField::OomKillsTotal || + + statistic == EStatField::IOReadByte || + statistic == EStatField::IOWriteByte || + statistic == EStatField::IOReadOps || + statistic == EStatField::IOWriteOps || + statistic == EStatField::IOOps || + statistic == EStatField::IOTotalTime || + statistic == EStatField::IOWaitTime || + + statistic == EStatField::NetTxBytes || + statistic == EStatField::NetTxPackets || + statistic == EStatField::NetTxDrops || + statistic == EStatField::NetRxBytes || + statistic == EStatField::NetRxPackets || + statistic == EStatField::NetRxDrops; +} + +void TPortoResourceTracker::ReCalculateResourceUsage(const TResourceUsage& newResourceUsage) const +{ + auto guard = Guard(SpinLock_); + + TResourceUsage resourceUsage; + TResourceUsage resourceUsageDelta; + + for (const auto& stat : InstanceStatFields) { + TErrorOr<ui64> oldValue; + TErrorOr<ui64> newValue; + + if (auto newValueIt = newResourceUsage.find(stat); newValueIt.IsEnd()) { + newValue = TError("Missing property %Qlv in Porto response", stat) + << TErrorAttribute("container", Instance_->GetName()); + } else { + newValue = newValueIt->second; + } + + if (auto oldValueIt = ResourceUsage_.find(stat); oldValueIt.IsEnd()) { + oldValue = newValue; + } else { + oldValue = oldValueIt->second; + } + + if (newValue.IsOK()) { + resourceUsage[stat] = newValue; + } else { + resourceUsage[stat] = oldValue; + } + + if (IsCumulativeStatistics(stat)) { + resourceUsageDelta[stat] = CalculateCounterDelta(oldValue, newValue); + } else { + if (newValue.IsOK()) { + resourceUsageDelta[stat] = newValue; + } else { + resourceUsageDelta[stat] = oldValue; + } + } + } + + ResourceUsage_ = resourceUsage; + ResourceUsageDelta_ = resourceUsageDelta; + LastUpdateTime_.store(TInstant::Now()); +} + +void TPortoResourceTracker::DoUpdateResourceUsage() const +{ + try { + ReCalculateResourceUsage(Instance_->GetResourceUsage()); + } catch (const std::exception& ex) { + YT_LOG_ERROR( + ex, + "Couldn't get metrics from porto"); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TPortoResourceProfiler::TPortoResourceProfiler( + TPortoResourceTrackerPtr tracker, + TPodSpecConfigPtr podSpec, + const TProfiler& profiler) + : ResourceTracker_(std::move(tracker)) + , PodSpec_(std::move(podSpec)) +{ + profiler.AddProducer("", MakeStrong(this)); +} + +static void WriteGaugeIfOk( + ISensorWriter* writer, + const TString& path, + TErrorOr<ui64> valueOrError) +{ + if (valueOrError.IsOK()) { + i64 value = static_cast<i64>(valueOrError.Value()); + + if (value >= 0) { + writer->AddGauge(path, value); + } + } +} + +static void WriteCumulativeGaugeIfOk( + ISensorWriter* writer, + const TString& path, + TErrorOr<ui64> valueOrError, + i64 timeDeltaUsec) +{ + if (valueOrError.IsOK()) { + i64 value = static_cast<i64>(valueOrError.Value()); + + if (value >= 0) { + writer->AddGauge(path, + 1.0 * value * ResourceUsageUpdatePeriod.MicroSeconds() / timeDeltaUsec); + } + } +} + +void TPortoResourceProfiler::WriteCpuMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + { + if (totalStatistics.CpuStatistics.UserUsageTime.IsOK()) { + i64 userUsageTimeUs = totalStatistics.CpuStatistics.UserUsageTime.Value().MicroSeconds(); + double userUsagePercent = std::max<double>(0.0, 100. * userUsageTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/user", userUsagePercent); + } + + if (totalStatistics.CpuStatistics.SystemUsageTime.IsOK()) { + i64 systemUsageTimeUs = totalStatistics.CpuStatistics.SystemUsageTime.Value().MicroSeconds(); + double systemUsagePercent = std::max<double>(0.0, 100. * systemUsageTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/system", systemUsagePercent); + } + + if (totalStatistics.CpuStatistics.WaitTime.IsOK()) { + i64 waitTimeUs = totalStatistics.CpuStatistics.WaitTime.Value().MicroSeconds(); + double waitPercent = std::max<double>(0.0, 100. * waitTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/wait", waitPercent); + } + + if (totalStatistics.CpuStatistics.ThrottledTime.IsOK()) { + i64 throttledTimeUs = totalStatistics.CpuStatistics.ThrottledTime.Value().MicroSeconds(); + double throttledPercent = std::max<double>(0.0, 100. * throttledTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/throttled", throttledPercent); + } + + if (totalStatistics.CpuStatistics.TotalUsageTime.IsOK()) { + i64 totalUsageTimeUs = totalStatistics.CpuStatistics.TotalUsageTime.Value().MicroSeconds(); + double totalUsagePercent = std::max<double>(0.0, 100. * totalUsageTimeUs / timeDeltaUsec); + writer->AddGauge("/cpu/total", totalUsagePercent); + } + + if (totalStatistics.CpuStatistics.GuaranteeTime.IsOK()) { + i64 guaranteeTimeUs = totalStatistics.CpuStatistics.GuaranteeTime.Value().MicroSeconds(); + double guaranteePercent = std::max<double>(0.0, (100. * guaranteeTimeUs) / (1'000'000L)); + writer->AddGauge("/cpu/guarantee", guaranteePercent); + } + + if (totalStatistics.CpuStatistics.LimitTime.IsOK()) { + i64 limitTimeUs = totalStatistics.CpuStatistics.LimitTime.Value().MicroSeconds(); + double limitPercent = std::max<double>(0.0, (100. * limitTimeUs) / (1'000'000L)); + writer->AddGauge("/cpu/limit", limitPercent); + } + } + + if (PodSpec_->CpuToVCpuFactor) { + auto factor = *PodSpec_->CpuToVCpuFactor; + + writer->AddGauge("/cpu_to_vcpu_factor", factor); + + if (totalStatistics.CpuStatistics.UserUsageTime.IsOK()) { + i64 userUsageTimeUs = totalStatistics.CpuStatistics.UserUsageTime.Value().MicroSeconds(); + double userUsagePercent = std::max<double>(0.0, 100. * userUsageTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/user", userUsagePercent); + } + + if (totalStatistics.CpuStatistics.SystemUsageTime.IsOK()) { + i64 systemUsageTimeUs = totalStatistics.CpuStatistics.SystemUsageTime.Value().MicroSeconds(); + double systemUsagePercent = std::max<double>(0.0, 100. * systemUsageTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/system", systemUsagePercent); + } + + if (totalStatistics.CpuStatistics.WaitTime.IsOK()) { + i64 waitTimeUs = totalStatistics.CpuStatistics.WaitTime.Value().MicroSeconds(); + double waitPercent = std::max<double>(0.0, 100. * waitTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/wait", waitPercent); + } + + if (totalStatistics.CpuStatistics.ThrottledTime.IsOK()) { + i64 throttledTimeUs = totalStatistics.CpuStatistics.ThrottledTime.Value().MicroSeconds(); + double throttledPercent = std::max<double>(0.0, 100. * throttledTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/throttled", throttledPercent); + } + + if (totalStatistics.CpuStatistics.TotalUsageTime.IsOK()) { + i64 totalUsageTimeUs = totalStatistics.CpuStatistics.TotalUsageTime.Value().MicroSeconds(); + double totalUsagePercent = std::max<double>(0.0, 100. * totalUsageTimeUs * factor / timeDeltaUsec); + writer->AddGauge("/vcpu/total", totalUsagePercent); + } + + if (totalStatistics.CpuStatistics.GuaranteeTime.IsOK()) { + i64 guaranteeTimeUs = totalStatistics.CpuStatistics.GuaranteeTime.Value().MicroSeconds(); + double guaranteePercent = std::max<double>(0.0, 100. * guaranteeTimeUs * factor / 1'000'000L); + writer->AddGauge("/vcpu/guarantee", guaranteePercent); + } + + if (totalStatistics.CpuStatistics.LimitTime.IsOK()) { + i64 limitTimeUs = totalStatistics.CpuStatistics.LimitTime.Value().MicroSeconds(); + double limitPercent = std::max<double>(0.0, 100. * limitTimeUs * factor / 1'000'000L); + writer->AddGauge("/vcpu/limit", limitPercent); + } + } + + WriteGaugeIfOk(writer, "/cpu/thread_count", totalStatistics.CpuStatistics.ThreadCount); + WriteGaugeIfOk(writer, "/cpu/context_switches", totalStatistics.CpuStatistics.ContextSwitches); +} + +void TPortoResourceProfiler::WriteMemoryMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + WriteCumulativeGaugeIfOk(writer, + "/memory/minor_page_faults", + totalStatistics.MemoryStatistics.MinorPageFaults, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/memory/major_page_faults", + totalStatistics.MemoryStatistics.MajorPageFaults, + timeDeltaUsec); + + WriteGaugeIfOk(writer, "/memory/oom_kills", totalStatistics.MemoryStatistics.OomKills); + WriteGaugeIfOk(writer, "/memory/oom_kills_total", totalStatistics.MemoryStatistics.OomKillsTotal); + + WriteGaugeIfOk(writer, "/memory/file_cache_usage", totalStatistics.MemoryStatistics.FileCacheUsage); + WriteGaugeIfOk(writer, "/memory/anon_usage", totalStatistics.MemoryStatistics.AnonUsage); + WriteGaugeIfOk(writer, "/memory/anon_limit", totalStatistics.MemoryStatistics.AnonLimit); + WriteGaugeIfOk(writer, "/memory/memory_usage", totalStatistics.MemoryStatistics.MemoryUsage); + WriteGaugeIfOk(writer, "/memory/memory_guarantee", totalStatistics.MemoryStatistics.MemoryGuarantee); + WriteGaugeIfOk(writer, "/memory/memory_limit", totalStatistics.MemoryStatistics.MemoryLimit); +} + +void TPortoResourceProfiler::WriteBlockingIOMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + WriteCumulativeGaugeIfOk(writer, + "/io/read_bytes", + totalStatistics.BlockIOStatistics.IOReadByte, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/write_bytes", + totalStatistics.BlockIOStatistics.IOWriteByte, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/read_ops", + totalStatistics.BlockIOStatistics.IOReadOps, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/write_ops", + totalStatistics.BlockIOStatistics.IOWriteOps, + timeDeltaUsec); + WriteCumulativeGaugeIfOk(writer, + "/io/ops", + totalStatistics.BlockIOStatistics.IOOps, + timeDeltaUsec); + + WriteGaugeIfOk(writer, + "/io/bytes_limit", + totalStatistics.BlockIOStatistics.IOBytesLimit); + WriteGaugeIfOk(writer, + "/io/ops_limit", + totalStatistics.BlockIOStatistics.IOOpsLimit); + + if (totalStatistics.BlockIOStatistics.IOTotalTime.IsOK()) { + i64 totalTimeUs = totalStatistics.BlockIOStatistics.IOTotalTime.Value().MicroSeconds(); + double totalPercent = std::max<double>(0.0, 100. * totalTimeUs / timeDeltaUsec); + writer->AddGauge("/io/total", totalPercent); + } + + if (totalStatistics.BlockIOStatistics.IOWaitTime.IsOK()) { + i64 waitTimeUs = totalStatistics.BlockIOStatistics.IOWaitTime.Value().MicroSeconds(); + double waitPercent = std::max<double>(0.0, 100. * waitTimeUs / timeDeltaUsec); + writer->AddGauge("/io/wait", waitPercent); + } +} + +void TPortoResourceProfiler::WriteNetworkMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec) +{ + WriteCumulativeGaugeIfOk( + writer, + "/network/rx_bytes", + totalStatistics.NetworkStatistics.RxBytes, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/rx_drops", + totalStatistics.NetworkStatistics.RxDrops, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/rx_packets", + totalStatistics.NetworkStatistics.RxPackets, + timeDeltaUsec); + WriteGaugeIfOk( + writer, + "/network/rx_limit", + totalStatistics.NetworkStatistics.RxLimit); + + WriteCumulativeGaugeIfOk( + writer, + "/network/tx_bytes", + totalStatistics.NetworkStatistics.TxBytes, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/tx_drops", + totalStatistics.NetworkStatistics.TxDrops, + timeDeltaUsec); + WriteCumulativeGaugeIfOk( + writer, + "/network/tx_packets", + totalStatistics.NetworkStatistics.TxPackets, + timeDeltaUsec); + WriteGaugeIfOk( + writer, + "/network/tx_limit", + totalStatistics.NetworkStatistics.TxLimit); +} + +void TPortoResourceProfiler::CollectSensors(ISensorWriter* writer) +{ + i64 lastUpdate = ResourceTracker_->GetLastUpdateTime().MicroSeconds(); + + auto totalStatistics = ResourceTracker_->GetTotalStatistics(); + i64 timeDeltaUsec = TInstant::Now().MicroSeconds() - lastUpdate; + + WriteCpuMetrics(writer, totalStatistics, timeDeltaUsec); + WriteMemoryMetrics(writer, totalStatistics, timeDeltaUsec); + WriteBlockingIOMetrics(writer, totalStatistics, timeDeltaUsec); + WriteNetworkMetrics(writer, totalStatistics, timeDeltaUsec); +} + +//////////////////////////////////////////////////////////////////////////////// + +TPortoResourceProfilerPtr CreatePortoProfilerWithTags( + const IInstancePtr& instance, + const TString containerCategory, + const TPodSpecConfigPtr& podSpec) +{ + auto portoResourceTracker = New<TPortoResourceTracker>( + instance, + ResourceUsageUpdatePeriod, + true, + true); + + return New<TPortoResourceProfiler>( + portoResourceTracker, + podSpec, + TProfiler("/porto") + .WithTag("container_category", containerCategory)); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif + +#ifdef __linux__ +void EnablePortoResourceTracker(const TPodSpecConfigPtr& podSpec) +{ + try { + auto executor = CreatePortoExecutor(New<TPortoExecutorDynamicConfig>(), "porto-tracker"); + + executor->SubscribeFailed(BIND([=] (const TError& error) { + YT_LOG_ERROR(error, "Fatal error during Porto polling"); + })); + + LeakyRefCountedSingleton<TPortoProfilers>( + CreatePortoProfilerWithTags(GetSelfPortoInstance(executor), "daemon", podSpec), + CreatePortoProfilerWithTags(GetRootPortoInstance(executor), "pod", podSpec)); + } catch(const std::exception& exception) { + YT_LOG_ERROR(exception, "Failed to enable porto profiler"); + } +} +#else +void EnablePortoResourceTracker(const TPodSpecConfigPtr& /*podSpec*/) +{ + YT_LOG_WARNING("Porto resource tracker not supported"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/porto_resource_tracker.h b/yt/yt/library/containers/porto_resource_tracker.h new file mode 100644 index 0000000000..8a0f781949 --- /dev/null +++ b/yt/yt/library/containers/porto_resource_tracker.h @@ -0,0 +1,158 @@ +#pragma once + +#include <yt/yt/library/containers/instance.h> +#include <yt/yt/library/containers/public.h> + +#include <yt/yt/library/containers/cgroup.h> + +#include <yt/yt/core/misc/singleton.h> +#include <yt/yt/core/net/address.h> +#include <yt/yt/core/ytree/public.h> + +#include <yt/yt/library/process/process.h> +#include <yt/yt/library/profiling/producer.h> + +namespace NYT::NContainers { + +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +static constexpr auto ResourceUsageUpdatePeriod = TDuration::MilliSeconds(1000); + +//////////////////////////////////////////////////////////////////////////////// + +using TCpuStatistics = TCpuAccounting::TStatistics; +using TBlockIOStatistics = TBlockIO::TStatistics; +using TMemoryStatistics = TMemory::TStatistics; +using TNetworkStatistics = TNetwork::TStatistics; + +struct TTotalStatistics +{ +public: + TCpuStatistics CpuStatistics; + TMemoryStatistics MemoryStatistics; + TBlockIOStatistics BlockIOStatistics; + TNetworkStatistics NetworkStatistics; +}; + +#ifdef _linux_ + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoResourceTracker + : public TRefCounted +{ +public: + TPortoResourceTracker( + IInstancePtr instance, + TDuration updatePeriod, + bool isDeltaTracker = false, + bool isForceUpdate = false); + + TCpuStatistics GetCpuStatistics() const; + + TBlockIOStatistics GetBlockIOStatistics() const; + + TMemoryStatistics GetMemoryStatistics() const; + + TNetworkStatistics GetNetworkStatistics() const; + + TTotalStatistics GetTotalStatistics() const; + + bool AreResourceUsageStatisticsExpired() const; + + TInstant GetLastUpdateTime() const; + +private: + const IInstancePtr Instance_; + const TDuration UpdatePeriod_; + const bool IsDeltaTracker_; + const bool IsForceUpdate_; + + mutable std::atomic<TInstant> LastUpdateTime_ = {}; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + mutable TResourceUsage ResourceUsage_; + mutable TResourceUsage ResourceUsageDelta_; + + mutable std::optional<TCpuStatistics> CachedCpuStatistics_; + mutable std::optional<TMemoryStatistics> CachedMemoryStatistics_; + mutable std::optional<TBlockIOStatistics> CachedBlockIOStatistics_; + mutable std::optional<TNetworkStatistics> CachedNetworkStatistics_; + mutable std::optional<TTotalStatistics> CachedTotalStatistics_; + mutable TErrorOr<ui64> PeakThreadCount_ = 0; + + template <class T, class F> + T GetStatistics( + std::optional<T>& cachedStatistics, + const TString& statisticsKind, + F extractor) const; + + TCpuStatistics ExtractCpuStatistics(const TResourceUsage& resourceUsage) const; + TMemoryStatistics ExtractMemoryStatistics(const TResourceUsage& resourceUsage) const; + TBlockIOStatistics ExtractBlockIOStatistics(const TResourceUsage& resourceUsage) const; + TNetworkStatistics ExtractNetworkStatistics(const TResourceUsage& resourceUsage) const; + TTotalStatistics ExtractTotalStatistics(const TResourceUsage& resourceUsage) const; + + TErrorOr<ui64> CalculateCounterDelta( + const TErrorOr<ui64>& oldValue, + const TErrorOr<ui64>& newValue) const; + + void ReCalculateResourceUsage(const TResourceUsage& newResourceUsage) const; + + void UpdateResourceUsageStatisticsIfExpired() const; + + void DoUpdateResourceUsage() const; +}; + +DEFINE_REFCOUNTED_TYPE(TPortoResourceTracker) + +//////////////////////////////////////////////////////////////////////////////// + +class TPortoResourceProfiler + : public ISensorProducer +{ +public: + TPortoResourceProfiler( + TPortoResourceTrackerPtr tracker, + TPodSpecConfigPtr podSpec, + const TProfiler& profiler = TProfiler{"/porto"}); + + void CollectSensors(ISensorWriter* writer) override; + +private: + const TPortoResourceTrackerPtr ResourceTracker_; + const TPodSpecConfigPtr PodSpec_; + + void WriteCpuMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); + + void WriteMemoryMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); + + void WriteBlockingIOMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); + + void WriteNetworkMetrics( + ISensorWriter* writer, + TTotalStatistics& totalStatistics, + i64 timeDeltaUsec); +}; + +DECLARE_REFCOUNTED_TYPE(TPortoResourceProfiler) +DEFINE_REFCOUNTED_TYPE(TPortoResourceProfiler) + +//////////////////////////////////////////////////////////////////////////////// + +#endif + +void EnablePortoResourceTracker(const TPodSpecConfigPtr& podSpec); + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/private.h b/yt/yt/library/containers/private.h new file mode 100644 index 0000000000..62682cb364 --- /dev/null +++ b/yt/yt/library/containers/private.h @@ -0,0 +1,13 @@ +#pragma once + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger ContainersLogger("Containers"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/process.cpp b/yt/yt/library/containers/process.cpp new file mode 100644 index 0000000000..ad1c8d35dc --- /dev/null +++ b/yt/yt/library/containers/process.cpp @@ -0,0 +1,154 @@ +#ifdef __linux__ + +#include "process.h" + +#include <yt/yt/library/containers/instance.h> + +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/fs.h> + +namespace NYT::NContainers { + +using namespace NPipes; +using namespace NNet; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static inline const NLogging::TLogger Logger("Process"); + +static constexpr pid_t InvalidProcessId = -1; + +//////////////////////////////////////////////////////////////////////////////// + +TPortoProcess::TPortoProcess( + const TString& path, + IInstanceLauncherPtr containerLauncher, + bool copyEnv) + : TProcessBase(path) + , ContainerLauncher_(std::move(containerLauncher)) +{ + AddArgument(NFS::GetFileName(path)); + if (copyEnv) { + for (char** envIt = environ; *envIt; ++envIt) { + Env_.push_back(Capture(*envIt)); + } + } +} + +void TPortoProcess::Kill(int signal) +{ + if (auto instance = GetInstance()) { + instance->Kill(signal); + } +} + +void TPortoProcess::DoSpawn() +{ + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + YT_VERIFY(!GetInstance()); + YT_VERIFY(!Started_); + YT_VERIFY(!Args_.empty()); + + if (!WorkingDirectory_.empty()) { + ContainerLauncher_->SetCwd(WorkingDirectory_); + } + + Started_ = true; + + try { + // TPortoProcess doesn't support running processes inside rootFS. + YT_VERIFY(!ContainerLauncher_->HasRoot()); + std::vector<TString> args(Args_.begin() + 1, Args_.end()); + auto instance = WaitFor(ContainerLauncher_->Launch(ResolvedPath_, args, DecomposeEnv())) + .ValueOrThrow(); + ContainerInstance_.Store(instance); + FinishedPromise_.SetFrom(instance->Wait()); + + try { + ProcessId_ = instance->GetPid(); + } catch (const std::exception& ex) { + // This could happen if Porto container has already died or pid namespace of + // parent container is not a parent of pid namespace of child container. + // It's not a problem, since for Porto process pid is used for logging purposes only. + YT_LOG_DEBUG(ex, "Failed to get pid of root process (Container: %v)", + instance->GetName()); + } + + YT_LOG_DEBUG("Process inside Porto spawned successfully (Path: %v, ExternalPid: %v, Container: %v)", + ResolvedPath_, + ProcessId_, + instance->GetName()); + + FinishedPromise_.ToFuture().Subscribe(BIND([=, this, this_ = MakeStrong(this)] (const TError& exitStatus) { + Finished_ = true; + if (exitStatus.IsOK()) { + YT_LOG_DEBUG("Process inside Porto exited gracefully (ExternalPid: %v, Container: %v)", + ProcessId_, + instance->GetName()); + } else { + YT_LOG_DEBUG(exitStatus, "Process inside Porto exited with an error (ExternalPid: %v, Container: %v)", + ProcessId_, + instance->GetName()); + } + })); + } catch (const std::exception& ex) { + Finished_ = true; + THROW_ERROR_EXCEPTION("Failed to start child process inside Porto") + << TErrorAttribute("path", ResolvedPath_) + << TErrorAttribute("container", ContainerLauncher_->GetName()) + << ex; + } +} + +IInstancePtr TPortoProcess::GetInstance() +{ + return ContainerInstance_.Acquire(); +} + +THashMap<TString, TString> TPortoProcess::DecomposeEnv() const +{ + THashMap<TString, TString> result; + for (const auto& env : Env_) { + TStringBuf name, value; + TStringBuf(env).TrySplit('=', name, value); + result[name] = value; + } + return result; +} + +static TString CreateStdIONamedPipePath() +{ + const TString name = ToString(TGuid::Create()); + return NFS::GetRealPath(NFS::CombinePaths("/tmp", name)); +} + +IConnectionWriterPtr TPortoProcess::GetStdInWriter() +{ + auto pipe = TNamedPipe::Create(CreateStdIONamedPipePath()); + ContainerLauncher_->SetStdIn(pipe->GetPath()); + NamedPipes_.push_back(pipe); + return pipe->CreateAsyncWriter(); +} + +IConnectionReaderPtr TPortoProcess::GetStdOutReader() +{ + auto pipe = TNamedPipe::Create(CreateStdIONamedPipePath()); + ContainerLauncher_->SetStdOut(pipe->GetPath()); + NamedPipes_.push_back(pipe); + return pipe->CreateAsyncReader(); +} + +IConnectionReaderPtr TPortoProcess::GetStdErrReader() +{ + auto pipe = TNamedPipe::Create(CreateStdIONamedPipePath()); + ContainerLauncher_->SetStdErr(pipe->GetPath()); + NamedPipes_.push_back(pipe); + return pipe->CreateAsyncReader(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers + +#endif diff --git a/yt/yt/library/containers/process.h b/yt/yt/library/containers/process.h new file mode 100644 index 0000000000..75255165d8 --- /dev/null +++ b/yt/yt/library/containers/process.h @@ -0,0 +1,46 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/process/process.h> + +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +#include <library/cpp/porto/libporto.hpp> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +// NB(psushin): this class is deprecated and only used to run job proxy. +// ToDo(psushin): kill me. +class TPortoProcess + : public TProcessBase +{ +public: + TPortoProcess( + const TString& path, + NContainers::IInstanceLauncherPtr containerLauncher, + bool copyEnv = true); + void Kill(int signal) override; + NNet::IConnectionWriterPtr GetStdInWriter() override; + NNet::IConnectionReaderPtr GetStdOutReader() override; + NNet::IConnectionReaderPtr GetStdErrReader() override; + + NContainers::IInstancePtr GetInstance(); + +private: + const NContainers::IInstanceLauncherPtr ContainerLauncher_; + + TAtomicIntrusivePtr<NContainers::IInstance> ContainerInstance_; + std::vector<NPipes::TNamedPipePtr> NamedPipes_; + + void DoSpawn() override; + THashMap<TString, TString> DecomposeEnv() const; +}; + +DEFINE_REFCOUNTED_TYPE(TPortoProcess) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/public.h b/yt/yt/library/containers/public.h new file mode 100644 index 0000000000..d8e3cf3491 --- /dev/null +++ b/yt/yt/library/containers/public.h @@ -0,0 +1,163 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +#include <library/cpp/porto/proto/rpc.pb.h> +#include <library/cpp/yt/misc/enum.h> + +namespace NYT::NContainers { + +//////////////////////////////////////////////////////////////////////////////// + +const int PortoErrorCodeBase = 12000; + +DEFINE_ENUM(EPortoErrorCode, + ((Success) ((PortoErrorCodeBase + Porto::EError::Success))) + ((Unknown) ((PortoErrorCodeBase + Porto::EError::Unknown))) + ((InvalidMethod) ((PortoErrorCodeBase + Porto::EError::InvalidMethod))) + ((ContainerAlreadyExists) ((PortoErrorCodeBase + Porto::EError::ContainerAlreadyExists))) + ((ContainerDoesNotExist) ((PortoErrorCodeBase + Porto::EError::ContainerDoesNotExist))) + ((InvalidProperty) ((PortoErrorCodeBase + Porto::EError::InvalidProperty))) + ((InvalidData) ((PortoErrorCodeBase + Porto::EError::InvalidData))) + ((InvalidValue) ((PortoErrorCodeBase + Porto::EError::InvalidValue))) + ((InvalidState) ((PortoErrorCodeBase + Porto::EError::InvalidState))) + ((NotSupported) ((PortoErrorCodeBase + Porto::EError::NotSupported))) + ((ResourceNotAvailable) ((PortoErrorCodeBase + Porto::EError::ResourceNotAvailable))) + ((Permission) ((PortoErrorCodeBase + Porto::EError::Permission))) + ((VolumeAlreadyExists) ((PortoErrorCodeBase + Porto::EError::VolumeAlreadyExists))) + ((VolumeNotFound) ((PortoErrorCodeBase + Porto::EError::VolumeNotFound))) + ((NoSpace) ((PortoErrorCodeBase + Porto::EError::NoSpace))) + ((Busy) ((PortoErrorCodeBase + Porto::EError::Busy))) + ((VolumeAlreadyLinked) ((PortoErrorCodeBase + Porto::EError::VolumeAlreadyLinked))) + ((VolumeNotLinked) ((PortoErrorCodeBase + Porto::EError::VolumeNotLinked))) + ((LayerAlreadyExists) ((PortoErrorCodeBase + Porto::EError::LayerAlreadyExists))) + ((LayerNotFound) ((PortoErrorCodeBase + Porto::EError::LayerNotFound))) + ((NoValue) ((PortoErrorCodeBase + Porto::EError::NoValue))) + ((VolumeNotReady) ((PortoErrorCodeBase + Porto::EError::VolumeNotReady))) + ((InvalidCommand) ((PortoErrorCodeBase + Porto::EError::InvalidCommand))) + ((LostError) ((PortoErrorCodeBase + Porto::EError::LostError))) + ((DeviceNotFound) ((PortoErrorCodeBase + Porto::EError::DeviceNotFound))) + ((InvalidPath) ((PortoErrorCodeBase + Porto::EError::InvalidPath))) + ((InvalidNetworkAddress) ((PortoErrorCodeBase + Porto::EError::InvalidNetworkAddress))) + ((PortoFrozen) ((PortoErrorCodeBase + Porto::EError::PortoFrozen))) + ((LabelNotFound) ((PortoErrorCodeBase + Porto::EError::LabelNotFound))) + ((InvalidLabel) ((PortoErrorCodeBase + Porto::EError::InvalidLabel))) + ((NotFound) ((PortoErrorCodeBase + Porto::EError::NotFound))) + ((SocketError) ((PortoErrorCodeBase + Porto::EError::SocketError))) + ((SocketUnavailable) ((PortoErrorCodeBase + Porto::EError::SocketUnavailable))) + ((SocketTimeout) ((PortoErrorCodeBase + Porto::EError::SocketTimeout))) + ((Taint) ((PortoErrorCodeBase + Porto::EError::Taint))) + ((Queued) ((PortoErrorCodeBase + Porto::EError::Queued))) +); + +//////////////////////////////////////////////////////////////////////////////// + +YT_DEFINE_ERROR_ENUM( + ((FailedToStartContainer) (14000)) +); + +DEFINE_ENUM(EStatField, + // CPU + (CpuUsage) + (CpuUserUsage) + (CpuSystemUsage) + (CpuWait) + (CpuThrottled) + (ContextSwitches) + (ContextSwitchesDelta) + (ThreadCount) + (CpuLimit) + (CpuGuarantee) + + // Memory + (Rss) + (MappedFile) + (MajorPageFaults) + (MinorPageFaults) + (FileCacheUsage) + (AnonMemoryUsage) + (AnonMemoryLimit) + (MemoryUsage) + (MemoryGuarantee) + (MemoryLimit) + (MaxMemoryUsage) + (OomKills) + (OomKillsTotal) + + // IO + (IOReadByte) + (IOWriteByte) + (IOBytesLimit) + (IOReadOps) + (IOWriteOps) + (IOOps) + (IOOpsLimit) + (IOTotalTime) + (IOWaitTime) + + // Network + (NetTxBytes) + (NetTxPackets) + (NetTxDrops) + (NetTxLimit) + (NetRxBytes) + (NetRxPackets) + (NetRxDrops) + (NetRxLimit) +); + +DEFINE_ENUM(EEnablePorto, + (None) + (Isolate) + (Full) +); + +struct TBind +{ + TString SourcePath; + TString TargetPath; + bool ReadOnly; +}; + +struct TRootFS +{ + TString RootPath; + bool IsRootReadOnly; + std::vector<TBind> Binds; +}; + +struct TDevice +{ + TString DeviceName; + bool Enabled; +}; + +struct TInstanceLimits +{ + double Cpu = 0; + i64 Memory = 0; + std::optional<i64> NetTx; + std::optional<i64> NetRx; + + bool operator==(const TInstanceLimits&) const = default; +}; + +DECLARE_REFCOUNTED_STRUCT(IContainerManager) +DECLARE_REFCOUNTED_STRUCT(IInstanceLauncher) +DECLARE_REFCOUNTED_STRUCT(IInstance) +DECLARE_REFCOUNTED_STRUCT(IPortoExecutor) + +DECLARE_REFCOUNTED_CLASS(TPortoHealthChecker) +DECLARE_REFCOUNTED_CLASS(TInstanceLimitsTracker) +DECLARE_REFCOUNTED_CLASS(TPortoProcess) +DECLARE_REFCOUNTED_CLASS(TPortoResourceTracker) +DECLARE_REFCOUNTED_CLASS(TPortoExecutorDynamicConfig) +DECLARE_REFCOUNTED_CLASS(TPodSpecConfig) + +//////////////////////////////////////////////////////////////////////////////// + +bool IsValidCGroupType(const TString& type); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NContainers diff --git a/yt/yt/library/containers/ya.make b/yt/yt/library/containers/ya.make new file mode 100644 index 0000000000..19461fc51e --- /dev/null +++ b/yt/yt/library/containers/ya.make @@ -0,0 +1,36 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + cgroup.cpp + config.cpp + instance.cpp + instance_limits_tracker.cpp + process.cpp + porto_executor.cpp + porto_resource_tracker.cpp + porto_health_checker.cpp +) + +PEERDIR( + library/cpp/porto/proto + yt/yt/core +) + +IF(OS_LINUX) + PEERDIR( + library/cpp/porto + ) +ENDIF() + +END() + +RECURSE( + disk_manager + cri +) + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/heavy_schema_validation/schema_validation.cpp b/yt/yt/library/heavy_schema_validation/schema_validation.cpp new file mode 100644 index 0000000000..686919a079 --- /dev/null +++ b/yt/yt/library/heavy_schema_validation/schema_validation.cpp @@ -0,0 +1,471 @@ +#include "schema_validation.h" + +// TODO(sandello,lukyan): Refine these dependencies. +#include <yt/yt/library/query/base/query_preparer.h> +#include <yt/yt/library/query/base/functions.h> + +#include <yt/yt/client/table_client/column_sort_schema.h> +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/logical_type.h> + +#include <yt/yt/client/complex_types/check_type_compatibility.h> + +#include <yt/yt/core/ytree/convert.h> + +namespace NYT::NTableClient { + +using namespace NYTree; +using namespace NQueryClient; +using namespace NChunkClient; + +//////////////////////////////////////////////////////////////////////////////// +//! Validates the column schema update. +/*! + * \pre{oldColumn and newColumn should have the same stable name.} + * + * Validates that: + * - New column type is compatible with the old one. + * - Optional column doesn't become required. + * - Column expression remains the same. + * - Column aggregate method either was introduced or remains the same. + * - Column sort order either changes to std::nullopt or remains the same. + */ +void ValidateColumnSchemaUpdate(const TColumnSchema& oldColumn, const TColumnSchema& newColumn) +{ + YT_VERIFY(oldColumn.StableName() == newColumn.StableName()); + + auto compatibility = NComplexTypes::CheckTypeCompatibility( + oldColumn.LogicalType(), + newColumn.LogicalType()); + + if (compatibility.first != ESchemaCompatibility::FullyCompatible) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Type mismatch for column %v", + oldColumn.GetDiagnosticNameString()) + << compatibility.second; + } + + if (newColumn.SortOrder().operator bool() && newColumn.SortOrder() != oldColumn.SortOrder()) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Sort order mismatch for column %v: old %Qlv, new %Qlv", + oldColumn.GetDiagnosticNameString(), + oldColumn.SortOrder(), + newColumn.SortOrder()); + } + + if (newColumn.Expression() != oldColumn.Expression()) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Expression mismatch for column %v: old %Qv, new %Qv", + oldColumn.GetDiagnosticNameString(), + oldColumn.Expression(), + newColumn.Expression()); + } + + if (oldColumn.Aggregate() && oldColumn.Aggregate() != newColumn.Aggregate()) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Aggregate mode mismatch for column %v: old %Qv, new %Qv", + oldColumn.GetDiagnosticNameString(), + oldColumn.Aggregate(), + newColumn.Aggregate()); + } + + if (oldColumn.SortOrder() && oldColumn.Lock() != newColumn.Lock()) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Lock mismatch for key column %v: old %Qv, new %Qv", + oldColumn.GetDiagnosticNameString(), + oldColumn.Lock(), + newColumn.Lock()); + } + + if (oldColumn.MaxInlineHunkSize() && !newColumn.MaxInlineHunkSize()) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Cannot reset max inline hunk size for column %v", + oldColumn.GetDiagnosticNameString()); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +//! Validates that all columns from the old schema are present in the new schema, +//! potentially among the deleted ones. +static void ValidateColumnRemoval( + const TTableSchema& oldSchema, + const TTableSchema& newSchema, + TSchemaUpdateEnabledFeatures enabledFeatures, + bool isTableDynamic) +{ + YT_VERIFY(newSchema.GetStrict()); + for (const auto& oldColumn : oldSchema.Columns()) { + if (newSchema.FindColumnByStableName(oldColumn.StableName())) { + continue; + } + + if (!enabledFeatures.EnableStaticTableDropColumn && !isTableDynamic || + !enabledFeatures.EnableDynamicTableDropColumn && isTableDynamic) { + THROW_ERROR_EXCEPTION("Cannot remove column %v from a strict schema", + oldColumn.GetDiagnosticNameString()); + } + + if (!newSchema.FindDeletedColumn(oldColumn.StableName())) { + THROW_ERROR_EXCEPTION("To remove column %v from a strict schema, put it into " + "deleted columns.", oldColumn.GetDiagnosticNameString()); + } + + if (oldColumn.SortOrder() && newSchema.FindDeletedColumn(oldColumn.StableName())) { + THROW_ERROR_EXCEPTION("Key column %v may not be deleted", + oldColumn.GetDiagnosticNameString()); + } + } + if (!newSchema.DeletedColumns().empty()) { + if (!enabledFeatures.EnableDynamicTableDropColumn && isTableDynamic) { + THROW_ERROR_EXCEPTION("Deleting columns is not allowed on a dynamic table, " + "got %v deleted columns", std::ssize(newSchema.DeletedColumns())); + } + + if (!enabledFeatures.EnableStaticTableDropColumn && !isTableDynamic) { + THROW_ERROR_EXCEPTION("Deleting columns is not allowed on a static table, " + "got %v deleted columns", std::ssize(newSchema.DeletedColumns())); + } + } + for (const auto& oldDeletedColumn : oldSchema.DeletedColumns()) { + if (!newSchema.FindDeletedColumn(oldDeletedColumn.StableName())) { + THROW_ERROR_EXCEPTION("Deleted column %v must remain in the deleted column list", + oldDeletedColumn.StableName().Get()); + } + } +} + +//! Validates that all columns from the new schema are present in the old schema. +void ValidateColumnsNotInserted(const TTableSchema& oldSchema, const TTableSchema& newSchema) +{ + YT_VERIFY(!oldSchema.GetStrict()); + for (const auto& newColumn : newSchema.Columns()) { + if (!oldSchema.FindColumnByStableName(newColumn.StableName())) { + THROW_ERROR_EXCEPTION("Cannot insert a new column %v into non-strict schema", + newColumn.GetDiagnosticNameString()); + } + } +} + +//! Validates that table schema columns match. +/*! + * Validates that: + * - For each column present in both #oldSchema and #newSchema, its declarations match each other. + * - Key columns are not removed (but they may become non-key). + * - If any key columns are removed, the unique_keys is set to false. + */ +void ValidateColumnsMatch(const TTableSchema& oldSchema, const TTableSchema& newSchema) +{ + int commonKeyColumnPrefix = 0; + for (int oldColumnIndex = 0; oldColumnIndex < oldSchema.GetColumnCount(); ++oldColumnIndex) { + const auto& oldColumn = oldSchema.Columns()[oldColumnIndex]; + const auto* newColumnPtr = newSchema.FindColumnByStableName(oldColumn.StableName()); + if (!newColumnPtr) { + // We consider only columns present both in oldSchema and newSchema. + continue; + } + const auto& newColumn = *newColumnPtr; + ValidateColumnSchemaUpdate(oldColumn, newColumn); + + if (oldColumn.SortOrder() && newColumn.SortOrder()) { + int newColumnIndex = newSchema.GetColumnIndex(newColumn); + if (oldColumnIndex != newColumnIndex) { + THROW_ERROR_EXCEPTION("Cannot change position of a key column %v: old %v, new %v", + oldColumn.GetDiagnosticNameString(), + oldColumnIndex, + newColumnIndex); + } + if (commonKeyColumnPrefix <= oldColumnIndex) { + commonKeyColumnPrefix = oldColumnIndex + 1; + } + } + } + + // Check that all columns from the commonKeyColumnPrefix in oldSchema are actually present in newSchema. + for (int oldColumnIndex = 0; oldColumnIndex < commonKeyColumnPrefix; ++oldColumnIndex) { + const auto& oldColumn = oldSchema.Columns()[oldColumnIndex]; + if (!newSchema.FindColumnByStableName(oldColumn.StableName())) { + THROW_ERROR_EXCEPTION("Key column with %v is missing in new schema", oldColumn.GetDiagnosticNameString()); + } + } + + if (commonKeyColumnPrefix < oldSchema.GetKeyColumnCount() && newSchema.GetUniqueKeys()) { + THROW_ERROR_EXCEPTION("Table cannot have unique keys since some of its key columns were removed"); + } +} + +void ValidateNoRequiredColumnsAdded(const TTableSchema& oldSchema, const TTableSchema& newSchema) +{ + for (const auto& newColumn : newSchema.Columns()) { + if (newColumn.Required()) { + const auto* oldColumn = oldSchema.FindColumnByStableName(newColumn.StableName()); + if (!oldColumn) { + THROW_ERROR_EXCEPTION("Cannot insert a new required column %v into a non-empty table", + newColumn.GetDiagnosticNameString()); + } + } + } +} + +static bool IsPhysicalType(ESimpleLogicalValueType logicalType) +{ + return static_cast<ui32>(logicalType) == static_cast<ui32>(GetPhysicalType(logicalType)); +} + +//! Validates aggregated columns. +/*! + * Validates that: + * - Aggregated columns are non-key. + * - Aggregate function appears in a list of pre-defined aggregate functions. + * - Type of an aggregated column matches the type of an aggregate function. + */ +void ValidateAggregatedColumns(const TTableSchema& schema) +{ + for (int index = 0; index < schema.GetColumnCount(); ++index) { + const auto& columnSchema = schema.Columns()[index]; + if (columnSchema.Aggregate()) { + if (index < schema.GetKeyColumnCount()) { + THROW_ERROR_EXCEPTION("Key column %v cannot be aggregated", columnSchema.GetDiagnosticNameString()); + } + if (!columnSchema.IsOfV1Type() || !IsPhysicalType(columnSchema.CastToV1Type())) { + THROW_ERROR_EXCEPTION("Aggregated column %v is forbidden to have logical type %Qlv", + columnSchema.GetDiagnosticNameString(), + *columnSchema.LogicalType()); + } + + const auto& name = *columnSchema.Aggregate(); + auto typeInferrer = GetBuiltinTypeInferrers()->GetFunction(name); + if (auto descriptor = typeInferrer->As<TAggregateTypeInferrer>()) { + TTypeSet constraint; + std::optional<EValueType> stateType; + std::optional<EValueType> resultType; + + descriptor->GetNormalizedConstraints(&constraint, &stateType, &resultType, name); + if (!constraint.Get(columnSchema.GetWireType())) { + THROW_ERROR_EXCEPTION("Argument type mismatch in aggregate function %Qv from column %v: expected %Qlv, got %Qlv", + *columnSchema.Aggregate(), + columnSchema.GetDiagnosticNameString(), + constraint, + columnSchema.GetWireType()); + } + + if (stateType && *stateType != columnSchema.GetWireType()) { + THROW_ERROR_EXCEPTION("Aggregate function %Qv state type %Qlv differs from column %v type %Qlv", + *columnSchema.Aggregate(), + stateType, + columnSchema.GetDiagnosticNameString(), + columnSchema.GetWireType()); + } + + if (resultType && *resultType != columnSchema.GetWireType()) { + THROW_ERROR_EXCEPTION("Aggregate function %Qv result type %Qlv differs from column %v type %Qlv", + *columnSchema.Aggregate(), + resultType, + columnSchema.GetDiagnosticNameString(), + columnSchema.GetWireType()); + } + } else if (auto descriptor = typeInferrer->As<TAggregateFunctionTypeInferrer>()) { + std::vector<TTypeSet> typeConstraints; + std::vector<int> argumentIndexes; + + auto [_, resultIndex] = descriptor->GetNormalizedConstraints( + &typeConstraints, + &argumentIndexes); + auto& resultConstraint = typeConstraints[resultIndex]; + + if (!resultConstraint.Get(columnSchema.GetWireType())) { + THROW_ERROR_EXCEPTION("Aggregate function %Qv result type set %Qlv differs from column %v type %Qlv", + *columnSchema.Aggregate(), + resultConstraint, + columnSchema.GetDiagnosticNameString(), + columnSchema.GetWireType()); + } + } else { + THROW_ERROR_EXCEPTION("Unknown aggregate function %Qv at column %v", + *columnSchema.Aggregate(), + columnSchema.GetDiagnosticNameString()); + } + } + } +} + +void ValidateComputedColumns(const TTableSchema& schema, bool isTableDynamic) +{ + for (int index = 0; index < schema.GetColumnCount(); ++index) { + const auto& columnSchema = schema.Columns()[index]; + // TODO(levysotsky): Use early continue. + if (columnSchema.Expression()) { + if (index >= schema.GetKeyColumnCount() && isTableDynamic) { + THROW_ERROR_EXCEPTION("Non-key column %v cannot be computed", columnSchema.GetDiagnosticNameString()); + } + THashSet<TString> references; + auto expr = PrepareExpression(*columnSchema.Expression(), schema, GetBuiltinTypeInferrers(), &references); + if (*columnSchema.LogicalType() != *expr->LogicalType) { + THROW_ERROR_EXCEPTION( + "Computed column %v type mismatch: declared type is %Qlv but expression type is %Qlv", + columnSchema.GetDiagnosticNameString(), + *columnSchema.LogicalType(), + *expr->LogicalType); + } + + for (const auto& ref : references) { + const auto* refColumn = schema.FindColumn(ref); + if (!refColumn) { + THROW_ERROR_EXCEPTION("Computed column %v depends on unknown column %Qv", + columnSchema.GetDiagnosticNameString(), + ref); + } + if (!refColumn->SortOrder() && isTableDynamic) { + THROW_ERROR_EXCEPTION("Computed column %v depends on a non-key column %v", + columnSchema.GetDiagnosticNameString(), + refColumn->GetDiagnosticNameString()); + } + if (refColumn->Expression()) { + THROW_ERROR_EXCEPTION("Computed column %v depends on a computed column %v", + columnSchema.GetDiagnosticNameString(), + refColumn->GetDiagnosticNameString()); + } + } + } + } +} + +//! TODO(max42): document this functions somewhere (see also https://st.yandex-team.ru/YT-1433). +void ValidateTableSchemaUpdateInternal( + const TTableSchema& oldSchema, + const TTableSchema& newSchema, + TSchemaUpdateEnabledFeatures enabledFeatures, + bool isTableDynamic, + bool isTableEmpty) +{ + try { + ValidateTableSchemaHeavy(newSchema, isTableDynamic); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::InvalidSchemaValue, "New table schema is not valid") + << TErrorAttribute("old_schema", oldSchema) + << TErrorAttribute("new_schema", newSchema) + << ex; + } + + try { + if (isTableEmpty) { + // Any valid schema is allowed to be set for an empty table. + return; + } + + if (isTableDynamic && oldSchema.IsSorted() != newSchema.IsSorted()) { + THROW_ERROR_EXCEPTION("Cannot change dynamic table type from sorted to ordered or vice versa"); + } + + if (oldSchema.GetKeyColumnCount() == 0 && newSchema.GetKeyColumnCount() > 0) { + THROW_ERROR_EXCEPTION("Cannot change schema from unsorted to sorted"); + } + if (!oldSchema.GetStrict() && newSchema.GetStrict()) { + THROW_ERROR_EXCEPTION("Changing \"strict\" from \"false\" to \"true\" is not allowed"); + } + if (!oldSchema.GetUniqueKeys() && newSchema.GetUniqueKeys()) { + THROW_ERROR_EXCEPTION("Changing \"unique_keys\" from \"false\" to \"true\" is not allowed"); + } + + if (oldSchema.GetStrict() && !newSchema.GetStrict()) { + if (oldSchema.Columns() != newSchema.Columns()) { + THROW_ERROR_EXCEPTION("Changing columns is not allowed while changing \"strict\" from \"true\" to \"false\""); + } + return; + } + + if (oldSchema.GetStrict()) { + ValidateColumnRemoval(oldSchema, newSchema, enabledFeatures, isTableDynamic); + } else { + ValidateColumnsNotInserted(oldSchema, newSchema); + } + ValidateColumnsMatch(oldSchema, newSchema); + + // We allow adding computed columns only on creation of the table. + if (!oldSchema.Columns().empty() || !isTableEmpty) { + for (const auto& newColumn : newSchema.Columns()) { + if (newColumn.Expression() && !oldSchema.FindColumnByStableName(newColumn.StableName())) { + THROW_ERROR_EXCEPTION("Cannot introduce a new computed column %v after creation", + newColumn.GetDiagnosticNameString()); + } + } + } + + ValidateNoRequiredColumnsAdded(oldSchema, newSchema); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::IncompatibleSchemas, "Table schemas are incompatible") + << TErrorAttribute("old_schema", oldSchema) + << TErrorAttribute("new_schema", newSchema) + << ex; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +void ValidateTableSchemaUpdate( + const TTableSchema& oldSchema, + const TTableSchema& newSchema, + bool isTableDynamic, + bool isTableEmpty) +{ + ValidateTableSchemaUpdateInternal( + oldSchema, + newSchema, + TSchemaUpdateEnabledFeatures{ + false, /* EnableStaticTableDropColumn */ + false /* EnableDynamicTableDropColumn */ + }, + isTableDynamic, + isTableEmpty + ); +} + +//////////////////////////////////////////////////////////////////////////////// + + +void ValidateTableSchemaHeavy( + const TTableSchema& schema, + bool isTableDynamic) +{ + ValidateTableSchema(schema, isTableDynamic); + ValidateComputedColumns(schema, isTableDynamic); + ValidateAggregatedColumns(schema); +} + +//////////////////////////////////////////////////////////////////////////////// + +TError ValidateComputedColumnsCompatibility( + const TTableSchema& inputSchema, + const TTableSchema& outputSchema) +{ + try { + for (const auto& outputColumn : outputSchema.Columns()) { + if (!outputColumn.Expression()) { + continue; + } + const auto* inputColumn = inputSchema.FindColumn(outputColumn.Name()); + if (!inputColumn) { + THROW_ERROR_EXCEPTION("Computed column %v is missing in input schema", + outputColumn.GetDiagnosticNameString()); + } + if (outputColumn.Expression() != inputColumn->Expression()) { + THROW_ERROR_EXCEPTION("Computed column %v has different expressions in input " + "and output schemas", + outputColumn.GetDiagnosticNameString()) + << TErrorAttribute("input_schema_expression", inputColumn->Expression()) + << TErrorAttribute("output_schema_expression", outputColumn.Expression()); + } + if (*outputColumn.LogicalType() != *inputColumn->LogicalType()) { + THROW_ERROR_EXCEPTION("Computed column %v type in the input table %Qlv " + "differs from the type in the output table %Qlv", + outputColumn.GetDiagnosticNameString(), + *inputColumn->LogicalType(), + *outputColumn.LogicalType()); + } + } + } catch (const TErrorException& exception) { + return exception.Error() + << TErrorAttribute("input_table_schema", inputSchema) + << TErrorAttribute("output_table_schema", outputSchema); + } + + return TError(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NTableClient diff --git a/yt/yt/library/heavy_schema_validation/schema_validation.h b/yt/yt/library/heavy_schema_validation/schema_validation.h new file mode 100644 index 0000000000..fd4cf37100 --- /dev/null +++ b/yt/yt/library/heavy_schema_validation/schema_validation.h @@ -0,0 +1,60 @@ +#pragma once + +#include <yt/yt/client/table_client/public.h> + +namespace NYT::NTableClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct TSchemaUpdateEnabledFeatures +{ + bool EnableStaticTableDropColumn = false; + bool EnableDynamicTableDropColumn = false; +}; + +//////////////////////////////////////////////////////////////////////////////// + +void ValidateColumnSchemaUpdate( + const TColumnSchema& oldColumn, + const TColumnSchema& newColumn); + +void ValidateTableSchemaUpdateInternal( + const TTableSchema& oldSchema, + const TTableSchema& newSchema, + TSchemaUpdateEnabledFeatures enabledFeatures, + bool isTableDynamic = false, + bool isTableEmpty = false); + +void ValidateTableSchemaUpdate( + const TTableSchema& oldSchema, + const TTableSchema& newSchema, + bool isTableDynamic = false, + bool isTableEmpty = false); + +//! Compared to #ValidateTableSchema, additionally validates +//! aggregated and computed columns (this involves calling some heavy QL-related +//! stuff which is missing in yt/client). +void ValidateTableSchemaHeavy( + const TTableSchema& schema, + bool isTableDynamic); + +//! Validates computed columns. +//! +//! Validates that: +//! - Type of a computed column matches the type of its expression. +//! - All referenced columns appear in schema and are not computed. +//! For dynamic tables, additionally validates that all computed and referenced +//! columns are key columns. +void ValidateComputedColumns( + const TTableSchema& schema, + bool isTableDynamic); + +//! Validates that all computed columns in the outputSchema are present in the +//! inputSchema and have the same expression. +TError ValidateComputedColumnsCompatibility( + const TTableSchema& inputSchema, + const TTableSchema& outputSchema); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NTableClient diff --git a/yt/yt/library/heavy_schema_validation/ya.make b/yt/yt/library/heavy_schema_validation/ya.make new file mode 100644 index 0000000000..f40f543d87 --- /dev/null +++ b/yt/yt/library/heavy_schema_validation/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + schema_validation.cpp +) + +PEERDIR( + yt/yt/client + yt/yt/library/query/engine_api +) + +END() diff --git a/yt/yt/library/monitoring/http_integration.cpp b/yt/yt/library/monitoring/http_integration.cpp new file mode 100644 index 0000000000..a526d2ede6 --- /dev/null +++ b/yt/yt/library/monitoring/http_integration.cpp @@ -0,0 +1,203 @@ +#include "http_integration.h" + +#include "monitoring_manager.h" + +#include <yt/yt/build/build.h> + +#include <yt/yt/core/json/config.h> +#include <yt/yt/core/json/json_writer.h> + +#include <yt/yt/core/ytree/fluent.h> + +#include <yt/yt/core/yson/parser.h> +#include <yt/yt/core/yson/consumer.h> + +#include <yt/yt/core/concurrency/scheduler.h> + +#include <yt/yt/core/ytree/helpers.h> +#include <yt/yt/core/ytree/virtual.h> +#include <yt/yt/core/ytree/ypath_detail.h> +#include <yt/yt/core/ytree/ypath_proxy.h> + +#include <yt/yt/core/http/http.h> +#include <yt/yt/core/http/helpers.h> +#include <yt/yt/core/http/server.h> + +#include <yt/yt/core/misc/ref_counted_tracker_statistics_producer.h> + +#include <yt/yt/library/profiling/solomon/exporter.h> + +#ifdef _linux_ +#include <yt/yt/library/ytprof/http/handler.h> +#include <yt/yt/library/ytprof/build_info.h> + +#include <yt/yt/library/backtrace_introspector/http/handler.h> +#endif + +#include <library/cpp/cgiparam/cgiparam.h> + +#include <util/string/vector.h> + +namespace NYT::NMonitoring { + +using namespace NYTree; +using namespace NYson; +using namespace NHttp; +using namespace NConcurrency; +using namespace NJson; + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EVerb, + (Get) + (List) +); + +//////////////////////////////////////////////////////////////////////////////// + +void Initialize( + const NHttp::IServerPtr& monitoringServer, + const NProfiling::TSolomonExporterConfigPtr& config, + TMonitoringManagerPtr* monitoringManager, + NYTree::IMapNodePtr* orchidRoot) +{ + *monitoringManager = New<TMonitoringManager>(); + (*monitoringManager)->Register("/ref_counted", CreateRefCountedTrackerStatisticsProducer()); + (*monitoringManager)->Register("/solomon", BIND([] (NYson::IYsonConsumer* consumer) { + auto tags = NProfiling::TSolomonRegistry::Get()->GetDynamicTags(); + + BuildYsonFluently(consumer) + .BeginMap() + .Item("dynamic_tags").Value(THashMap<TString, TString>(tags.begin(), tags.end())) + .EndMap(); + })); + (*monitoringManager)->Start(); + + *orchidRoot = NYTree::GetEphemeralNodeFactory(true)->CreateMap(); + SetNodeByYPath( + *orchidRoot, + "/monitoring", + CreateVirtualNode((*monitoringManager)->GetService())); + +#ifdef _linux_ + auto buildInfo = NYTProf::TBuildInfo::GetDefault(); + buildInfo.BinaryVersion = GetVersion(); + + SetNodeByYPath( + *orchidRoot, + "/build_info", + NYTree::BuildYsonNodeFluently() + .BeginMap() + .Item("arc_revision").Value(buildInfo.ArcRevision) + .Item("binary_version").Value(buildInfo.BinaryVersion) + .Item("build_type").Value(buildInfo.BuildType) + .EndMap()); +#endif + + if (monitoringServer) { + auto exporter = New<NProfiling::TSolomonExporter>(config); + exporter->Register("/solomon", monitoringServer); + exporter->Start(); + + SetNodeByYPath( + *orchidRoot, + "/sensors", + CreateVirtualNode(exporter->GetSensorService())); + +#ifdef _linux_ + NYTProf::Register(monitoringServer, "/ytprof", buildInfo); + NBacktraceIntrospector::Register(monitoringServer, "/backtrace"); +#endif + monitoringServer->AddHandler( + "/orchid/", + GetOrchidYPathHttpHandler(*orchidRoot)); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +class TYPathHttpHandler + : public IHttpHandler +{ +public: + explicit TYPathHttpHandler(IYPathServicePtr service) + : Service_(std::move(service)) + { } + + void HandleRequest( + const IRequestPtr& req, + const IResponseWriterPtr& rsp) override + { + const TStringBuf orchidPrefix = "/orchid"; + + TString path{req->GetUrl().Path}; + if (!path.StartsWith(orchidPrefix)) { + THROW_ERROR_EXCEPTION("HTTP request must start with %Qv prefix", + orchidPrefix) + << TErrorAttribute("path", path); + } + + path = path.substr(orchidPrefix.size(), TString::npos); + TCgiParameters params(req->GetUrl().RawQuery); + + auto verb = EVerb::Get; + + auto options = CreateEphemeralAttributes(); + for (const auto& param : params) { + if (param.first == "verb") { + verb = ParseEnum<EVerb>(param.second); + } else { + // Just a check, IAttributeDictionary takes raw YSON anyway. + try { + ValidateYson(TYsonString(param.second), DefaultYsonParserNestingLevelLimit); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error parsing value of query parameter %Qv", + param.first) + << ex; + } + + options->SetYson(param.first, TYsonString(param.second)); + } + } + + TYsonString result; + switch (verb) { + case EVerb::Get: { + auto ypathReq = TYPathProxy::Get(path); + ToProto(ypathReq->mutable_options(), *options); + auto ypathRsp = WaitFor(ExecuteVerb(Service_, ypathReq)) + .ValueOrThrow(); + result = TYsonString(ypathRsp->value()); + break; + } + case EVerb::List: { + auto ypathReq = TYPathProxy::List(path); + auto ypathRsp = WaitFor(ExecuteVerb(Service_, ypathReq)) + .ValueOrThrow(); + result = TYsonString(ypathRsp->value()); + break; + } + default: + YT_ABORT(); + } + + rsp->SetStatus(EStatusCode::OK); + NHttp::ReplyJson(rsp, [&] (NYson::IYsonConsumer* writer) { + Serialize(result, writer); + }); + WaitFor(rsp->Close()) + .ThrowOnError(); + } + +private: + const IYPathServicePtr Service_; +}; + +IHttpHandlerPtr GetOrchidYPathHttpHandler(const IYPathServicePtr& service) +{ + return WrapYTException(New<TYPathHttpHandler>(service)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/http_integration.h b/yt/yt/library/monitoring/http_integration.h new file mode 100644 index 0000000000..48c12ca8a8 --- /dev/null +++ b/yt/yt/library/monitoring/http_integration.h @@ -0,0 +1,28 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/ypath_service.h> + +#include <yt/yt/core/http/public.h> + +#include <yt/yt/library/profiling/solomon/public.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +void Initialize( + const NHttp::IServerPtr& monitoringServer, + const NProfiling::TSolomonExporterConfigPtr& solomonExporterConfig, + TMonitoringManagerPtr* monitoringManager, + NYTree::IMapNodePtr* orchidRoot); + +NHttp::IHttpHandlerPtr CreateTracingHttpHandler(); + +NHttp::IHttpHandlerPtr GetOrchidYPathHttpHandler( + const NYTree::IYPathServicePtr& service); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/monitoring_manager.cpp b/yt/yt/library/monitoring/monitoring_manager.cpp new file mode 100644 index 0000000000..ef642034a4 --- /dev/null +++ b/yt/yt/library/monitoring/monitoring_manager.cpp @@ -0,0 +1,177 @@ +#include "monitoring_manager.h" +#include "private.h" + +#include <yt/yt/core/concurrency/action_queue.h> +#include <yt/yt/core/concurrency/periodic_executor.h> + +#include <yt/yt/core/ytree/convert.h> +#include <yt/yt/core/ytree/ephemeral_node_factory.h> +#include <yt/yt/core/ytree/node.h> +#include <yt/yt/core/ytree/tree_visitor.h> +#include <yt/yt/core/ytree/ypath_detail.h> +#include <yt/yt/core/ytree/ypath_client.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NMonitoring { + +using namespace NYTree; +using namespace NYPath; +using namespace NYson; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = MonitoringLogger; + +static const auto UpdatePeriod = TDuration::Seconds(3); +static const auto EmptyRoot = GetEphemeralNodeFactory()->CreateMap(); + +//////////////////////////////////////////////////////////////////////////////// + +class TMonitoringManager::TImpl + : public TRefCounted +{ +public: + void Register(const TYPath& path, TYsonProducer producer) + { + auto guard = Guard(SpinLock_); + YT_VERIFY(PathToProducer_.emplace(path, producer).second); + } + + void Unregister(const TYPath& path) + { + auto guard = Guard(SpinLock_); + YT_VERIFY(PathToProducer_.erase(path) == 1); + } + + IYPathServicePtr GetService() + { + return New<TYPathService>(this); + } + + void Start() + { + auto guard = Guard(SpinLock_); + + YT_VERIFY(!Started_); + + PeriodicExecutor_ = New<TPeriodicExecutor>( + ActionQueue_->GetInvoker(), + BIND(&TImpl::Update, MakeWeak(this)), + UpdatePeriod); + PeriodicExecutor_->Start(); + + Started_ = true; + } + + void Stop() + { + auto guard = Guard(SpinLock_); + + if (!Started_) + return; + + Started_ = false; + YT_UNUSED_FUTURE(PeriodicExecutor_->Stop()); + Root_.Reset(); + } + +private: + class TYPathService + : public TYPathServiceBase + { + public: + explicit TYPathService(TIntrusivePtr<TImpl> owner) + : Owner_(std::move(owner)) + { } + + TResolveResult Resolve(const TYPath& path, const IYPathServiceContextPtr& /*context*/) override + { + return TResolveResultThere{Owner_->GetRoot(), path}; + } + + private: + const TIntrusivePtr<TImpl> Owner_; + + }; + + bool Started_ = false; + TActionQueuePtr ActionQueue_ = New<TActionQueue>("Monitoring"); + TPeriodicExecutorPtr PeriodicExecutor_; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, SpinLock_); + THashMap<TString, NYson::TYsonProducer> PathToProducer_; + IMapNodePtr Root_; + + void Update() + { + YT_LOG_DEBUG("Started updating monitoring state"); + + YT_PROFILE_TIMING("/monitoring/update_time") { + auto newRoot = GetEphemeralNodeFactory()->CreateMap(); + + THashMap<TString, NYson::TYsonProducer> pathToProducer;; + { + auto guard = Guard(SpinLock_); + pathToProducer = PathToProducer_; + } + + for (const auto& [path, producer] : pathToProducer) { + auto value = ConvertToYsonString(producer); + SyncYPathSet(newRoot, path, value); + } + + if (Started_) { + auto guard = Guard(SpinLock_); + std::swap(Root_, newRoot); + } + } + YT_LOG_DEBUG("Finished updating monitoring state"); + } + + IMapNodePtr GetRoot() + { + auto guard = Guard(SpinLock_); + return Root_ ? Root_ : EmptyRoot; + } +}; + +DEFINE_REFCOUNTED_TYPE(TMonitoringManager) + +//////////////////////////////////////////////////////////////////////////////// + +TMonitoringManager::TMonitoringManager() + : Impl_(New<TImpl>()) +{ } + +TMonitoringManager::~TMonitoringManager() = default; + +void TMonitoringManager::Register(const TYPath& path, TYsonProducer producer) +{ + Impl_->Register(path, producer); +} + +void TMonitoringManager::Unregister(const TYPath& path) +{ + Impl_->Unregister(path); +} + +IYPathServicePtr TMonitoringManager::GetService() +{ + return Impl_->GetService(); +} + +void TMonitoringManager::Start() +{ + Impl_->Start(); +} + +void TMonitoringManager::Stop() +{ + Impl_->Stop(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/monitoring_manager.h b/yt/yt/library/monitoring/monitoring_manager.h new file mode 100644 index 0000000000..fc5c3de6c7 --- /dev/null +++ b/yt/yt/library/monitoring/monitoring_manager.h @@ -0,0 +1,54 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/yson/consumer.h> + +#include <yt/yt/core/ypath/public.h> + +#include <yt/yt/core/ytree/public.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +//! Exposes a tree assembled from results returned by a set of +//! registered NYson::TYsonProducer-s. +/*! + * \note + * The results are cached and periodically updated. + */ +class TMonitoringManager + : public TRefCounted +{ +public: + TMonitoringManager(); + ~TMonitoringManager(); + + //! Registers a new #producer for a given #path. + void Register(const NYPath::TYPath& path, NYson::TYsonProducer producer); + + //! Unregisters an existing producer for the specified #path. + void Unregister(const NYPath::TYPath& path); + + //! Returns the service representing the whole tree. + /*! + * \note The service is thread-safe. + */ + NYTree::IYPathServicePtr GetService(); + + //! Starts periodic updates. + void Start(); + + //! Stops periodic updates. + void Stop(); + +private: + class TImpl; + TIntrusivePtr<TImpl> Impl_; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/private.h b/yt/yt/library/monitoring/private.h new file mode 100644 index 0000000000..e2bfb31c78 --- /dev/null +++ b/yt/yt/library/monitoring/private.h @@ -0,0 +1,15 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger MonitoringLogger("Monitoring"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NJournalClient diff --git a/yt/yt/library/monitoring/public.h b/yt/yt/library/monitoring/public.h new file mode 100644 index 0000000000..3514bdd858 --- /dev/null +++ b/yt/yt/library/monitoring/public.h @@ -0,0 +1,13 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT::NMonitoring { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TMonitoringManager) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NMonitoring diff --git a/yt/yt/library/monitoring/ya.make b/yt/yt/library/monitoring/ya.make new file mode 100644 index 0000000000..c2fccd99ac --- /dev/null +++ b/yt/yt/library/monitoring/ya.make @@ -0,0 +1,27 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + http_integration.cpp + monitoring_manager.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/build + yt/yt/library/profiling + yt/yt/library/profiling/solomon + library/cpp/cgiparam +) + +IF (OS_LINUX) + PEERDIR( + yt/yt/library/ytprof + yt/yt/library/ytprof/http + + yt/yt/library/backtrace_introspector/http + ) +ENDIF() + +END() diff --git a/yt/yt/library/numeric/fixed_point_number-inl.h b/yt/yt/library/numeric/fixed_point_number-inl.h new file mode 100644 index 0000000000..6d76a8ce8a --- /dev/null +++ b/yt/yt/library/numeric/fixed_point_number-inl.h @@ -0,0 +1,103 @@ +#ifndef FIXED_POINT_NUMBER_INL_H_ +#error "Direct inclusion of this file is not allowed, include fixed_point_number.h" +// For the sake of sane code completion. +#include "fixed_point_number.h" +#endif + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <typename T> +constexpr T ComputePower(T base, int exponent) +{ + return exponent == 0 ? 1 : base * ComputePower(base, exponent - 1); +} + +//////////////////////////////////////////////////////////////////////////////// + +template <typename U, int P> +TFixedPointNumber<U, P>::TFixedPointNumber() + : Value_() +{ } + +template <typename U, int P> +TFixedPointNumber<U, P>::TFixedPointNumber(i64 value) + : Value_(value * ScalingFactor) +{ } + +template <typename U, int P> +TFixedPointNumber<U, P>::TFixedPointNumber(double value) + : Value_(std::round(value * ScalingFactor)) +{ } + +template <typename U, int P> +TFixedPointNumber<U, P>::operator i64 () const +{ + return Value_ / ScalingFactor; +} + +template <typename U, int P> +TFixedPointNumber<U, P>::operator double () const +{ + return static_cast<double>(Value_) / ScalingFactor; +} + +template <typename U, int P> +U TFixedPointNumber<U, P>::GetUnderlyingValue() const +{ + return Value_; +} + +template <typename U, int P> +void TFixedPointNumber<U, P>::SetUnderlyingValue(U value) +{ + Value_ = value; +} + +template <typename U, int P> +TFixedPointNumber<U, P>& TFixedPointNumber<U, P>::operator += (const TFixedPointNumber& rhs) +{ + Value_ += rhs.Value_; + return *this; +} + +template <typename U, int P> +TFixedPointNumber<U, P>& TFixedPointNumber<U, P>::operator -= (const TFixedPointNumber<U, P>& rhs) +{ + Value_ -= rhs.Value_; + return *this; +} + +template <typename U, int P> +template <typename T> +TFixedPointNumber<U, P>& TFixedPointNumber<U, P>::operator *= (const T& value) +{ + Value_ *= value; + return *this; +} + +template <typename U, int P> +template <typename T> +TFixedPointNumber<U, P>& TFixedPointNumber<U, P>::operator /= (const T& value) +{ + Value_ /= value; + return *this; +} + +template <typename U, int P> +TFixedPointNumber<U, P>& TFixedPointNumber<U, P>::operator *= (const double& value) +{ + Value_ = std::round(Value_ * value); + return *this; +} + +template <typename U, int P> +NYT::TFixedPointNumber<U, P> round(const NYT::TFixedPointNumber<U, P>& number) +{ + return number; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/numeric/fixed_point_number.h b/yt/yt/library/numeric/fixed_point_number.h new file mode 100644 index 0000000000..0f7532577d --- /dev/null +++ b/yt/yt/library/numeric/fixed_point_number.h @@ -0,0 +1,147 @@ +#pragma once + +#include <util/system/defaults.h> + +#include <type_traits> +#include <cmath> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <typename T> +constexpr T ComputePower(T base, int exponent); + +//////////////////////////////////////////////////////////////////////////////// + +//! Stores fixed point number of form X.YYY with DecimalPrecision decimal digits after the point +//! using one object of Underlying integer type as a storage. +//! Details can be found at https://en.wikipedia.org/wiki/Fixed-point_arithmetic +template <typename Underlying, int DecimalPrecision> +class TFixedPointNumber +{ + static_assert(std::is_integral<Underlying>::value, "Underlying type should be integral"); + static_assert(DecimalPrecision >= 0 && DecimalPrecision <= std::numeric_limits<Underlying>::digits10, + "Underlying type should be able to represent specified number of decimal places"); + +public: + static constexpr Underlying ScalingFactor = ComputePower<Underlying>(10, DecimalPrecision); + + TFixedPointNumber(); + + TFixedPointNumber(i64 value); + + explicit TFixedPointNumber(double value); + + explicit operator i64 () const; + explicit operator double () const; + + Underlying GetUnderlyingValue() const; + void SetUnderlyingValue(Underlying value); + + TFixedPointNumber& operator += (const TFixedPointNumber& rhs); + TFixedPointNumber& operator -= (const TFixedPointNumber& rhs); + + template <typename T> + TFixedPointNumber& operator *= (const T& value); + + TFixedPointNumber& operator *= (const double& value); + + template <typename T> + TFixedPointNumber& operator /= (const T& value); + + friend TFixedPointNumber operator + (TFixedPointNumber lhs, const TFixedPointNumber& rhs) + { + lhs += rhs; + return lhs; + } + + friend TFixedPointNumber operator - (TFixedPointNumber lhs, const TFixedPointNumber& rhs) + { + lhs -= rhs; + return lhs; + } + + template <typename T> + friend TFixedPointNumber operator * (TFixedPointNumber lhs, T value) + { + lhs *= value; + return lhs; + } + + template <typename T> + friend TFixedPointNumber operator / (TFixedPointNumber lhs, T value) + { + lhs /= value; + return lhs; + } + + friend TFixedPointNumber operator - (TFixedPointNumber lhs) + { + lhs.Value_ = -lhs.Value_; + return lhs; + } + + friend bool operator == (const TFixedPointNumber& lhs, const TFixedPointNumber& rhs) + { + return lhs.Value_ == rhs.Value_; + } + + friend bool operator != (const TFixedPointNumber& lhs, const TFixedPointNumber& rhs) + { + return lhs.Value_ != rhs.Value_; + } + + friend bool operator < (const TFixedPointNumber& lhs, const TFixedPointNumber& rhs) + { + return lhs.Value_ < rhs.Value_; + } + + friend bool operator <= (const TFixedPointNumber& lhs, const TFixedPointNumber& rhs) + { + return lhs.Value_ <= rhs.Value_; + } + + friend bool operator > (const TFixedPointNumber& lhs, const TFixedPointNumber& rhs) + { + return lhs.Value_ > rhs.Value_; + } + + friend bool operator >= (const TFixedPointNumber& lhs, const TFixedPointNumber& rhs) + { + return lhs.Value_ >= rhs.Value_; + } + +private: + Underlying Value_; + +}; + +template <typename U, int P> +NYT::TFixedPointNumber<U, P> round(const NYT::TFixedPointNumber<U, P>& number); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + +//////////////////////////////////////////////////////////////////////////////// + +namespace std { + +template <typename U, int P> +class numeric_limits<NYT::TFixedPointNumber<U, P>> +{ +public: + static NYT::TFixedPointNumber<U, P> max() + { + return numeric_limits<U>::max() / NYT::TFixedPointNumber<U, P>::ScalingFactor; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace std + +#define FIXED_POINT_NUMBER_INL_H_ +#include "fixed_point_number-inl.h" +#undef FIXED_POINT_NUMBER_INL_H_ diff --git a/yt/yt/library/numeric/serialize/fixed_point_number.h b/yt/yt/library/numeric/serialize/fixed_point_number.h new file mode 100644 index 0000000000..9eca67f7a4 --- /dev/null +++ b/yt/yt/library/numeric/serialize/fixed_point_number.h @@ -0,0 +1,65 @@ +#pragma once + +#include <yt/yt/core/misc/serialize.h> + +#include <yt/yt/core/ytree/serialize.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <typename U, int P> +void Serialize(const TFixedPointNumber<U, P>& number, NYson::IYsonConsumer* consumer) +{ + NYTree::Serialize(static_cast<double>(number), consumer); +} + +template <typename U, int P> +void Deserialize(TFixedPointNumber<U, P>& number, NYTree::INodePtr node) +{ + double doubleValue; + Deserialize(doubleValue, std::move(node)); + number = TFixedPointNumber<U, P>(doubleValue); +} + +template <typename U, int P> +void Deserialize(TFixedPointNumber<U, P>& number, NYson::TYsonPullParserCursor* cursor) +{ + auto doubleValue = ExtractTo<double>(cursor); + number = TFixedPointNumber<U, P>(doubleValue); +} + +template <typename U, int P> +TString ToString(const TFixedPointNumber<U, P>& number) +{ + return ToString(static_cast<double>(number)); +} + +//////////////////////////////////////////////////////////////////////////////// + +struct TFixedPointNumberSerializer +{ + template <class TNumber, class C> + static void Save(C& context, const TNumber& value) + { + NYT::Save(context, value.GetUnderlyingValue()); + } + + template <class TNumber, class C> + static void Load(C& context, TNumber& value) + { + typename std::remove_const<decltype(TNumber::ScalingFactor)>::type underlyingValue; + NYT::Load(context, underlyingValue); + value.SetUnderlyingValue(underlyingValue); + } +}; + +template <class U, int P, class C> +struct TSerializerTraits<TFixedPointNumber<U, P>, C, void> +{ + using TSerializer = TFixedPointNumberSerializer; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/io_dispatcher.cpp b/yt/yt/library/process/io_dispatcher.cpp new file mode 100644 index 0000000000..7da757658d --- /dev/null +++ b/yt/yt/library/process/io_dispatcher.cpp @@ -0,0 +1,37 @@ +#include "io_dispatcher.h" + +#include <yt/yt/core/concurrency/thread_pool_poller.h> +#include <yt/yt/core/concurrency/poller.h> + +#include <yt/yt/core/misc/singleton.h> + +namespace NYT::NPipes { + +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +TIODispatcher::TIODispatcher() + : Poller_(BIND([] { return CreateThreadPoolPoller(1, "Pipes"); })) +{ } + +TIODispatcher::~TIODispatcher() = default; + +TIODispatcher* TIODispatcher::Get() +{ + return Singleton<TIODispatcher>(); +} + +IInvokerPtr TIODispatcher::GetInvoker() +{ + return Poller_.Value()->GetInvoker(); +} + +IPollerPtr TIODispatcher::GetPoller() +{ + return Poller_.Value(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/io_dispatcher.h b/yt/yt/library/process/io_dispatcher.h new file mode 100644 index 0000000000..2db1b34386 --- /dev/null +++ b/yt/yt/library/process/io_dispatcher.h @@ -0,0 +1,34 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/concurrency/public.h> + +#include <yt/yt/core/misc/lazy_ptr.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +class TIODispatcher +{ +public: + ~TIODispatcher(); + + static TIODispatcher* Get(); + + IInvokerPtr GetInvoker(); + + NConcurrency::IPollerPtr GetPoller(); + +private: + TIODispatcher(); + + Y_DECLARE_SINGLETON_FRIEND() + + TLazyIntrusivePtr<NConcurrency::IThreadPoolPoller> Poller_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/pipe.cpp b/yt/yt/library/process/pipe.cpp new file mode 100644 index 0000000000..f51d043f22 --- /dev/null +++ b/yt/yt/library/process/pipe.cpp @@ -0,0 +1,256 @@ +#include "pipe.h" +#include "private.h" +#include "io_dispatcher.h" + +#include <yt/yt/core/net/connection.h> + +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/fs.h> + +#include <sys/types.h> +#include <sys/stat.h> + +namespace NYT::NPipes { + +using namespace NNet; + +//////////////////////////////////////////////////////////////////////////////// + +static const auto& Logger = PipesLogger; + +//////////////////////////////////////////////////////////////////////////////// + +TNamedPipe::TNamedPipe(const TString& path, bool owning) + : Path_(path) + , Owning_(owning) +{ } + +TNamedPipe::~TNamedPipe() +{ + if (!Owning_) { + return; + } + + if (unlink(Path_.c_str()) == -1) { + YT_LOG_INFO(TError::FromSystem(), "Failed to unlink pipe %v", Path_); + } +} + +TNamedPipePtr TNamedPipe::Create(const TString& path, int permissions) +{ + auto pipe = New<TNamedPipe>(path, /* owning */ true); + pipe->Open(permissions); + YT_LOG_DEBUG("Named pipe created (Path: %v, Permissions: %v)", path, permissions); + return pipe; +} + +TNamedPipePtr TNamedPipe::FromPath(const TString& path) +{ + return New<TNamedPipe>(path, /* owning */ false); +} + +void TNamedPipe::Open(int permissions) +{ + if (mkfifo(Path_.c_str(), permissions) == -1) { + THROW_ERROR_EXCEPTION("Failed to create named pipe %v", Path_) + << TError::FromSystem(); + } +} + +IConnectionReaderPtr TNamedPipe::CreateAsyncReader() +{ + YT_VERIFY(!Path_.empty()); + return CreateInputConnectionFromPath(Path_, TIODispatcher::Get()->GetPoller(), MakeStrong(this)); +} + +IConnectionWriterPtr TNamedPipe::CreateAsyncWriter() +{ + YT_VERIFY(!Path_.empty()); + return CreateOutputConnectionFromPath(Path_, TIODispatcher::Get()->GetPoller(), MakeStrong(this)); +} + +TString TNamedPipe::GetPath() const +{ + return Path_; +} + +//////////////////////////////////////////////////////////////////////////////// + +TNamedPipeConfigPtr TNamedPipeConfig::Create(TString path, int fd, bool write) +{ + auto result = New<TNamedPipeConfig>(); + result->Path = std::move(path); + result->FD = fd; + result->Write = write; + + return result; +} + +void TNamedPipeConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("path", &TThis::Path) + .Default(); + + registrar.Parameter("fd", &TThis::FD) + .Default(0); + + registrar.Parameter("write", &TThis::Write) + .Default(false); +} + +DEFINE_REFCOUNTED_TYPE(TNamedPipeConfig) + +//////////////////////////////////////////////////////////////////////////////// + +TPipe::TPipe() +{ } + +TPipe::TPipe(TPipe&& pipe) +{ + Init(std::move(pipe)); +} + +TPipe::TPipe(int fd[2]) + : ReadFD_(fd[0]) + , WriteFD_(fd[1]) +{ } + +void TPipe::Init(TPipe&& other) +{ + ReadFD_ = other.ReadFD_; + WriteFD_ = other.WriteFD_; + other.ReadFD_ = InvalidFD; + other.WriteFD_ = InvalidFD; +} + +TPipe::~TPipe() +{ + if (ReadFD_ != InvalidFD) { + YT_VERIFY(TryClose(ReadFD_, false)); + } + + if (WriteFD_ != InvalidFD) { + YT_VERIFY(TryClose(WriteFD_, false)); + } +} + +void TPipe::operator=(TPipe&& other) +{ + if (this == &other) { + return; + } + + Init(std::move(other)); +} + +IConnectionWriterPtr TPipe::CreateAsyncWriter() +{ + YT_VERIFY(WriteFD_ != InvalidFD); + SafeMakeNonblocking(WriteFD_); + return CreateConnectionFromFD(ReleaseWriteFD(), {}, {}, TIODispatcher::Get()->GetPoller()); +} + +IConnectionReaderPtr TPipe::CreateAsyncReader() +{ + YT_VERIFY(ReadFD_ != InvalidFD); + SafeMakeNonblocking(ReadFD_); + return CreateConnectionFromFD(ReleaseReadFD(), {}, {}, TIODispatcher::Get()->GetPoller()); +} + +int TPipe::ReleaseReadFD() +{ + YT_VERIFY(ReadFD_ != InvalidFD); + auto fd = ReadFD_; + ReadFD_ = InvalidFD; + return fd; +} + +int TPipe::ReleaseWriteFD() +{ + YT_VERIFY(WriteFD_ != InvalidFD); + auto fd = WriteFD_; + WriteFD_ = InvalidFD; + return fd; +} + +int TPipe::GetReadFD() const +{ + YT_VERIFY(ReadFD_ != InvalidFD); + return ReadFD_; +} + +int TPipe::GetWriteFD() const +{ + YT_VERIFY(WriteFD_ != InvalidFD); + return WriteFD_; +} + +void TPipe::CloseReadFD() +{ + if (ReadFD_ == InvalidFD) { + return; + } + auto fd = ReadFD_; + ReadFD_ = InvalidFD; + SafeClose(fd, false); +} + +void TPipe::CloseWriteFD() +{ + if (WriteFD_ == InvalidFD) { + return; + } + auto fd = WriteFD_; + WriteFD_ = InvalidFD; + SafeClose(fd, false); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString ToString(const TPipe& pipe) +{ + return Format("{ReadFD: %v, WriteFD: %v}", + pipe.GetReadFD(), + pipe.GetWriteFD()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TPipeFactory::TPipeFactory(int minFD) + : MinFD_(minFD) +{ } + +TPipeFactory::~TPipeFactory() +{ + for (int fd : ReservedFDs_) { + YT_VERIFY(TryClose(fd, false)); + } +} + +TPipe TPipeFactory::Create() +{ + while (true) { + int fd[2]; + SafePipe(fd); + if (fd[0] >= MinFD_ && fd[1] >= MinFD_) { + TPipe pipe(fd); + return pipe; + } else { + ReservedFDs_.push_back(fd[0]); + ReservedFDs_.push_back(fd[1]); + } + } +} + +void TPipeFactory::Clear() +{ + for (int& fd : ReservedFDs_) { + YT_VERIFY(TryClose(fd, false)); + fd = TPipe::InvalidFD; + } + ReservedFDs_.clear(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/pipe.h b/yt/yt/library/process/pipe.h new file mode 100644 index 0000000000..10da81cc8a --- /dev/null +++ b/yt/yt/library/process/pipe.h @@ -0,0 +1,114 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/net/public.h> + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +class TNamedPipe + : public TRefCounted +{ +public: + ~TNamedPipe(); + static TNamedPipePtr Create(const TString& path, int permissions = 0660); + static TNamedPipePtr FromPath(const TString& path); + + NNet::IConnectionReaderPtr CreateAsyncReader(); + NNet::IConnectionWriterPtr CreateAsyncWriter(); + + TString GetPath() const; + +private: + const TString Path_; + + //! Whether pipe was created by this class + //! and should be removed in destructor. + const bool Owning_; + + explicit TNamedPipe(const TString& path, bool owning); + void Open(int permissions); + DECLARE_NEW_FRIEND() +}; + +DEFINE_REFCOUNTED_TYPE(TNamedPipe) + +//////////////////////////////////////////////////////////////////////////////// + +class TNamedPipeConfig + : public NYTree::TYsonStruct +{ +public: + TString Path; + int FD = 0; + bool Write = false; + + static TNamedPipeConfigPtr Create(TString path, int fd, bool write); + + REGISTER_YSON_STRUCT(TNamedPipeConfig); + + static void Register(TRegistrar registrar); +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TPipe + : public TNonCopyable +{ +public: + static const int InvalidFD = -1; + + TPipe(); + TPipe(TPipe&& pipe); + ~TPipe(); + + void operator=(TPipe&& other); + + void CloseReadFD(); + void CloseWriteFD(); + + NNet::IConnectionReaderPtr CreateAsyncReader(); + NNet::IConnectionWriterPtr CreateAsyncWriter(); + + int ReleaseReadFD(); + int ReleaseWriteFD(); + + int GetReadFD() const; + int GetWriteFD() const; + +private: + int ReadFD_ = InvalidFD; + int WriteFD_ = InvalidFD; + + TPipe(int fd[2]); + void Init(TPipe&& other); + + friend class TPipeFactory; +}; + +TString ToString(const TPipe& pipe); + +//////////////////////////////////////////////////////////////////////////////// + +class TPipeFactory +{ +public: + explicit TPipeFactory(int minFD = 0); + ~TPipeFactory(); + + TPipe Create(); + + void Clear(); + +private: + const int MinFD_; + std::vector<int> ReservedFDs_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/private.h b/yt/yt/library/process/private.h new file mode 100644 index 0000000000..95b2ffb0f5 --- /dev/null +++ b/yt/yt/library/process/private.h @@ -0,0 +1,14 @@ +#pragma once + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger PipesLogger("Pipes"); +inline const NLogging::TLogger PtyLogger("Pty"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/process.cpp b/yt/yt/library/process/process.cpp new file mode 100644 index 0000000000..809a50ed9a --- /dev/null +++ b/yt/yt/library/process/process.cpp @@ -0,0 +1,697 @@ +#include "process.h" +#include "pipe.h" + +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/finally.h> + +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/delayed_executor.h> + +#include <library/cpp/yt/system/handle_eintr.h> + +#include <util/folder/dirut.h> + +#include <util/generic/guid.h> + +#include <util/string/ascii.h> + +#include <util/string/util.h> + +#include <util/system/env.h> +#include <util/system/execpath.h> +#include <util/system/maxlen.h> +#include <util/system/shellcommand.h> + +#ifdef _unix_ + #include <unistd.h> + #include <errno.h> + #include <sys/wait.h> + #include <sys/resource.h> +#endif + +#ifdef _darwin_ + #include <crt_externs.h> + #define environ (*_NSGetEnviron()) +#endif + +namespace NYT { + +using namespace NPipes; +using namespace NNet; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +static inline const NLogging::TLogger Logger("Process"); + +static constexpr pid_t InvalidProcessId = -1; + +static constexpr int ExecveRetryCount = 5; +static constexpr auto ExecveRetryTimeout = TDuration::Seconds(1); + +static constexpr int ResolveRetryCount = 5; +static constexpr auto ResolveRetryTimeout = TDuration::Seconds(1); + +//////////////////////////////////////////////////////////////////////////////// + +TErrorOr<TString> ResolveBinaryPath(const TString& binary) +{ + auto Logger = NYT::Logger + .WithTag("Binary: %v", binary); + + YT_LOG_DEBUG("Resolving binary path"); + + std::vector<TError> accumulatedErrors; + + auto test = [&] (const char* path) { + YT_LOG_DEBUG("Probing path (Path: %v)", path); + if (access(path, R_OK | X_OK) == 0) { + return true; + } else { + auto error = TError("Cannot run %Qlv", path) << TError::FromSystem(); + accumulatedErrors.push_back(std::move(error)); + return false; + } + }; + + auto failure = [&] { + auto error = TError( + EProcessErrorCode::CannotResolveBinary, + "Cannot resolve binary %Qlv", + binary); + error.MutableInnerErrors()->swap(accumulatedErrors); + YT_LOG_DEBUG(error, "Error resolving binary path"); + return error; + }; + + auto success = [&] (const TString& path) { + YT_LOG_DEBUG("Binary resolved (Path: %v)", path); + return path; + }; + + if (test(binary.c_str())) { + return success(binary); + } + + // If this is an absolute path, stop here. + if (binary.empty() || binary[0] == '/') { + return failure(); + } + + // XXX(sandello): Sometimes we drop PATH from environment when spawning isolated processes. + // In this case, try to locate somewhere nearby. + { + auto execPathDirName = GetDirName(GetExecPath()); + YT_LOG_DEBUG("Looking in our exec path directory (ExecPathDir: %v)", execPathDirName); + auto probe = TString::Join(execPathDirName, "/", binary); + if (test(probe.c_str())) { + return success(probe); + } + } + + std::array<char, MAX_PATH> buffer; + + auto envPathStr = GetEnv("PATH"); + TStringBuf envPath(envPathStr); + TStringBuf envPathItem; + + YT_LOG_DEBUG("Looking for binary in PATH (Path: %v)", envPathStr); + + while (envPath.NextTok(':', envPathItem)) { + if (buffer.size() < 2 + envPathItem.size() + binary.size()) { + continue; + } + + size_t index = 0; + std::copy(envPathItem.begin(), envPathItem.end(), buffer.begin() + index); + index += envPathItem.size(); + buffer[index] = '/'; + index += 1; + std::copy(binary.begin(), binary.end(), buffer.begin() + index); + index += binary.size(); + buffer[index] = 0; + + if (test(buffer.data())) { + return success(TString(buffer.data(), index)); + } + } + + return failure(); +} + +bool TryKillProcessByPid(int pid, int signal) +{ +#ifdef _unix_ + YT_VERIFY(pid != -1); + int result = ::kill(pid, signal); + // Ignore ESRCH because process may have died just before TryKillProcessByPid. + if (result < 0 && errno != ESRCH) { + return false; + } + return true; +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +#ifdef _unix_ + +bool TryWaitid(idtype_t idtype, id_t id, siginfo_t *infop, int options) +{ + if (infop != nullptr) { + // See comment below. + infop->si_pid = 0; + } + + siginfo_t info; + ::memset(&info, 0, sizeof(info)); + auto res = HandleEintr(::waitid, idtype, id, infop != nullptr ? infop : &info, options); + + if (res == 0) { + // According to man wait. + // If WNOHANG was specified in options and there were + // no children in a waitable state, then waitid() returns 0 immediately. + // To distinguish this case from that where a child + // was in a waitable state, zero out the si_pid field + // before the call and check for a nonzero value in this field after + // the call returns. + if (infop && infop->si_pid == 0) { + return false; + } + return true; + } + + return false; +} + +void Wait4OrDie(pid_t id, int* status, int options, rusage* rusage) +{ + auto res = HandleEintr(::wait4, id, status, options, rusage); + if (res == -1) { + YT_LOG_FATAL(TError::FromSystem(), "Wait4 failed"); + } +} + +void Cleanup(int pid) +{ + YT_VERIFY(pid > 0); + + YT_VERIFY(TryKillProcessByPid(pid, 9)); + YT_VERIFY(TryWaitid(P_PID, pid, nullptr, WEXITED)); +} + +bool TrySetSignalMask(const sigset_t* sigmask, sigset_t* oldSigmask) +{ + int error = pthread_sigmask(SIG_SETMASK, sigmask, oldSigmask); + if (error != 0) { + return false; + } + return true; +} + +bool TryResetSignals() +{ + for (int sig = 1; sig < NSIG; ++sig) { + // Ignore invalid signal errors. + ::signal(sig, SIG_DFL); + } + return true; +} + +#endif + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +TProcessBase::TProcessBase(const TString& path) + : Path_(path) + , ProcessId_(InvalidProcessId) +{ } + +void TProcessBase::AddArgument(TStringBuf arg) +{ + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + + Args_.push_back(Capture(arg)); +} + +void TProcessBase::AddEnvVar(TStringBuf var) +{ + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + + Env_.push_back(Capture(var)); +} + +void TProcessBase::AddArguments(std::initializer_list<TStringBuf> args) +{ + for (auto arg : args) { + AddArgument(arg); + } +} + +void TProcessBase::AddArguments(const std::vector<TString>& args) +{ + for (const auto& arg : args) { + AddArgument(arg); + } +} + +void TProcessBase::SetWorkingDirectory(const TString& path) +{ + WorkingDirectory_ = path; +} + +void TProcessBase::CreateProcessGroup() +{ + CreateProcessGroup_ = true; +} + +//////////////////////////////////////////////////////////////////////////////// + +TSimpleProcess::TSimpleProcess(const TString& path, bool copyEnv, TDuration pollPeriod) + // TString is guaranteed to be zero-terminated. + // https://wiki.yandex-team.ru/Development/Poisk/arcadia/util/TStringAndTStringBuf#sobstvennosimvoly + : TProcessBase(path) + , PollPeriod_(pollPeriod) + , PipeFactory_(3) +{ + AddArgument(path); + + if (copyEnv) { + for (char** envIt = environ; *envIt; ++envIt) { + Env_.push_back(Capture(*envIt)); + } + } +} + +void TSimpleProcess::AddDup2FileAction(int oldFD, int newFD) +{ + TSpawnAction action{ + std::bind(TryDup2, oldFD, newFD), + Format("Error duplicating %v file descriptor to %v in child process", oldFD, newFD) + }; + + MaxSpawnActionFD_ = std::max(MaxSpawnActionFD_, newFD); + SpawnActions_.push_back(action); +} + +IConnectionReaderPtr TSimpleProcess::GetStdOutReader() +{ + auto& pipe = StdPipes_[STDOUT_FILENO]; + pipe = PipeFactory_.Create(); + AddDup2FileAction(pipe.GetWriteFD(), STDOUT_FILENO); + return pipe.CreateAsyncReader(); +} + +IConnectionReaderPtr TSimpleProcess::GetStdErrReader() +{ + auto& pipe = StdPipes_[STDERR_FILENO]; + pipe = PipeFactory_.Create(); + AddDup2FileAction(pipe.GetWriteFD(), STDERR_FILENO); + return pipe.CreateAsyncReader(); +} + +IConnectionWriterPtr TSimpleProcess::GetStdInWriter() +{ + auto& pipe = StdPipes_[STDIN_FILENO]; + pipe = PipeFactory_.Create(); + AddDup2FileAction(pipe.GetReadFD(), STDIN_FILENO); + return pipe.CreateAsyncWriter(); +} + +TFuture<void> TProcessBase::Spawn() +{ + try { + // Resolve binary path. + std::vector<TError> innerErrors; + for (int retryIndex = ResolveRetryCount; retryIndex >= 0; --retryIndex) { + auto errorOrPath = ResolveBinaryPath(Path_); + if (errorOrPath.IsOK()) { + ResolvedPath_ = errorOrPath.Value(); + break; + } + + innerErrors.push_back(errorOrPath); + + if (retryIndex == 0) { + THROW_ERROR_EXCEPTION("Failed to resolve binary path %v", Path_) + << innerErrors; + } + + TDelayedExecutor::WaitForDuration(ResolveRetryTimeout); + } + + DoSpawn(); + } catch (const std::exception& ex) { + FinishedPromise_.TrySet(ex); + } + return FinishedPromise_; +} + +void TSimpleProcess::DoSpawn() +{ +#ifdef _unix_ + auto finally = Finally([&] () { + StdPipes_[STDIN_FILENO].CloseReadFD(); + StdPipes_[STDOUT_FILENO].CloseWriteFD(); + StdPipes_[STDERR_FILENO].CloseWriteFD(); + PipeFactory_.Clear(); + }); + + YT_VERIFY(ProcessId_ == InvalidProcessId && !Finished_); + + // Make sure no spawn action closes Pipe_.WriteFD + TPipeFactory pipeFactory(MaxSpawnActionFD_ + 1); + Pipe_ = pipeFactory.Create(); + pipeFactory.Clear(); + + YT_LOG_DEBUG("Spawning new process (Path: %v, ErrorPipe: %v, Arguments: %v, Environment: %v)", + ResolvedPath_, + Pipe_, + Args_, + Env_); + + Env_.push_back(nullptr); + Args_.push_back(nullptr); + + // Block all signals around vfork; see http://ewontfix.com/7/ + + // As the child may run in the same address space as the parent until + // the actual execve() system call, any (custom) signal handlers that + // the parent has might alter parent's memory if invoked in the child, + // with undefined results. So we block all signals in the parent before + // vfork(), which will cause them to be blocked in the child as well (we + // rely on the fact that Linux, just like all sane implementations, only + // clones the calling thread). Then, in the child, we reset all signals + // to their default dispositions (while still blocked), and unblock them + // (so the exec()ed process inherits the parent's signal mask) + + sigset_t allBlocked; + sigfillset(&allBlocked); + sigset_t oldSignals; + + if (!TrySetSignalMask(&allBlocked, &oldSignals)) { + THROW_ERROR_EXCEPTION("Failed to block all signals") + << TError::FromSystem(); + } + + SpawnActions_.push_back(TSpawnAction{ + TryResetSignals, + "Error resetting signals to default disposition in child process: signal failed" + }); + + SpawnActions_.push_back(TSpawnAction{ + std::bind(TrySetSignalMask, &oldSignals, nullptr), + "Error unblocking signals in child process: pthread_sigmask failed" + }); + + if (!WorkingDirectory_.empty()) { + SpawnActions_.push_back(TSpawnAction{ + [&] () { + NFs::SetCurrentWorkingDirectory(WorkingDirectory_); + return true; + }, + "Error changing working directory" + }); + } + + if (CreateProcessGroup_) { + SpawnActions_.push_back(TSpawnAction{ + [&] () { + setpgrp(); + return true; + }, + "Error creating process group" + }); + } + + SpawnActions_.push_back(TSpawnAction{ + [this] { + for (int retryIndex = 0; retryIndex < ExecveRetryCount; ++retryIndex) { + // Execve may fail, if called binary is being updated, e.g. during yandex-yt package update. + // So we'd better retry several times. + // For example see YT-6352. + TryExecve(ResolvedPath_.c_str(), Args_.data(), Env_.data()); + if (retryIndex < ExecveRetryCount - 1) { + Sleep(ExecveRetryTimeout); + } + } + // If we are still here, return failure. + return false; + }, + "Error starting child process: execve failed" + }); + + SpawnChild(); + + // This should not fail ever. + YT_VERIFY(TrySetSignalMask(&oldSignals, nullptr)); + + Pipe_.CloseWriteFD(); + + ValidateSpawnResult(); + + AsyncWaitExecutor_ = New<TPeriodicExecutor>( + GetSyncInvoker(), + BIND(&TSimpleProcess::AsyncPeriodicTryWait, MakeStrong(this)), + PollPeriod_); + + AsyncWaitExecutor_->Start(); +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSimpleProcess::SpawnChild() +{ + // NB: fork() will cause data corruption when run concurrently with + // Disk IO on O_DIRECT file descriptor. Seems like vfork don't suffer from the same issue. + +#ifdef _unix_ + int pid = vfork(); + + if (pid < 0) { + THROW_ERROR_EXCEPTION("Error starting child process: vfork failed") + << TErrorAttribute("path", ResolvedPath_) + << TError::FromSystem(); + } + + if (pid == 0) { + try { + Child(); + } catch (...) { + YT_ABORT(); + } + } + + ProcessId_ = pid; + Started_ = true; +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSimpleProcess::ValidateSpawnResult() +{ +#ifdef _unix_ + int data[2]; + ssize_t res; + res = HandleEintr(::read, Pipe_.GetReadFD(), &data, sizeof(data)); + Pipe_.CloseReadFD(); + + if (res == 0) { + // Child successfully spawned or was killed by a signal. + // But there is no way to distinguish between these two cases: + // * child killed by signal before exec + // * child killed by signal after exec + // So we treat kill-before-exec the same way as kill-after-exec. + YT_LOG_DEBUG("Child process spawned successfully (Pid: %v)", ProcessId_); + return; + } + + YT_VERIFY(res == sizeof(data)); + Finished_ = true; + + Cleanup(ProcessId_); + ProcessId_ = InvalidProcessId; + + int actionIndex = data[0]; + int errorCode = data[1]; + + YT_VERIFY(0 <= actionIndex && actionIndex < std::ssize(SpawnActions_)); + const auto& action = SpawnActions_[actionIndex]; + THROW_ERROR_EXCEPTION("%v", action.ErrorMessage) + << TError::FromSystem(errorCode); +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +#ifdef _unix_ +void TSimpleProcess::AsyncPeriodicTryWait() +{ + siginfo_t processInfo; + memset(&processInfo, 0, sizeof(siginfo_t)); + + // Note WNOWAIT flag. + // This call just waits for a process to be finished but does not clear zombie flag. + + if (!TryWaitid(P_PID, ProcessId_, &processInfo, WEXITED | WNOWAIT | WNOHANG) || + processInfo.si_pid != ProcessId_) + { + return; + } + + YT_UNUSED_FUTURE(AsyncWaitExecutor_->Stop()); + AsyncWaitExecutor_ = nullptr; + + // This call just should return immediately + // because we have already waited for this process with WNOHANG + rusage rusage; + Wait4OrDie(ProcessId_, nullptr, WNOHANG, &rusage); + + Finished_ = true; + auto error = ProcessInfoToError(processInfo); + YT_LOG_DEBUG("Process finished (Pid: %v, MajFaults: %d, Error: %v)", ProcessId_, rusage.ru_majflt, error); + + FinishedPromise_.Set(error); +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSimpleProcess::Kill(int signal) +{ +#ifdef _unix_ + if (!Started_) { + THROW_ERROR_EXCEPTION("Process is not started yet"); + } + + if (Finished_) { + return; + } + + YT_LOG_DEBUG("Killing child process (Pid: %v)", ProcessId_); + + bool result = false; + if (!CreateProcessGroup_) { + result = TryKillProcessByPid(ProcessId_, signal); + } else { + result = TryKillProcessByPid(-1 * ProcessId_, signal); + } + + if (!result) { + THROW_ERROR_EXCEPTION("Failed to kill child process %v", ProcessId_) + << TError::FromSystem(); + } + return; +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +TString TProcessBase::GetPath() const +{ + return Path_; +} + +int TProcessBase::GetProcessId() const +{ + return ProcessId_; +} + +bool TProcessBase::IsStarted() const +{ + return Started_; +} + +bool TProcessBase::IsFinished() const +{ + return Finished_; +} + +TString TProcessBase::GetCommandLine() const +{ + TStringBuilder builder; + builder.AppendString(Path_); + + bool first = true; + for (const auto& arg_ : Args_) { + TStringBuf arg(arg_); + if (first) { + first = false; + } else { + if (arg) { + builder.AppendChar(' '); + bool needQuote = false; + for (size_t i = 0; i < arg.length(); ++i) { + if (!IsAsciiAlnum(arg[i]) && + arg[i] != '-' && arg[i] != '_' && arg[i] != '=' && arg[i] != '/') + { + needQuote = true; + break; + } + } + if (needQuote) { + builder.AppendChar('"'); + TStringBuf left, right; + while (arg.TrySplit('"', left, right)) { + builder.AppendString(left); + builder.AppendString("\\\""); + arg = right; + } + builder.AppendString(arg); + builder.AppendChar('"'); + } else { + builder.AppendString(arg); + } + } + } + } + + return builder.Flush(); +} + +const char* TProcessBase::Capture(TStringBuf arg) +{ + StringHolders_.push_back(TString(arg)); + return StringHolders_.back().c_str(); +} + +void TSimpleProcess::Child() +{ +#ifdef _unix_ + for (int actionIndex = 0; actionIndex < std::ssize(SpawnActions_); ++actionIndex) { + auto& action = SpawnActions_[actionIndex]; + if (!action.Callback()) { + // Report error through the pipe. + int data[] = { + actionIndex, + errno + }; + + // According to pipe(7) write of small buffer is atomic. + ssize_t size = HandleEintr(::write, Pipe_.GetWriteFD(), &data, sizeof(data)); + YT_VERIFY(size == sizeof(data)); + _exit(1); + } + } +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/process.h b/yt/yt/library/process/process.h new file mode 100644 index 0000000000..b38ae3f4b3 --- /dev/null +++ b/yt/yt/library/process/process.h @@ -0,0 +1,125 @@ +#pragma once + +#include "pipe.h" + +#include <yt/yt/core/misc/error.h> + +#include <yt/yt/core/actions/future.h> + +#include <yt/yt/core/concurrency/public.h> + +#include <atomic> +#include <vector> +#include <array> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TErrorOr<TString> ResolveBinaryPath(const TString& binary); +bool TryKillProcessByPid(int pid, int signal); + +//////////////////////////////////////////////////////////////////////////////// + +class TProcessBase + : public TRefCounted +{ +public: + explicit TProcessBase(const TString& path); + + void AddArgument(TStringBuf arg); + void AddEnvVar(TStringBuf var); + + void AddArguments(std::initializer_list<TStringBuf> args); + void AddArguments(const std::vector<TString>& args); + + void SetWorkingDirectory(const TString& path); + void CreateProcessGroup(); + + virtual NNet::IConnectionWriterPtr GetStdInWriter() = 0; + virtual NNet::IConnectionReaderPtr GetStdOutReader() = 0; + virtual NNet::IConnectionReaderPtr GetStdErrReader() = 0; + + TFuture<void> Spawn(); + virtual void Kill(int signal) = 0; + + TString GetPath() const; + int GetProcessId() const; + bool IsStarted() const; + bool IsFinished() const; + + TString GetCommandLine() const; + +protected: + const TString Path_; + + int ProcessId_; + std::atomic<bool> Started_ = {false}; + std::atomic<bool> Finished_ = {false}; + int MaxSpawnActionFD_ = - 1; + NPipes::TPipe Pipe_; + // Container for owning string data. Use std::deque because it never moves contained objects. + std::deque<std::string> StringHolders_; + std::vector<const char*> Args_; + std::vector<const char*> Env_; + TString ResolvedPath_; + TString WorkingDirectory_; + bool CreateProcessGroup_ = false; + TPromise<void> FinishedPromise_ = NewPromise<void>(); + + virtual void DoSpawn() = 0; + const char* Capture(TStringBuf arg); + +private: + void SpawnChild(); + void ValidateSpawnResult(); + void Child(); + void AsyncPeriodicTryWait(); +}; + +DEFINE_REFCOUNTED_TYPE(TProcessBase) + +//////////////////////////////////////////////////////////////////////////////// + +// Read this +// http://ewontfix.com/7/ +// before making any changes. +class TSimpleProcess + : public TProcessBase +{ +public: + explicit TSimpleProcess( + const TString& path, + bool copyEnv = true, + TDuration pollPeriod = TDuration::MilliSeconds(100)); + void Kill(int signal) override; + NNet::IConnectionWriterPtr GetStdInWriter() override; + NNet::IConnectionReaderPtr GetStdOutReader() override; + NNet::IConnectionReaderPtr GetStdErrReader() override; + +private: + const TDuration PollPeriod_; + + NPipes::TPipeFactory PipeFactory_; + std::array<NPipes::TPipe, 3> StdPipes_; + + NConcurrency::TPeriodicExecutorPtr AsyncWaitExecutor_; + struct TSpawnAction + { + std::function<bool()> Callback; + TString ErrorMessage; + }; + + std::vector<TSpawnAction> SpawnActions_; + + void AddDup2FileAction(int oldFD, int newFD); + void DoSpawn() override; + void SpawnChild(); + void ValidateSpawnResult(); + void AsyncPeriodicTryWait(); + void Child(); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/pty.cpp b/yt/yt/library/process/pty.cpp new file mode 100644 index 0000000000..fc972d38ea --- /dev/null +++ b/yt/yt/library/process/pty.cpp @@ -0,0 +1,64 @@ +#include "pty.h" + +#include "io_dispatcher.h" + +#include <yt/yt/core/misc/common.h> +#include <yt/yt/core/misc/proc.h> + +#include <yt/yt/core/net/connection.h> + +namespace NYT::NPipes { + +using namespace NNet; + +//////////////////////////////////////////////////////////////////////////////// + +TPty::TPty(int height, int width) +{ + SafeOpenPty(&MasterFD_, &SlaveFD_, height, width); +} + +TPty::~TPty() +{ + if (MasterFD_ != InvalidFD) { + YT_VERIFY(TryClose(MasterFD_, false)); + } + + if (SlaveFD_ != InvalidFD) { + YT_VERIFY(TryClose(SlaveFD_, false)); + } +} + +IConnectionWriterPtr TPty::CreateMasterAsyncWriter() +{ + YT_VERIFY(MasterFD_ != InvalidFD); + int fd = SafeDup(MasterFD_); + SafeSetCloexec(fd); + SafeMakeNonblocking(fd); + return CreateConnectionFromFD(fd, {}, {}, TIODispatcher::Get()->GetPoller()); +} + +IConnectionReaderPtr TPty::CreateMasterAsyncReader() +{ + YT_VERIFY(MasterFD_ != InvalidFD); + int fd = SafeDup(MasterFD_); + SafeSetCloexec(fd); + SafeMakeNonblocking(fd); + return CreateConnectionFromFD(fd, {}, {}, TIODispatcher::Get()->GetPoller()); +} + +int TPty::GetMasterFD() const +{ + YT_VERIFY(MasterFD_ != InvalidFD); + return MasterFD_; +} + +int TPty::GetSlaveFD() const +{ + YT_VERIFY(SlaveFD_ != InvalidFD); + return SlaveFD_; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/pty.h b/yt/yt/library/process/pty.h new file mode 100644 index 0000000000..b585782d12 --- /dev/null +++ b/yt/yt/library/process/pty.h @@ -0,0 +1,33 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/net/public.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +class TPty + : public TNonCopyable +{ +public: + static const int InvalidFD = -1; + + TPty(int height, int width); + ~TPty(); + + NNet::IConnectionReaderPtr CreateMasterAsyncReader(); + NNet::IConnectionWriterPtr CreateMasterAsyncWriter(); + + int GetMasterFD() const; + int GetSlaveFD() const; + +private: + int MasterFD_ = InvalidFD; + int SlaveFD_ = InvalidFD; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/public.h b/yt/yt/library/process/public.h new file mode 100644 index 0000000000..0fa1d3d0a9 --- /dev/null +++ b/yt/yt/library/process/public.h @@ -0,0 +1,14 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT::NPipes { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TNamedPipe) +DECLARE_REFCOUNTED_CLASS(TNamedPipeConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NPipes diff --git a/yt/yt/library/process/subprocess.cpp b/yt/yt/library/process/subprocess.cpp new file mode 100644 index 0000000000..02555b0c9b --- /dev/null +++ b/yt/yt/library/process/subprocess.cpp @@ -0,0 +1,153 @@ +#include "subprocess.h" + +#include <yt/yt/core/misc/blob.h> +#include <yt/yt/core/misc/proc.h> +#include <yt/yt/core/misc/finally.h> + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/net/connection.h> + +#include <util/system/execpath.h> + +#include <array> + +namespace NYT { + +using namespace NConcurrency; +using namespace NPipes; + +//////////////////////////////////////////////////////////////////////////////// + +const static size_t PipeBlockSize = 64 * 1024; +static NLogging::TLogger Logger("Subprocess"); + +//////////////////////////////////////////////////////////////////////////////// + +TSubprocess::TSubprocess(const TString& path, bool copyEnv) + : Process_(New<TSimpleProcess>(path, copyEnv)) +{ } + +TSubprocess TSubprocess::CreateCurrentProcessSpawner() +{ + return TSubprocess(GetExecPath()); +} + +void TSubprocess::AddArgument(TStringBuf arg) +{ + Process_->AddArgument(arg); +} + +void TSubprocess::AddArguments(std::initializer_list<TStringBuf> args) +{ + Process_->AddArguments(args); +} + +TSubprocessResult TSubprocess::Execute(const TSharedRef& input) +{ +#ifdef _unix_ + auto inputStream = Process_->GetStdInWriter(); + auto outputStream = Process_->GetStdOutReader(); + auto errorStream = Process_->GetStdErrReader(); + auto finished = Process_->Spawn(); + + auto readIntoBlob = [] (IAsyncInputStreamPtr stream) { + TBlob output; + auto buffer = TSharedMutableRef::Allocate(PipeBlockSize, {.InitializeStorage = false}); + while (true) { + auto size = WaitFor(stream->Read(buffer)) + .ValueOrThrow(); + + if (size == 0) + break; + + // ToDo(psushin): eliminate copying. + output.Append(buffer.Begin(), size); + } + return TSharedRef::FromBlob(std::move(output)); + }; + + auto writeStdin = BIND([=] { + if (input.Size() > 0) { + WaitFor(inputStream->Write(input)) + .ThrowOnError(); + } + + WaitFor(inputStream->Close()) + .ThrowOnError(); + + //! Return dummy ref, so later we cat put Future into vector + //! along with stdout and stderr. + return TSharedRef::MakeEmpty(); + }); + + std::vector<TFuture<TSharedRef>> futures = { + BIND(readIntoBlob, outputStream).AsyncVia(GetCurrentInvoker()).Run(), + BIND(readIntoBlob, errorStream).AsyncVia(GetCurrentInvoker()).Run(), + writeStdin.AsyncVia(GetCurrentInvoker()).Run(), + }; + + try { + auto outputsOrError = WaitFor(AllSucceeded(futures)); + THROW_ERROR_EXCEPTION_IF_FAILED( + outputsOrError, + "IO error occurred during subprocess call"); + + const auto& outputs = outputsOrError.Value(); + YT_VERIFY(outputs.size() == 3); + + // This can block indefinitely. + auto exitCode = WaitFor(finished); + return TSubprocessResult{outputs[0], outputs[1], exitCode}; + } catch (...) { + try { + Process_->Kill(SIGKILL); + } catch (...) { } + Y_UNUSED(WaitFor(finished)); + throw; + } +#else + THROW_ERROR_EXCEPTION("Unsupported platform"); +#endif +} + +void TSubprocess::Kill(int signal) +{ + Process_->Kill(signal); +} + +TString TSubprocess::GetCommandLine() const +{ + return Process_->GetCommandLine(); +} + +TProcessBasePtr TSubprocess::GetProcess() const +{ + return Process_; +} + +//////////////////////////////////////////////////////////////////////////////// + +void RunSubprocess(const std::vector<TString>& cmd) +{ + if (cmd.empty()) { + THROW_ERROR_EXCEPTION("Command can't be empty"); + } + + auto process = TSubprocess(cmd[0]); + for (int index = 1; index < std::ssize(cmd); ++index) { + process.AddArgument(cmd[index]); + } + + auto result = process.Execute(); + if (!result.Status.IsOK()) { + THROW_ERROR_EXCEPTION("Failed to run %v", cmd[0]) + << result.Status + << TErrorAttribute("command_line", process.GetCommandLine()) + << TErrorAttribute("error", TString(result.Error.Begin(), result.Error.End())); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/subprocess.h b/yt/yt/library/process/subprocess.h new file mode 100644 index 0000000000..223db533f6 --- /dev/null +++ b/yt/yt/library/process/subprocess.h @@ -0,0 +1,48 @@ +#pragma once + +#include "public.h" +#include "process.h" + +#include <library/cpp/yt/memory/ref.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +struct TSubprocessResult +{ + TSharedRef Output; + TSharedRef Error; + TError Status; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TSubprocess +{ +public: + explicit TSubprocess(const TString& path, bool copyEnv = true); + + static TSubprocess CreateCurrentProcessSpawner(); + + void AddArgument(TStringBuf arg); + void AddArguments(std::initializer_list<TStringBuf> args); + + TSubprocessResult Execute(const TSharedRef& input = TSharedRef::MakeEmpty()); + void Kill(int signal); + + TString GetCommandLine() const; + + TProcessBasePtr GetProcess() const; + +private: + const TProcessBasePtr Process_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +void RunSubprocess(const std::vector<TString>& cmd); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/process/ya.make b/yt/yt/library/process/ya.make new file mode 100644 index 0000000000..79763c7267 --- /dev/null +++ b/yt/yt/library/process/ya.make @@ -0,0 +1,22 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + io_dispatcher.cpp + pipe.cpp + process.cpp + pty.cpp + subprocess.cpp +) + +PEERDIR( + yt/yt/core + contrib/libs/re2 +) + +END() + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/program/build_attributes.cpp b/yt/yt/library/program/build_attributes.cpp new file mode 100644 index 0000000000..38caf57997 --- /dev/null +++ b/yt/yt/library/program/build_attributes.cpp @@ -0,0 +1,107 @@ +#include "build_attributes.h" + +#include <yt/yt/build/build.h> + +#include <yt/yt/core/ytree/fluent.h> +#include <yt/yt/core/ytree/ypath_client.h> + +#include <yt/yt/core/misc/error_code.h> + +namespace NYT { + +using namespace NYTree; +using namespace NYson; + +static const NLogging::TLogger Logger("Build"); + +//////////////////////////////////////////////////////////////////////////////// + +void TBuildInfo::Register(TRegistrar registrar) +{ + registrar.Parameter("name", &TThis::Name) + .Default(); + + registrar.Parameter("version", &TThis::Version) + .Default(GetVersion()); + + registrar.Parameter("build_host", &TThis::BuildHost) + .Default(GetBuildHost()); + + registrar.Parameter("build_time", &TThis::BuildTime) + .Default(ParseBuildTime()); + + registrar.Parameter("start_time", &TThis::StartTime) + .Default(TInstant::Now()); +} + +std::optional<TInstant> TBuildInfo::ParseBuildTime() +{ + TString rawBuildTime(GetBuildTime()); + + // Build time may be empty if code is building + // without -DBUILD_DATE (for example, in opensource build). + if (rawBuildTime.empty()) { + return std::nullopt; + } + + try { + return TInstant::ParseIso8601(rawBuildTime); + } catch (const std::exception& ex) { + YT_LOG_ERROR(ex, "Error parsing build time"); + return std::nullopt; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TBuildInfoPtr BuildBuildAttributes(const char* serviceName) +{ + auto info = New<TBuildInfo>(); + if (serviceName) { + info->Name = serviceName; + } + return info; +} + +void SetBuildAttributes(IYPathServicePtr orchidRoot, const char* serviceName) +{ + SyncYPathSet( + orchidRoot, + "/service", + BuildYsonStringFluently() + .BeginAttributes() + .Item("opaque").Value(true) + .EndAttributes() + .Value(BuildBuildAttributes(serviceName))); + SyncYPathSet( + orchidRoot, + "/error_codes", + BuildYsonStringFluently() + .BeginAttributes() + .Item("opaque").Value(true) + .EndAttributes() + .DoMapFor(TErrorCodeRegistry::Get()->GetAllErrorCodes(), [] (TFluentMap fluent, const auto& pair) { + fluent + .Item(ToString(pair.first)).BeginMap() + .Item("cpp_literal").Value(ToString(pair.second)) + .EndMap(); + })); + SyncYPathSet( + orchidRoot, + "/error_code_ranges", + BuildYsonStringFluently() + .BeginAttributes() + .Item("opaque").Value(true) + .EndAttributes() + .DoMapFor(TErrorCodeRegistry::Get()->GetAllErrorCodeRanges(), [] (TFluentMap fluent, const TErrorCodeRegistry::TErrorCodeRangeInfo& range) { + fluent + .Item(ToString(range)).BeginMap() + .Item("cpp_enum").Value(range.Namespace) + .EndMap(); + })); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + diff --git a/yt/yt/library/program/build_attributes.h b/yt/yt/library/program/build_attributes.h new file mode 100644 index 0000000000..e02f86b351 --- /dev/null +++ b/yt/yt/library/program/build_attributes.h @@ -0,0 +1,44 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/public.h> +#include <yt/yt/core/ytree/yson_struct.h> + +#include <yt/yt/core/yson/public.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TBuildInfo + : public NYTree::TYsonStruct +{ +public: + std::optional<TString> Name; + TString Version; + TString BuildHost; + std::optional<TInstant> BuildTime; + TInstant StartTime; + + REGISTER_YSON_STRUCT(TBuildInfo); + + static void Register(TRegistrar registrar); + +private: + static std::optional<TInstant> ParseBuildTime(); +}; + +DEFINE_REFCOUNTED_TYPE(TBuildInfo) + +//////////////////////////////////////////////////////////////////////////////// + +//! Build build (pun intended) attributes as a TBuildInfo a-la /orchid/service. If service name is not provided, +//! it is omitted from the result. +TBuildInfoPtr BuildBuildAttributes(const char* serviceName = nullptr); + +void SetBuildAttributes(NYTree::IYPathServicePtr orchidRoot, const char* serviceName); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/config.cpp b/yt/yt/library/program/config.cpp new file mode 100644 index 0000000000..84f0bddc39 --- /dev/null +++ b/yt/yt/library/program/config.cpp @@ -0,0 +1,200 @@ +#include "config.h" + +namespace NYT { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +void TRpcConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("tracing", &TThis::Tracing) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void THeapSizeLimit::Register(TRegistrar registrar) +{ + registrar.Parameter("container_memory_ratio", &TThis::ContainerMemoryRatio) + .Optional(); + registrar.Parameter("is_hard", &TThis::IsHard) + .Default(false); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TTCMallocConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("background_release_rate", &TThis::BackgroundReleaseRate) + .Default(32_MB); + registrar.Parameter("max_per_cpu_cache_size", &TThis::MaxPerCpuCacheSize) + .Default(3_MB); + + registrar.Parameter("aggressive_release_threshold", &TThis::AggressiveReleaseThreshold) + .Default(20_GB); + registrar.Parameter("aggressive_release_threshold_ratio", &TThis::AggressiveReleaseThresholdRatio) + .Optional(); + + registrar.Parameter("aggressive_release_size", &TThis::AggressiveReleaseSize) + .Default(128_MB); + registrar.Parameter("aggressive_release_period", &TThis::AggressiveReleasePeriod) + .Default(TDuration::MilliSeconds(100)); + registrar.Parameter("guarded_sampling_rate", &TThis::GuardedSamplingRate) + .Default(128_MB); + + registrar.Parameter("heap_size_limit", &TThis::HeapSizeLimit) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TStockpileConfig::Register(TRegistrar registrar) +{ + registrar.BaseClassParameter("buffer_size", &TThis::BufferSize) + .Default(DefaultBufferSize); + registrar.BaseClassParameter("thread_count", &TThis::ThreadCount) + .Default(DefaultThreadCount); + registrar.BaseClassParameter("period", &TThis::Period) + .Default(DefaultPeriod); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TSingletonsConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("spin_wait_slow_path_logging_threshold", &TThis::SpinWaitSlowPathLoggingThreshold) + .Default(TDuration::MicroSeconds(100)); + registrar.Parameter("yt_alloc", &TThis::YTAlloc) + .DefaultNew(); + registrar.Parameter("fiber_stack_pool_sizes", &TThis::FiberStackPoolSizes) + .Default({}); + registrar.Parameter("address_resolver", &TThis::AddressResolver) + .DefaultNew(); + registrar.Parameter("tcp_dispatcher", &TThis::TcpDispatcher) + .DefaultNew(); + registrar.Parameter("rpc_dispatcher", &TThis::RpcDispatcher) + .DefaultNew(); + registrar.Parameter("grpc_dispatcher", &TThis::GrpcDispatcher) + .DefaultNew(); + registrar.Parameter("yp_service_discovery", &TThis::YPServiceDiscovery) + .DefaultNew(); + registrar.Parameter("solomon_exporter", &TThis::SolomonExporter) + .DefaultNew(); + registrar.Parameter("logging", &TThis::Logging) + .DefaultCtor([] () { return NLogging::TLogManagerConfig::CreateDefault(); }); + registrar.Parameter("jaeger", &TThis::Jaeger) + .DefaultNew(); + registrar.Parameter("rpc", &TThis::Rpc) + .DefaultNew(); + registrar.Parameter("tcmalloc", &TThis::TCMalloc) + .DefaultNew(); + registrar.Parameter("stockpile", &TThis::Stockpile) + .DefaultNew(); + registrar.Parameter("enable_ref_counted_tracker_profiling", &TThis::EnableRefCountedTrackerProfiling) + .Default(true); + registrar.Parameter("enable_resource_tracker", &TThis::EnableResourceTracker) + .Default(true); + registrar.Parameter("enable_porto_resource_tracker", &TThis::EnablePortoResourceTracker) + .Default(false); + registrar.Parameter("resource_tracker_vcpu_factor", &TThis::ResourceTrackerVCpuFactor) + .Optional(); + registrar.Parameter("pod_spec", &TThis::PodSpec) + .DefaultNew(); + + registrar.Postprocessor([] (TThis* config) { + if (config->ResourceTrackerVCpuFactor && !config->EnableResourceTracker) { + THROW_ERROR_EXCEPTION("Option \"resource_tracker_vcpu_factor\" can be specified only if resource tracker is enabled"); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TSingletonsDynamicConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("spin_lock_slow_path_logging_threshold", &TThis::SpinWaitSlowPathLoggingThreshold) + .Optional(); + registrar.Parameter("yt_alloc", &TThis::YTAlloc) + .Optional(); + registrar.Parameter("tcp_dispatcher", &TThis::TcpDispatcher) + .DefaultNew(); + registrar.Parameter("rpc_dispatcher", &TThis::RpcDispatcher) + .DefaultNew(); + registrar.Parameter("logging", &TThis::Logging) + .DefaultNew(); + registrar.Parameter("jaeger", &TThis::Jaeger) + .DefaultNew(); + registrar.Parameter("rpc", &TThis::Rpc) + .DefaultNew(); + registrar.Parameter("tcmalloc", &TThis::TCMalloc) + .Optional(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TDiagnosticDumpConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("yt_alloc_dump_period", &TThis::YTAllocDumpPeriod) + .Default(); + registrar.Parameter("ref_counted_tracker_dump_period", &TThis::RefCountedTrackerDumpPeriod) + .Default(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void WarnForUnrecognizedOptionsImpl( + const NLogging::TLogger& logger, + const IMapNodePtr& unrecognized) +{ + const auto& Logger = logger; + if (unrecognized && unrecognized->GetChildCount() > 0) { + YT_LOG_WARNING("Bootstrap config contains unrecognized options (Unrecognized: %v)", + ConvertToYsonString(unrecognized, NYson::EYsonFormat::Text)); + } +} + +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config) +{ + WarnForUnrecognizedOptionsImpl(logger, config->GetRecursiveUnrecognized()); +} + +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config) +{ + WarnForUnrecognizedOptionsImpl(logger, config->GetUnrecognizedRecursively()); +} + +void AbortOnUnrecognizedOptionsImpl( + const NLogging::TLogger& logger, + const IMapNodePtr& unrecognized) +{ + const auto& Logger = logger; + if (unrecognized && unrecognized->GetChildCount() > 0) { + YT_LOG_ERROR("Bootstrap config contains unrecognized options, terminating (Unrecognized: %v)", + ConvertToYsonString(unrecognized, NYson::EYsonFormat::Text)); + YT_ABORT(); + } +} + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config) +{ + AbortOnUnrecognizedOptionsImpl(logger, config->GetRecursiveUnrecognized()); +} + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config) +{ + AbortOnUnrecognizedOptionsImpl(logger, config->GetUnrecognizedRecursively()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + diff --git a/yt/yt/library/program/config.h b/yt/yt/library/program/config.h new file mode 100644 index 0000000000..7bb3f9c0da --- /dev/null +++ b/yt/yt/library/program/config.h @@ -0,0 +1,207 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/yson_serializable.h> +#include <yt/yt/core/ytree/yson_struct.h> + +#include <yt/yt/core/ytalloc/config.h> + +#include <yt/yt/core/net/config.h> + +#include <yt/yt/core/rpc/config.h> +#include <yt/yt/core/rpc/grpc/config.h> + +#include <yt/yt/core/bus/tcp/config.h> + +#include <yt/yt/core/logging/config.h> + +#include <yt/yt/core/tracing/config.h> + +#include <yt/yt/core/service_discovery/yp/config.h> + +#include <yt/yt/library/profiling/solomon/exporter.h> + +#include <yt/yt/library/containers/config.h> + +#include <yt/yt/library/tracing/jaeger/tracer.h> + +#include <library/cpp/yt/stockpile/stockpile.h> + + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TRpcConfig + : public NYTree::TYsonStruct +{ +public: + NTracing::TTracingConfigPtr Tracing; + + REGISTER_YSON_STRUCT(TRpcConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TRpcConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class THeapSizeLimit + : public virtual NYTree::TYsonStruct +{ +public: + //! Limit program memory in terms of container memory. + // If program heap size exceeds the limit tcmalloc is instructed to release memory to the kernel. + std::optional<double> ContainerMemoryRatio; + + //! If true tcmalloc crashes when system allocates more memory than #ContainerMemoryRatio. + bool IsHard; + + REGISTER_YSON_STRUCT(THeapSizeLimit); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(THeapSizeLimit) + +//////////////////////////////////////////////////////////////////////////////// + +class TTCMallocConfig + : public virtual NYTree::TYsonStruct +{ +public: + i64 BackgroundReleaseRate; + int MaxPerCpuCacheSize; + + //! Threshold in bytes + i64 AggressiveReleaseThreshold; + + //! Threshold in fractions of total memory of the container + std::optional<double> AggressiveReleaseThresholdRatio; + + i64 AggressiveReleaseSize; + TDuration AggressiveReleasePeriod; + + //! Approximately 1/#GuardedSamplingRate of all allocations of + //! size <= 256 KiB will be under GWP-ASAN. + std::optional<i64> GuardedSamplingRate; + + THeapSizeLimitPtr HeapSizeLimit; + + REGISTER_YSON_STRUCT(TTCMallocConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TTCMallocConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TStockpileConfig + : public TStockpileOptions + , public NYTree::TYsonStruct +{ +public: + REGISTER_YSON_STRUCT(TStockpileConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TStockpileConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TSingletonsConfig + : public virtual NYTree::TYsonStruct +{ +public: + TDuration SpinWaitSlowPathLoggingThreshold; + NYTAlloc::TYTAllocConfigPtr YTAlloc; + THashMap<TString, int> FiberStackPoolSizes; + NNet::TAddressResolverConfigPtr AddressResolver; + NBus::TTcpDispatcherConfigPtr TcpDispatcher; + NRpc::TDispatcherConfigPtr RpcDispatcher; + NRpc::NGrpc::TDispatcherConfigPtr GrpcDispatcher; + NServiceDiscovery::NYP::TServiceDiscoveryConfigPtr YPServiceDiscovery; + NProfiling::TSolomonExporterConfigPtr SolomonExporter; + NLogging::TLogManagerConfigPtr Logging; + NTracing::TJaegerTracerConfigPtr Jaeger; + TRpcConfigPtr Rpc; + TTCMallocConfigPtr TCMalloc; + TStockpileConfigPtr Stockpile; + bool EnableRefCountedTrackerProfiling; + bool EnableResourceTracker; + bool EnablePortoResourceTracker; + std::optional<double> ResourceTrackerVCpuFactor; + NContainers::TPodSpecConfigPtr PodSpec; + + REGISTER_YSON_STRUCT(TSingletonsConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TSingletonsConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TSingletonsDynamicConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::optional<TDuration> SpinWaitSlowPathLoggingThreshold; + NYTAlloc::TYTAllocConfigPtr YTAlloc; + NBus::TTcpDispatcherDynamicConfigPtr TcpDispatcher; + NRpc::TDispatcherDynamicConfigPtr RpcDispatcher; + NLogging::TLogManagerDynamicConfigPtr Logging; + NTracing::TJaegerTracerDynamicConfigPtr Jaeger; + TRpcConfigPtr Rpc; + TTCMallocConfigPtr TCMalloc; + + REGISTER_YSON_STRUCT(TSingletonsDynamicConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TSingletonsDynamicConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TDiagnosticDumpConfig + : public virtual NYTree::TYsonStruct +{ +public: + std::optional<TDuration> YTAllocDumpPeriod; + std::optional<TDuration> RefCountedTrackerDumpPeriod; + + REGISTER_YSON_STRUCT(TDiagnosticDumpConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TDiagnosticDumpConfig) + +//////////////////////////////////////////////////////////////////////////////// + +// NB: These functions should not be called from bootstrap +// config validator since logger is not set up yet. +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config); + +void WarnForUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config); + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonStructPtr& config); + +void AbortOnUnrecognizedOptions( + const NLogging::TLogger& logger, + const NYTree::TYsonSerializablePtr& config); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/helpers.cpp b/yt/yt/library/program/helpers.cpp new file mode 100644 index 0000000000..5c7ff29db1 --- /dev/null +++ b/yt/yt/library/program/helpers.cpp @@ -0,0 +1,335 @@ +#include "helpers.h" +#include "config.h" +#include "private.h" + +#include <yt/yt/core/ytalloc/bindings.h> + +#include <yt/yt/core/misc/lazy_ptr.h> +#include <yt/yt/core/misc/ref_counted_tracker.h> +#include <yt/yt/core/misc/ref_counted_tracker_profiler.h> + +#include <yt/yt/core/bus/tcp/dispatcher.h> + +#include <yt/yt/library/tracing/jaeger/tracer.h> + +#include <yt/yt/library/profiling/perf/counters.h> + +#include <yt/yt/library/profiling/resource_tracker/resource_tracker.h> + +#include <yt/yt/library/containers/config.h> +#include <yt/yt/library/containers/porto_resource_tracker.h> + +#include <yt/yt/core/logging/log_manager.h> + +#include <yt/yt/core/concurrency/execution_stack.h> +#include <yt/yt/core/concurrency/periodic_executor.h> +#include <yt/yt/core/concurrency/private.h> + +#include <tcmalloc/malloc_extension.h> + +#include <yt/yt/core/net/address.h> +#include <yt/yt/core/net/local_address.h> + +#include <yt/yt/core/rpc/dispatcher.h> +#include <yt/yt/core/rpc/grpc/dispatcher.h> + +#include <yt/yt/core/service_discovery/yp/service_discovery.h> + +#include <yt/yt/core/threading/spin_wait_slow_path_logger.h> + +#include <library/cpp/yt/threading/spin_wait_hook.h> + +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +#include <util/string/split.h> +#include <util/system/thread.h> + +#include <mutex> +#include <thread> + +namespace NYT { + +using namespace NConcurrency; +using namespace NThreading; + +//////////////////////////////////////////////////////////////////////////////// + +static std::once_flag InitAggressiveReleaseThread; +static auto& Logger = ProgramLogger; + +//////////////////////////////////////////////////////////////////////////////// + +class TCMallocLimitsAdjuster +{ +public: + void Adjust(const TTCMallocConfigPtr& config) + { + i64 totalMemory = GetContainerMemoryLimit(); + AdjustPageHeapLimit(totalMemory, config); + AdjustAggressiveReleaseThreshold(totalMemory, config); + } + + i64 GetAggressiveReleaseThreshold() + { + return AggressiveReleaseThreshold_; + } + +private: + using TAllocatorMemoryLimit = tcmalloc::MallocExtension::MemoryLimit; + + TAllocatorMemoryLimit AppliedLimit_; + i64 AggressiveReleaseThreshold_ = 0; + + + void AdjustPageHeapLimit(i64 totalMemory, const TTCMallocConfigPtr& config) + { + auto proposed = ProposeHeapMemoryLimit(totalMemory, config); + + if (proposed.limit == AppliedLimit_.limit && proposed.hard == AppliedLimit_.hard) { + // Already applied + return; + } + + YT_LOG_INFO("Changing tcmalloc memory limit (Limit: %v, IsHard: %v)", + proposed.limit, + proposed.hard); + + tcmalloc::MallocExtension::SetMemoryLimit(proposed); + AppliedLimit_ = proposed; + } + + void AdjustAggressiveReleaseThreshold(i64 totalMemory, const TTCMallocConfigPtr& config) + { + if (totalMemory && config->AggressiveReleaseThresholdRatio) { + AggressiveReleaseThreshold_ = *config->AggressiveReleaseThresholdRatio * totalMemory; + } else { + AggressiveReleaseThreshold_ = config->AggressiveReleaseThreshold; + } + } + + i64 GetContainerMemoryLimit() const + { + auto resourceTracker = NProfiling::GetResourceTracker(); + if (!resourceTracker) { + return 0; + } + + return resourceTracker->GetTotalMemoryLimit(); + } + + TAllocatorMemoryLimit ProposeHeapMemoryLimit(i64 totalMemory, const TTCMallocConfigPtr& config) const + { + const auto& heapLimitConfig = config->HeapSizeLimit; + + if (totalMemory == 0 || !heapLimitConfig->ContainerMemoryRatio) { + return {}; + } + + TAllocatorMemoryLimit proposed; + proposed.limit = *heapLimitConfig->ContainerMemoryRatio * totalMemory; + proposed.hard = heapLimitConfig->IsHard; + + return proposed; + } +}; + +void ConfigureTCMalloc(const TTCMallocConfigPtr& config) +{ + tcmalloc::MallocExtension::SetBackgroundReleaseRate( + tcmalloc::MallocExtension::BytesPerSecond{static_cast<size_t>(config->BackgroundReleaseRate)}); + + tcmalloc::MallocExtension::SetMaxPerCpuCacheSize(config->MaxPerCpuCacheSize); + + if (config->GuardedSamplingRate) { + tcmalloc::MallocExtension::SetGuardedSamplingRate(*config->GuardedSamplingRate); + tcmalloc::MallocExtension::ActivateGuardedSampling(); + } + + struct TConfigSingleton + { + TAtomicIntrusivePtr<TTCMallocConfig> Config; + }; + + LeakySingleton<TConfigSingleton>()->Config.Store(config); + + if (tcmalloc::MallocExtension::NeedsProcessBackgroundActions()) { + std::call_once(InitAggressiveReleaseThread, [] { + std::thread([] { + ::TThread::SetCurrentThreadName("TCAllocYT"); + + TCMallocLimitsAdjuster limitsAdjuster; + + while (true) { + auto config = LeakySingleton<TConfigSingleton>()->Config.Acquire(); + limitsAdjuster.Adjust(config); + + auto freeBytes = tcmalloc::MallocExtension::GetNumericProperty("tcmalloc.page_heap_free"); + YT_VERIFY(freeBytes); + + if (static_cast<i64>(*freeBytes) > limitsAdjuster.GetAggressiveReleaseThreshold()) { + + YT_LOG_DEBUG("Aggressively releasing memory (FreeBytes: %v, Threshold: %v)", + static_cast<i64>(*freeBytes), + limitsAdjuster.GetAggressiveReleaseThreshold()); + + tcmalloc::MallocExtension::ReleaseMemoryToSystem(config->AggressiveReleaseSize); + } + + Sleep(config->AggressiveReleasePeriod); + } + }).detach(); + }); + } +} + +template <class TConfig> +void ConfigureSingletonsImpl(const TConfig& config) +{ + SetSpinWaitSlowPathLoggingThreshold(config->SpinWaitSlowPathLoggingThreshold); + + if (!NYTAlloc::ConfigureFromEnv()) { + NYTAlloc::Configure(config->YTAlloc); + } + + for (const auto& [kind, size] : config->FiberStackPoolSizes) { + NConcurrency::SetFiberStackPoolSize(ParseEnum<NConcurrency::EExecutionStackKind>(kind), size); + } + + NLogging::TLogManager::Get()->EnableReopenOnSighup(); + if (!NLogging::TLogManager::Get()->IsConfiguredFromEnv()) { + NLogging::TLogManager::Get()->Configure(config->Logging); + } + + NNet::TAddressResolver::Get()->Configure(config->AddressResolver); + // By default, server components must have a reasonable FQDN. + // Failure to do so may result in issues like YT-4561. + NNet::TAddressResolver::Get()->EnsureLocalHostName(); + + NBus::TTcpDispatcher::Get()->Configure(config->TcpDispatcher); + + NRpc::TDispatcher::Get()->Configure(config->RpcDispatcher); + + NRpc::NGrpc::TDispatcher::Get()->Configure(config->GrpcDispatcher); + + NRpc::TDispatcher::Get()->SetServiceDiscovery( + NServiceDiscovery::NYP::CreateServiceDiscovery(config->YPServiceDiscovery)); + + NTracing::SetGlobalTracer(New<NTracing::TJaegerTracer>(config->Jaeger)); + + NProfiling::EnablePerfCounters(); + + if (auto tracingConfig = config->Rpc->Tracing) { + NTracing::SetTracingConfig(tracingConfig); + } + + ConfigureTCMalloc(config->TCMalloc); + + ConfigureStockpile(*config->Stockpile); + + if (config->EnableRefCountedTrackerProfiling) { + EnableRefCountedTrackerProfiling(); + } + + if (config->EnableResourceTracker) { + NProfiling::EnableResourceTracker(); + if (config->ResourceTrackerVCpuFactor.has_value()) { + NProfiling::SetVCpuFactor(config->ResourceTrackerVCpuFactor.value()); + } + } + + if (config->EnablePortoResourceTracker) { + NContainers::EnablePortoResourceTracker(config->PodSpec); + } +} + +void ConfigureSingletons(const TSingletonsConfigPtr& config) +{ + ConfigureSingletonsImpl(config); +} + +template <class TStaticConfig, class TDynamicConfig> +void ReconfigureSingletonsImpl(const TStaticConfig& config, const TDynamicConfig& dynamicConfig) +{ + SetSpinWaitSlowPathLoggingThreshold(dynamicConfig->SpinWaitSlowPathLoggingThreshold.value_or(config->SpinWaitSlowPathLoggingThreshold)); + + if (!NYTAlloc::IsConfiguredFromEnv()) { + NYTAlloc::Configure(dynamicConfig->YTAlloc ? dynamicConfig->YTAlloc : config->YTAlloc); + } + + if (!NLogging::TLogManager::Get()->IsConfiguredFromEnv()) { + NLogging::TLogManager::Get()->Configure( + config->Logging->ApplyDynamic(dynamicConfig->Logging), + /*sync*/ false); + } + + auto tracer = NTracing::GetGlobalTracer(); + if (auto jaeger = DynamicPointerCast<NTracing::TJaegerTracer>(tracer); jaeger) { + jaeger->Configure(config->Jaeger->ApplyDynamic(dynamicConfig->Jaeger)); + } + + NBus::TTcpDispatcher::Get()->Configure(config->TcpDispatcher->ApplyDynamic(dynamicConfig->TcpDispatcher)); + + NRpc::TDispatcher::Get()->Configure(config->RpcDispatcher->ApplyDynamic(dynamicConfig->RpcDispatcher)); + + if (dynamicConfig->Rpc->Tracing) { + NTracing::SetTracingConfig(dynamicConfig->Rpc->Tracing); + } else if (config->Rpc->Tracing) { + NTracing::SetTracingConfig(config->Rpc->Tracing); + } + + if (dynamicConfig->TCMalloc) { + ConfigureTCMalloc(dynamicConfig->TCMalloc); + } else if (config->TCMalloc) { + ConfigureTCMalloc(config->TCMalloc); + } +} + +void ReconfigureSingletons(const TSingletonsConfigPtr& config, const TSingletonsDynamicConfigPtr& dynamicConfig) +{ + ReconfigureSingletonsImpl(config, dynamicConfig); +} + +template <class TConfig> +void StartDiagnosticDumpImpl(const TConfig& config) +{ + static NLogging::TLogger Logger("DiagDump"); + + auto logDumpString = [&] (TStringBuf banner, const TString& str) { + for (const auto& line : StringSplitter(str).Split('\n')) { + YT_LOG_DEBUG("%v %v", banner, line.Token()); + } + }; + + if (config->YTAllocDumpPeriod) { + static const TLazyIntrusivePtr<TPeriodicExecutor> Executor(BIND([&] { + return New<TPeriodicExecutor>( + NRpc::TDispatcher::Get()->GetHeavyInvoker(), + BIND([&] { + logDumpString("YTAlloc", NYTAlloc::FormatAllocationCounters()); + })); + })); + Executor->SetPeriod(config->YTAllocDumpPeriod); + Executor->Start(); + } + + if (config->RefCountedTrackerDumpPeriod) { + static const TLazyIntrusivePtr<TPeriodicExecutor> Executor(BIND([&] { + return New<TPeriodicExecutor>( + NRpc::TDispatcher::Get()->GetHeavyInvoker(), + BIND([&] { + logDumpString("RCT", TRefCountedTracker::Get()->GetDebugInfo()); + })); + })); + Executor->SetPeriod(config->RefCountedTrackerDumpPeriod); + Executor->Start(); + } +} + +void StartDiagnosticDump(const TDiagnosticDumpConfigPtr& config) +{ + StartDiagnosticDumpImpl(config); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/helpers.h b/yt/yt/library/program/helpers.h new file mode 100644 index 0000000000..be09ec889c --- /dev/null +++ b/yt/yt/library/program/helpers.h @@ -0,0 +1,18 @@ +#pragma once + +#include "public.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void ConfigureSingletons(const TSingletonsConfigPtr& config); +void ReconfigureSingletons( + const TSingletonsConfigPtr& config, + const TSingletonsDynamicConfigPtr& dynamicConfig); + +void StartDiagnosticDump(const TDiagnosticDumpConfigPtr& config); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/private.h b/yt/yt/library/program/private.h new file mode 100644 index 0000000000..e328f30667 --- /dev/null +++ b/yt/yt/library/program/private.h @@ -0,0 +1,15 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger ProgramLogger("Program"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program.cpp b/yt/yt/library/program/program.cpp new file mode 100644 index 0000000000..621f3d65b6 --- /dev/null +++ b/yt/yt/library/program/program.cpp @@ -0,0 +1,383 @@ +#include "program.h" + +#include "build_attributes.h" + +#include <yt/yt/build/build.h> + +#include <yt/yt/core/misc/crash_handler.h> +#include <yt/yt/core/misc/signal_registry.h> +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/shutdown.h> + +#include <yt/yt/core/ytalloc/bindings.h> + +#include <yt/yt/core/yson/writer.h> +#include <yt/yt/core/yson/null_consumer.h> + +#include <yt/yt/core/logging/log_manager.h> + +#include <yt/yt/library/ytprof/heap_profiler.h> + +#include <yt/yt/library/profiling/tcmalloc/profiler.h> + +#include <library/cpp/ytalloc/api/ytalloc.h> + +#include <library/cpp/yt/mlock/mlock.h> +#include <library/cpp/yt/stockpile/stockpile.h> + +#include <tcmalloc/malloc_extension.h> + +#include <absl/debugging/stacktrace.h> + +#include <util/system/thread.h> +#include <util/system/sigset.h> + +#include <util/string/subst.h> + +#include <thread> + +#include <stdlib.h> + +#ifdef _unix_ +#include <unistd.h> +#include <sys/types.h> +#include <sys/stat.h> +#endif + +#ifdef _linux_ +#include <grp.h> +#include <sys/prctl.h> +#endif + +#if defined(_linux_) && defined(CLANG_COVERAGE) +extern "C" int __llvm_profile_write_file(void); +extern "C" void __llvm_profile_set_filename(const char* name); +#endif + +namespace NYT { + +using namespace NYson; + +//////////////////////////////////////////////////////////////////////////////// + +class TProgram::TOptsParseResult + : public NLastGetopt::TOptsParseResult +{ +public: + TOptsParseResult(TProgram* owner, int argc, const char** argv) + : Owner_(owner) + { + Init(&Owner_->Opts_, argc, argv); + } + + void HandleError() const override + { + Owner_->OnError(CurrentExceptionMessage()); + Cerr << Endl << "Try running '" << Owner_->Argv0_ << " --help' for more information." << Endl; + Owner_->Exit(EProgramExitCode::OptionsError); + } + +private: + TProgram* const Owner_; +}; + +TProgram::TProgram() +{ + Opts_.AddHelpOption(); + Opts_.AddLongOption("yt-version", "print YT version and exit") + .NoArgument() + .StoreValue(&PrintYTVersion_, true); + Opts_.AddLongOption("version", "print version and exit") + .NoArgument() + .StoreValue(&PrintVersion_, true); + Opts_.AddLongOption("yson", "print build information in YSON") + .NoArgument() + .StoreValue(&UseYson_, true); + Opts_.AddLongOption("build", "print build information and exit") + .NoArgument() + .StoreValue(&PrintBuild_, true); + Opts_.SetFreeArgsNum(0); + + ConfigureCoverageOutput(); +} + +void TProgram::SetCrashOnError() +{ + CrashOnError_ = true; +} + +TProgram::~TProgram() = default; + +void TProgram::HandleVersionAndBuild() +{ + if (PrintVersion_) { + PrintVersionAndExit(); + } + if (PrintYTVersion_) { + PrintYTVersionAndExit(); + } + if (PrintBuild_) { + PrintBuildAndExit(); + } +} + +int TProgram::Run(int argc, const char** argv) +{ + ::srand(time(nullptr)); + + auto run = [&] { + Argv0_ = TString(argv[0]); + TOptsParseResult result(this, argc, argv); + + HandleVersionAndBuild(); + + DoRun(result); + }; + + if (!CrashOnError_) { + try { + run(); + Exit(EProgramExitCode::OK); + } catch (...) { + OnError(CurrentExceptionMessage()); + Exit(EProgramExitCode::ProgramError); + } + } else { + run(); + Exit(EProgramExitCode::OK); + } + + // Cannot reach this due to #Exit calls above. + YT_ABORT(); +} + +void TProgram::Abort(EProgramExitCode code) noexcept +{ + Abort(static_cast<int>(code)); +} + +void TProgram::Abort(int code) noexcept +{ + NLogging::TLogManager::Get()->Shutdown(); + + ::_exit(code); +} + +void TProgram::Exit(EProgramExitCode code) noexcept +{ + Exit(static_cast<int>(code)); +} + +void TProgram::Exit(int code) noexcept +{ +#if defined(_linux_) && defined(CLANG_COVERAGE) + __llvm_profile_write_file(); +#endif + + // This explicit call may become obsolete some day; + // cf. the comment section for NYT::Shutdown. + Shutdown({ + .AbortOnHang = ShouldAbortOnHungShutdown(), + .HungExitCode = code + }); + + ::exit(code); +} + +bool TProgram::ShouldAbortOnHungShutdown() noexcept +{ + return true; +} + +void TProgram::OnError(const TString& message) noexcept +{ + try { + Cerr << message << Endl; + } catch (...) { + // Just ignore it; STDERR might be closed already, + // and write() would result in EPIPE. + } +} + +void TProgram::PrintYTVersionAndExit() +{ + if (UseYson_) { + THROW_ERROR_EXCEPTION("--yson is not supported when printing version"); + } + Cout << GetVersion() << Endl; + Exit(0); +} + +void TProgram::PrintBuildAndExit() +{ + if (UseYson_) { + TYsonWriter writer(&Cout, EYsonFormat::Pretty); + Serialize(BuildBuildAttributes(), &writer); + Cout << Endl; + } else { + Cout << "Build Time: " << GetBuildTime() << Endl; + Cout << "Build Host: " << GetBuildHost() << Endl; + } + Exit(0); +} + +void TProgram::PrintVersionAndExit() +{ + PrintYTVersionAndExit(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TProgramException::TProgramException(TString what) + : What_(std::move(what)) +{ } + +const char* TProgramException::what() const noexcept +{ + return What_.c_str(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString CheckPathExistsArgMapper(const TString& arg) +{ + if (!NFS::Exists(arg)) { + throw TProgramException(Format("File %v does not exist", arg)); + } + return arg; +} + +TGuid CheckGuidArgMapper(const TString& arg) +{ + TGuid result; + if (!TGuid::FromString(arg, &result)) { + throw TProgramException(Format("Error parsing guid %Qv", arg)); + } + return result; +} + +NYson::TYsonString CheckYsonArgMapper(const TString& arg) +{ + ParseYsonStringBuffer(arg, EYsonType::Node, GetNullYsonConsumer()); + return NYson::TYsonString(arg); +} + +void ConfigureUids() +{ +#ifdef _unix_ + uid_t ruid, euid; +#ifdef _linux_ + uid_t suid; + YT_VERIFY(getresuid(&ruid, &euid, &suid) == 0); +#else + ruid = getuid(); + euid = geteuid(); +#endif + if (euid == 0) { + // if real uid is already root do not set root as supplementary ids. + if (ruid != 0) { + YT_VERIFY(setgroups(0, nullptr) == 0); + } + // if effective uid == 0 (e. g. set-uid-root), alter saved = effective, effective = real. +#ifdef _linux_ + YT_VERIFY(setresuid(ruid, ruid, euid) == 0); + // Make server suid_dumpable = 1. + YT_VERIFY(prctl(PR_SET_DUMPABLE, 1) == 0); +#else + YT_VERIFY(setuid(euid) == 0); + YT_VERIFY(seteuid(ruid) == 0); + YT_VERIFY(setruid(ruid) == 0); +#endif + } + umask(0000); +#endif +} + +void ConfigureCoverageOutput() +{ +#if defined(_linux_) && defined(CLANG_COVERAGE) + // YT tests use pid namespaces. We can't use process id as unique identifier for output file. + if (auto profileFile = getenv("LLVM_PROFILE_FILE")) { + TString fixedProfile{profileFile}; + SubstGlobal(fixedProfile, "%e", "ytserver-all"); + SubstGlobal(fixedProfile, "%p", ToString(TInstant::Now().NanoSeconds())); + __llvm_profile_set_filename(fixedProfile.c_str()); + } +#endif +} + +void ConfigureIgnoreSigpipe() +{ +#ifdef _unix_ + signal(SIGPIPE, SIG_IGN); +#endif +} + +void ConfigureCrashHandler() +{ + TSignalRegistry::Get()->PushCallback(AllCrashSignals, CrashSignalHandler); + TSignalRegistry::Get()->PushDefaultSignalHandler(AllCrashSignals); +} + +namespace { + +void ExitZero(int /*unused*/) +{ +#if defined(_linux_) && defined(CLANG_COVERAGE) + __llvm_profile_write_file(); +#endif + // TODO(babenko): replace with pure "exit" some day. + // Currently this causes some RPC requests to master to be replied with "Promise abandoned" error, + // which is not retriable. + _exit(0); +} + +} // namespace + +void ConfigureExitZeroOnSigterm() +{ +#ifdef _unix_ + signal(SIGTERM, ExitZero); +#endif +} + +void ConfigureAllocator(const TAllocatorOptions& options) +{ + NYT::MlockFileMappings(); + +#ifdef _linux_ + NYTAlloc::EnableYTLogging(); + NYTAlloc::EnableYTProfiling(); + NYTAlloc::InitializeLibunwindInterop(); + NYTAlloc::SetEnableEagerMemoryRelease(options.YTAllocEagerMemoryRelease); + + if (tcmalloc::MallocExtension::NeedsProcessBackgroundActions()) { + std::thread backgroundThread([] { + TThread::SetCurrentThreadName("TCAllocBack"); + tcmalloc::MallocExtension::ProcessBackgroundActions(); + YT_ABORT(); + }); + backgroundThread.detach(); + } + + NProfiling::EnableTCMallocProfiler(); + NYTProf::EnableMemoryProfilingTags(); + absl::SetStackUnwinder(NYTProf::AbslStackUnwinder); + // TODO(prime@): tune parameters. + tcmalloc::MallocExtension::SetProfileSamplingRate(2_MB); + if (options.TCMallocGuardedSamplingRate) { + tcmalloc::MallocExtension::SetGuardedSamplingRate(*options.TCMallocGuardedSamplingRate); + tcmalloc::MallocExtension::ActivateGuardedSampling(); + } + tcmalloc::MallocExtension::SetMaxPerCpuCacheSize(3_MB); + tcmalloc::MallocExtension::SetMaxTotalThreadCacheBytes(24_MB); + tcmalloc::MallocExtension::SetBackgroundReleaseRate(tcmalloc::MallocExtension::BytesPerSecond{32_MB}); + tcmalloc::MallocExtension::EnableForkSupport(); +#else + Y_UNUSED(options); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program.h b/yt/yt/library/program/program.h new file mode 100644 index 0000000000..3f690f16c9 --- /dev/null +++ b/yt/yt/library/program/program.h @@ -0,0 +1,146 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +#include <library/cpp/yt/stockpile/stockpile.h> + +#include <library/cpp/getopt/last_getopt.h> + +#include <yt/yt/core/yson/string.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EProgramExitCode, + ((OK)(0)) + ((OptionsError)(1)) + ((ProgramError)(2)) +); + +class TProgram +{ +public: + TProgram(); + ~TProgram(); + + TProgram(const TProgram&) = delete; + TProgram(TProgram&&) = delete; + + // This call actually never returns; + // |int| return type is just for the symmetry with |main|. + [[noreturn]] + int Run(int argc, const char** argv); + + //! Handles --version/--yt-version/--build [--yson] if they are present. + void HandleVersionAndBuild(); + + //! Nongracefully aborts the program. + /*! + * Tries to flush logging messages. + * Aborts via |_exit| call. + */ + [[noreturn]] + static void Abort(EProgramExitCode code) noexcept; + [[noreturn]] + static void Abort(int code) noexcept; + +protected: + NLastGetopt::TOpts Opts_; + TString Argv0_; + bool PrintYTVersion_ = false; + bool PrintVersion_ = false; + bool PrintBuild_ = false; + bool UseYson_ = false; + + virtual void DoRun(const NLastGetopt::TOptsParseResult& parseResult) = 0; + + virtual void OnError(const TString& message) noexcept; + + virtual bool ShouldAbortOnHungShutdown() noexcept; + + void SetCrashOnError(); + + //! Handler for --yt-version command argument. + [[noreturn]] + void PrintYTVersionAndExit(); + + //! Handler for --build command argument. + [[noreturn]] + void PrintBuildAndExit(); + + //! Handler for --version command argument. + //! By default, --version and --yt-version work the same way, + //! but some YT components (e.g. CHYT) can override it to provide its own version. + [[noreturn]] + virtual void PrintVersionAndExit(); + + [[noreturn]] + void Exit(EProgramExitCode code) noexcept; + + [[noreturn]] + void Exit(int code) noexcept; + +private: + bool CrashOnError_ = false; + + // Custom handler for option parsing errors. + class TOptsParseResult; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//! The simplest exception possible. +//! Here we refrain from using TErrorException, as it relies on proper configuration of singleton subsystems, +//! which might not be the case during startup. +class TProgramException + : public std::exception +{ +public: + explicit TProgramException(TString what); + + const char* what() const noexcept override; + +private: + const TString What_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//! Helper for TOpt::StoreMappedResult to validate file paths for existance. +TString CheckPathExistsArgMapper(const TString& arg); + +//! Helper for TOpt::StoreMappedResult to parse GUIDs. +TGuid CheckGuidArgMapper(const TString& arg); + +//! Helper for TOpt::StoreMappedResult to parse YSON strings. +NYson::TYsonString CheckYsonArgMapper(const TString& arg); + +//! Drop privileges and save them if running with suid-bit. +void ConfigureUids(); + +void ConfigureCoverageOutput(); + +void ConfigureIgnoreSigpipe(); + +//! Intercepts standard crash signals (see signal_registry.h for full list) with a nice handler. +void ConfigureCrashHandler(); + +//! Intercepts SIGTERM and terminates the process immediately with zero exit code. +void ConfigureExitZeroOnSigterm(); + +//////////////////////////////////////////////////////////////////////////////// + +struct TAllocatorOptions +{ + bool YTAllocEagerMemoryRelease = false; + + bool TCMallocOptimizeSize = false; + std::optional<i64> TCMallocGuardedSamplingRate = 128_MB; +}; + +void ConfigureAllocator(const TAllocatorOptions& options = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_config_mixin.cpp b/yt/yt/library/program/program_config_mixin.cpp new file mode 100644 index 0000000000..9ced4de64f --- /dev/null +++ b/yt/yt/library/program/program_config_mixin.cpp @@ -0,0 +1 @@ +#include "program_config_mixin.h" diff --git a/yt/yt/library/program/program_config_mixin.h b/yt/yt/library/program/program_config_mixin.h new file mode 100644 index 0000000000..80f681d06e --- /dev/null +++ b/yt/yt/library/program/program_config_mixin.h @@ -0,0 +1,166 @@ +#pragma once + +#include "program.h" + +#include <library/cpp/yt/string/enum.h> + +#include <yt/yt/core/ytree/convert.h> +#include <yt/yt/core/ytree/yson_serializable.h> + +#include <util/stream/file.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TConfig, class TDynamicConfig = void> +class TProgramConfigMixin +{ +protected: + explicit TProgramConfigMixin( + NLastGetopt::TOpts& opts, + bool required = true, + const TString& argumentName = "config") + : ArgumentName_(argumentName) + { + auto opt = opts + .AddLongOption(TString(argumentName), Format("path to %v file", argumentName)) + .StoreMappedResult(&ConfigPath_, &CheckPathExistsArgMapper) + .RequiredArgument("FILE"); + if (required) { + opt.Required(); + } else { + opt.Optional(); + } + opts + .AddLongOption( + Format("%v-template", argumentName), + Format("print %v template and exit", argumentName)) + .SetFlag(&ConfigTemplate_); + opts + .AddLongOption( + Format("%v-actual", argumentName), + Format("print actual %v and exit", argumentName)) + .SetFlag(&ConfigActual_); + opts + .AddLongOption( + Format("%v-unrecognized-strategy", argumentName), + Format("configure strategy for unrecognized attributes in %v", argumentName)) + .Handler1T<TStringBuf>([this](TStringBuf value) { + UnrecognizedStrategy_ = ParseEnum<NYTree::EUnrecognizedStrategy>(value); + }); + + if constexpr (std::is_same_v<TDynamicConfig, void>) { + return; + } + + opts + .AddLongOption( + Format("dynamic-%v-template", argumentName), + Format("print dynamic %v template and exit", argumentName)) + .SetFlag(&DynamicConfigTemplate_); + } + + TIntrusivePtr<TConfig> GetConfig(bool returnNullIfNotSupplied = false) + { + if (returnNullIfNotSupplied && !ConfigPath_) { + return nullptr; + } + + if (!Config_) { + LoadConfig(); + } + return Config_; + } + + NYTree::INodePtr GetConfigNode(bool returnNullIfNotSupplied = false) + { + if (returnNullIfNotSupplied && !ConfigPath_) { + return nullptr; + } + + if (!ConfigNode_) { + LoadConfigNode(); + } + return ConfigNode_; + } + + bool HandleConfigOptions() + { + auto print = [] (const auto& config) { + using namespace NYson; + TYsonWriter writer(&Cout, EYsonFormat::Pretty); + config->Save(&writer); + Cout << Flush; + }; + if (ConfigTemplate_) { + print(New<TConfig>()); + return true; + } + if (ConfigActual_) { + print(GetConfig()); + return true; + } + + if constexpr (!std::is_same_v<TDynamicConfig, void>) { + if (DynamicConfigTemplate_) { + print(New<TDynamicConfig>()); + return true; + } + } + return false; + } + +private: + void LoadConfigNode() + { + using namespace NYTree; + + if (!ConfigPath_){ + THROW_ERROR_EXCEPTION("Missing --%v option", ArgumentName_); + } + + try { + TIFStream stream(ConfigPath_); + ConfigNode_ = ConvertToNode(&stream); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error parsing %v file %v", + ArgumentName_, + ConfigPath_) + << ex; + } + } + + void LoadConfig() + { + if (!ConfigNode_) { + LoadConfigNode(); + } + + try { + Config_ = New<TConfig>(); + Config_->SetUnrecognizedStrategy(UnrecognizedStrategy_); + Config_->Load(ConfigNode_); + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error loading %v file %v", + ArgumentName_, + ConfigPath_) + << ex; + } + } + + const TString ArgumentName_; + + TString ConfigPath_; + bool ConfigTemplate_; + bool ConfigActual_; + bool DynamicConfigTemplate_ = false; + NYTree::EUnrecognizedStrategy UnrecognizedStrategy_ = NYTree::EUnrecognizedStrategy::KeepRecursive; + + TIntrusivePtr<TConfig> Config_; + NYTree::INodePtr ConfigNode_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_pdeathsig_mixin.cpp b/yt/yt/library/program/program_pdeathsig_mixin.cpp new file mode 100644 index 0000000000..34f1f3b9a8 --- /dev/null +++ b/yt/yt/library/program/program_pdeathsig_mixin.cpp @@ -0,0 +1,36 @@ +#include "program_pdeathsig_mixin.h" + +#ifdef _linux_ +#include <sys/prctl.h> +#endif + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TProgramPdeathsigMixin::TProgramPdeathsigMixin(NLastGetopt::TOpts& opts) +{ + opts.AddLongOption("pdeathsig", "parent death signal") + .StoreResult(&ParentDeathSignal_) + .RequiredArgument("PDEATHSIG"); +} + +bool TProgramPdeathsigMixin::HandlePdeathsigOptions() +{ + if (ParentDeathSignal_ > 0) { +#ifdef _linux_ + // Parent death signal is set by testing framework to avoid dangling processes when test runner crashes. + // Unfortunately, setting pdeathsig in preexec_fn in subprocess call in test runner is not working + // when the program has suid bit (pdeath_sig is reset after exec call in this case) + // More details can be found in + // http://linux.die.net/man/2/prctl + // http://www.isec.pl/vulnerabilities/isec-0024-death-signal.txt + YT_VERIFY(prctl(PR_SET_PDEATHSIG, ParentDeathSignal_) == 0); +#endif + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_pdeathsig_mixin.h b/yt/yt/library/program/program_pdeathsig_mixin.h new file mode 100644 index 0000000000..3e4bcfd4a6 --- /dev/null +++ b/yt/yt/library/program/program_pdeathsig_mixin.h @@ -0,0 +1,22 @@ +#pragma once + +#include "program.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TProgramPdeathsigMixin +{ +protected: + explicit TProgramPdeathsigMixin(NLastGetopt::TOpts& opts); + + bool HandlePdeathsigOptions(); + +private: + int ParentDeathSignal_ = -1; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_setsid_mixin.cpp b/yt/yt/library/program/program_setsid_mixin.cpp new file mode 100644 index 0000000000..a745fcd3a2 --- /dev/null +++ b/yt/yt/library/program/program_setsid_mixin.cpp @@ -0,0 +1,30 @@ +#include "program_setsid_mixin.h" + +#ifdef _linux_ +#include <unistd.h> +#endif + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TProgramSetsidMixin::TProgramSetsidMixin(NLastGetopt::TOpts& opts) +{ + opts.AddLongOption("setsid", "create a new session") + .StoreTrue(&Setsid_) + .Optional(); +} + +bool TProgramSetsidMixin::HandleSetsidOptions() +{ + if (Setsid_) { +#ifdef _linux_ + setsid(); +#endif + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/program_setsid_mixin.h b/yt/yt/library/program/program_setsid_mixin.h new file mode 100644 index 0000000000..00b3dff50e --- /dev/null +++ b/yt/yt/library/program/program_setsid_mixin.h @@ -0,0 +1,22 @@ +#pragma once + +#include "program.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TProgramSetsidMixin +{ +protected: + explicit TProgramSetsidMixin(NLastGetopt::TOpts& opts); + + bool HandleSetsidOptions(); + +private: + bool Setsid_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/public.h b/yt/yt/library/program/public.h new file mode 100644 index 0000000000..9f8ad8dbf2 --- /dev/null +++ b/yt/yt/library/program/public.h @@ -0,0 +1,20 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TBuildInfo) +DECLARE_REFCOUNTED_CLASS(TRpcConfig) +DECLARE_REFCOUNTED_CLASS(TTCMallocConfig) +DECLARE_REFCOUNTED_CLASS(TStockpileConfig) +DECLARE_REFCOUNTED_CLASS(TSingletonsConfig) +DECLARE_REFCOUNTED_CLASS(TSingletonsDynamicConfig) +DECLARE_REFCOUNTED_CLASS(TDiagnosticDumpConfig) +DECLARE_REFCOUNTED_CLASS(THeapSizeLimit) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/program/ya.make b/yt/yt/library/program/ya.make new file mode 100644 index 0000000000..5742ce9287 --- /dev/null +++ b/yt/yt/library/program/ya.make @@ -0,0 +1,30 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + build_attributes.cpp + config.cpp + helpers.cpp + program.cpp + program_config_mixin.cpp + program_pdeathsig_mixin.cpp + program_setsid_mixin.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/core/service_discovery/yp + yt/yt/library/monitoring + yt/yt/library/containers + yt/yt/library/profiling/solomon + yt/yt/library/profiling/tcmalloc + yt/yt/library/profiling/perf + yt/yt/library/ytprof + yt/yt/library/tracing/jaeger + library/cpp/yt/mlock + library/cpp/yt/stockpile + library/cpp/yt/string +) + +END() diff --git a/yt/yt/library/query/base/ast.cpp b/yt/yt/library/query/base/ast.cpp new file mode 100644 index 0000000000..35ffd6b332 --- /dev/null +++ b/yt/yt/library/query/base/ast.cpp @@ -0,0 +1,687 @@ +#include "ast.h" + +#include <library/cpp/yt/misc/variant.h> + +#include <util/string/escape.h> + +namespace NYT::NQueryClient::NAst { + +//////////////////////////////////////////////////////////////////////////////// + +bool operator == (TNullLiteralValue, TNullLiteralValue) +{ + return true; +} + +bool operator != (TNullLiteralValue, TNullLiteralValue) +{ + return false; +} + +//////////////////////////////////////////////////////////////////////////////// + +TReference::operator size_t() const +{ + size_t result = 0; + HashCombine(result, ColumnName); + HashCombine(result, TableName); + return result; +} + +bool operator == (const TReference& lhs, const TReference& rhs) +{ + return + std::tie(lhs.ColumnName, lhs.TableName) == + std::tie(rhs.ColumnName, rhs.TableName); +} + +bool operator != (const TReference& lhs, const TReference& rhs) +{ + return !(lhs == rhs); +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +bool ExpressionListEqual(const T& lhs, const T& rhs) +{ + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t index = 0; index < lhs.size(); ++index) { + if (*lhs[index] != *rhs[index]) { + return false; + } + } + return true; +} + +bool operator == (const TExpressionList& lhs, const TExpressionList& rhs) +{ + return ExpressionListEqual(lhs, rhs); +} + +bool operator != (const TExpressionList& lhs, const TExpressionList& rhs) +{ + return !(lhs == rhs); +} + +bool operator == (const TIdentifierList& lhs, const TIdentifierList& rhs) +{ + return ExpressionListEqual(lhs, rhs); +} + +bool operator != (const TIdentifierList& lhs, const TIdentifierList& rhs) +{ + return !(lhs == rhs); +} + +bool operator == (const TExpression& lhs, const TExpression& rhs) +{ + if (const auto* typedLhs = lhs.As<TLiteralExpression>()) { + const auto* typedRhs = rhs.As<TLiteralExpression>(); + if (!typedRhs) { + return false; + } + return typedLhs->Value == typedRhs->Value; + } else if (const auto* typedLhs = lhs.As<TReferenceExpression>()) { + const auto* typedRhs = rhs.As<TReferenceExpression>(); + if (!typedRhs) { + return false; + } + return typedLhs->Reference == typedRhs->Reference; + } else if (const auto* typedLhs = lhs.As<TAliasExpression>()) { + const auto* typedRhs = rhs.As<TAliasExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->Name == typedRhs->Name && + *typedLhs->Expression == *typedRhs->Expression; + } else if (const auto* typedLhs = lhs.As<TFunctionExpression>()) { + const auto* typedRhs = rhs.As<TFunctionExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->FunctionName == typedRhs->FunctionName && + typedLhs->Arguments == typedRhs->Arguments; + } else if (const auto* typedLhs = lhs.As<TUnaryOpExpression>()) { + const auto* typedRhs = rhs.As<TUnaryOpExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->Opcode == typedRhs->Opcode && + typedLhs->Operand == typedRhs->Operand; + } else if (const auto* typedLhs = lhs.As<TBinaryOpExpression>()) { + const auto* typedRhs = rhs.As<TBinaryOpExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->Opcode == typedRhs->Opcode && + typedLhs->Lhs == typedRhs->Lhs && + typedLhs->Rhs == typedRhs->Rhs; + } else if (const auto* typedLhs = lhs.As<TInExpression>()) { + const auto* typedRhs = rhs.As<TInExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->Expr == typedRhs->Expr && + typedLhs->Values == typedRhs->Values; + } else if (const auto* typedLhs = lhs.As<TBetweenExpression>()) { + const auto* typedRhs = rhs.As<TBetweenExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->Expr == typedRhs->Expr && + typedLhs->Values == typedRhs->Values; + } else if (const auto* typedLhs = lhs.As<TTransformExpression>()) { + const auto* typedRhs = rhs.As<TTransformExpression>(); + if (!typedRhs) { + return false; + } + return + typedLhs->Expr == typedRhs->Expr && + typedLhs->From == typedRhs->From && + typedLhs->To == typedRhs->To && + typedLhs->DefaultExpr == typedRhs->DefaultExpr; + } else { + YT_ABORT(); + } +} + +bool operator != (const TExpression& lhs, const TExpression& rhs) +{ + return !(lhs == rhs); +} + +TStringBuf TExpression::GetSource(TStringBuf source) const +{ + auto begin = SourceLocation.first; + auto end = SourceLocation.second; + + return source.substr(begin, end - begin); +} + +TStringBuf GetSource(TSourceLocation sourceLocation, TStringBuf source) +{ + auto begin = sourceLocation.first; + auto end = sourceLocation.second; + + return source.substr(begin, end - begin); +} + +bool operator == (const TTableDescriptor& lhs, const TTableDescriptor& rhs) +{ + return + std::tie(lhs.Path, rhs.Alias) == + std::tie(rhs.Path, rhs.Alias); +} + +bool operator != (const TTableDescriptor& lhs, const TTableDescriptor& rhs) +{ + return !(lhs == rhs); +} + +bool operator == (const TJoin& lhs, const TJoin& rhs) +{ + return + std::tie(lhs.IsLeft, lhs.Table, lhs.Fields, lhs.Lhs, lhs.Rhs, lhs.Predicate) == + std::tie(rhs.IsLeft, rhs.Table, rhs.Fields, rhs.Lhs, rhs.Rhs, rhs.Predicate); +} + +bool operator != (const TJoin& lhs, const TJoin& rhs) +{ + return !(lhs == rhs); +} + +bool operator == (const TQuery& lhs, const TQuery& rhs) +{ + return + std::tie( + lhs.Table, + lhs.Joins, + lhs.SelectExprs, + lhs.WherePredicate, + lhs.GroupExprs, + lhs.HavingPredicate, + lhs.OrderExpressions, + lhs.Offset, + lhs.Limit) == + std::tie( + rhs.Table, + rhs.Joins, + rhs.SelectExprs, + rhs.WherePredicate, + rhs.GroupExprs, + rhs.HavingPredicate, + rhs.OrderExpressions, + rhs.Offset, + rhs.Limit); +} + +bool operator != (const TQuery& lhs, const TQuery& rhs) +{ + return !(lhs == rhs); +} + +void FormatLiteralValue(TStringBuilderBase* builder, const TLiteralValue& value) +{ + Visit(value, + [&] (TNullLiteralValue) { + builder->AppendString("null"); + }, + [&] (i64 value) { + builder->AppendFormat("%v", value); + }, + [&] (ui64 value) { + builder->AppendFormat("%vu", value); + }, + [&] (double value) { + builder->AppendFormat("%v", value); + }, + [&] (bool value) { + builder->AppendFormat("%v", value ? "true" : "false"); + }, + [&] (const TString& value) { + builder->AppendChar('"'); + builder->AppendString(EscapeC(value)); + builder->AppendChar('"'); + }); +} + +std::vector<TStringBuf> GetKeywords() +{ + std::vector<TStringBuf> result; + +#define XX(keyword) result.push_back(#keyword); + + XX(from) + XX(where) + XX(having) + XX(offset) + XX(limit) + XX(join) + XX(using) + XX(group) + XX(by) + XX(with) + XX(totals) + XX(order) + XX(by) + XX(asc) + XX(desc) + XX(left) + XX(as) + XX(on) + XX(and) + XX(or) + XX(not) + XX(null) + XX(between) + XX(in) + XX(transform) + XX(false) + XX(true) + +#undef XX + + std::sort(result.begin(), result.end()); + + return result; +} + +bool IsKeyword(TStringBuf str) +{ + static auto keywords = GetKeywords(); + + return std::binary_search(keywords.begin(), keywords.end(), str, [] (TStringBuf str, TStringBuf keyword) { + return std::lexicographical_compare( + str.begin(), + str.end(), + keyword.begin(), + keyword.end(), [] (char a, char b) { + return tolower(a) < tolower(b); + }); + }); +} + +bool IsValidId(TStringBuf str) +{ + if (str.empty()) { + return false; + } + + auto isNum = [] (char ch) { + return + ch >= '0' && ch <= '9'; + }; + + auto isAlpha = [] (char ch) { + return + ch >= 'a' && ch <= 'z' || + ch >= 'A' && ch <= 'Z' || + ch == '_'; + }; + + if (!isAlpha(str[0])) { + return false; + } + + for (size_t index = 1; index < str.length(); ++index) { + char ch = str[index]; + if (!isAlpha(ch) && !isNum(ch)) { + return false; + } + } + + if (IsKeyword(str)) { + return false; + } + + return true; +} + +bool AreBackticksNeeded(TStringBuf id) +{ + return id.Contains('[') || id.Contains(']'); +} + +void FormatId(TStringBuilderBase* builder, TStringBuf id, bool isFinal = false) +{ + if (isFinal || IsValidId(id)) { + builder->AppendString(id); + } else { + if (AreBackticksNeeded(id)) { + builder->AppendChar('`'); + builder->AppendString(EscapeC(id)); + builder->AppendChar('`'); + } else { + builder->AppendChar('['); + builder->AppendString(id); + builder->AppendChar(']'); + } + } +} + +void FormatReference(TStringBuilderBase* builder, const TReference& ref, bool isFinal = false) +{ + if (ref.TableName) { + builder->AppendString(*ref.TableName); + builder->AppendChar('.'); + } + FormatId(builder, ref.ColumnName, isFinal); +} + +void FormatTableDescriptor(TStringBuilderBase* builder, const TTableDescriptor& descriptor) +{ + FormatId(builder, descriptor.Path); + if (descriptor.Alias) { + builder->AppendString(" AS "); + FormatId(builder, *descriptor.Alias); + } +} + +void FormatExpressions(TStringBuilderBase* builder, const TExpressionList& exprs, bool expandAliases); +void FormatExpression(TStringBuilderBase* builder, const TExpression& expr, bool expandAliases, bool isFinal = false); +void FormatExpression(TStringBuilderBase* builder, const TExpressionList& expr, bool expandAliases); + +void FormatExpression(TStringBuilderBase* builder, const TExpression& expr, bool expandAliases, bool isFinal) +{ + auto printTuple = [] (TStringBuilderBase* builder, const TLiteralValueTuple& tuple) { + bool needParens = tuple.size() > 1; + if (needParens) { + builder->AppendChar('('); + } + JoinToString( + builder, + tuple.begin(), + tuple.end(), + [] (TStringBuilderBase* builder, const TLiteralValue& value) { + builder->AppendString(FormatLiteralValue(value)); + }); + if (needParens) { + builder->AppendChar(')'); + } + }; + + auto printTuples = [&] (TStringBuilderBase* builder, const TLiteralValueTupleList& list) { + JoinToString( + builder, + list.begin(), + list.end(), + printTuple); + }; + + auto printRanges = [&] (TStringBuilderBase* builder, const TLiteralValueRangeList& list) { + JoinToString( + builder, + list.begin(), + list.end(), + [&] (TStringBuilderBase* builder, const std::pair<TLiteralValueTuple, TLiteralValueTuple>& range) { + printTuple(builder, range.first); + builder->AppendString(" AND "); + printTuple(builder, range.second); + }); + }; + + if (auto* typedExpr = expr.As<TLiteralExpression>()) { + builder->AppendString(FormatLiteralValue(typedExpr->Value)); + } else if (auto* typedExpr = expr.As<TReferenceExpression>()) { + FormatReference(builder, typedExpr->Reference, isFinal); + } else if (auto* typedExpr = expr.As<TAliasExpression>()) { + if (expandAliases) { + builder->AppendChar('('); + FormatExpression(builder, *typedExpr->Expression, expandAliases); + builder->AppendString(" as "); + FormatId(builder, typedExpr->Name, isFinal); + builder->AppendChar(')'); + } else { + FormatId(builder, typedExpr->Name, isFinal); + } + } else if (auto* typedExpr = expr.As<TFunctionExpression>()) { + builder->AppendString(typedExpr->FunctionName); + builder->AppendChar('('); + FormatExpressions(builder, typedExpr->Arguments, expandAliases); + builder->AppendChar(')'); + } else if (auto* typedExpr = expr.As<TUnaryOpExpression>()) { + builder->AppendString(GetUnaryOpcodeLexeme(typedExpr->Opcode)); + builder->AppendChar('('); + FormatExpressions(builder, typedExpr->Operand, expandAliases); + builder->AppendChar(')'); + } else if (auto* typedExpr = expr.As<TBinaryOpExpression>()) { + builder->AppendChar('('); + FormatExpressions(builder, typedExpr->Lhs, expandAliases); + builder->AppendChar(')'); + builder->AppendString(GetBinaryOpcodeLexeme(typedExpr->Opcode)); + builder->AppendChar('('); + FormatExpressions(builder, typedExpr->Rhs, expandAliases); + builder->AppendChar(')'); + } else if (auto* typedExpr = expr.As<TInExpression>()) { + builder->AppendChar('('); + FormatExpressions(builder, typedExpr->Expr, expandAliases); + builder->AppendString(") IN ("); + printTuples(builder, typedExpr->Values); + builder->AppendChar(')'); + } else if (auto* typedExpr = expr.As<TBetweenExpression>()) { + builder->AppendChar('('); + FormatExpressions(builder, typedExpr->Expr, expandAliases); + builder->AppendString(") BETWEEN ("); + printRanges(builder, typedExpr->Values); + builder->AppendChar(')'); + } else if (auto* typedExpr = expr.As<TTransformExpression>()) { + builder->AppendString("TRANSFORM("); + size_t argumentCount = typedExpr->Expr.size(); + auto needParenthesis = argumentCount > 1; + if (needParenthesis) { + builder->AppendChar('('); + } + FormatExpressions(builder, typedExpr->Expr, expandAliases); + if (needParenthesis) { + builder->AppendChar(')'); + } + builder->AppendString(", ("); + printTuples(builder, typedExpr->From); + builder->AppendString("), ("); + printTuples(builder, typedExpr->To); + builder->AppendChar(')'); + + if (typedExpr->DefaultExpr) { + builder->AppendString(", "); + FormatExpression(builder, *typedExpr->DefaultExpr, expandAliases); + } + + builder->AppendChar(')'); + } else { + YT_ABORT(); + } +} + +void FormatExpression(TStringBuilderBase* builder, const TExpressionList& exprs, bool expandAliases) +{ + YT_VERIFY(exprs.size() > 0); + if (exprs.size() > 1) { + builder->AppendChar('('); + } + FormatExpressions(builder, exprs, expandAliases); + if (exprs.size() > 1) { + builder->AppendChar(')'); + } +} + +void FormatExpressions(TStringBuilderBase* builder, const TExpressionList& exprs, bool expandAliases) +{ + JoinToString( + builder, + exprs.begin(), + exprs.end(), + [&] (TStringBuilderBase* builder, const TExpressionPtr& expr) { + FormatExpression(builder, *expr, expandAliases); + }); +} + +void FormatJoin(TStringBuilderBase* builder, const TJoin& join) +{ + if (join.IsLeft) { + builder->AppendString(" LEFT"); + } + builder->AppendString(" JOIN "); + FormatTableDescriptor(builder, join.Table); + if (join.Fields.empty()) { + builder->AppendString(" ON ("); + FormatExpressions(builder, join.Lhs, true); + builder->AppendString(") = ("); + FormatExpressions(builder, join.Rhs, true); + builder->AppendChar(')'); + } else { + builder->AppendString(" USING "); + JoinToString( + builder, + join.Fields.begin(), + join.Fields.end(), + [] (TStringBuilderBase* builder, const TReferenceExpressionPtr& referenceExpr) { + FormatReference(builder, referenceExpr->Reference); + }); + } + if (join.Predicate) { + builder->AppendString(" AND "); + FormatExpression(builder, *join.Predicate, true); + } +} + +void FormatQuery(TStringBuilderBase* builder, const TQuery& query) +{ + if (query.SelectExprs) { + JoinToString( + builder, + query.SelectExprs->begin(), + query.SelectExprs->end(), + [] (TStringBuilderBase* builder, const TExpressionPtr& expr) { + FormatExpression(builder, *expr, true); + }); + } else { + builder->AppendString("*"); + } + + builder->AppendString(" FROM "); + FormatTableDescriptor(builder, query.Table); + + for (const auto& join : query.Joins) { + FormatJoin(builder, join); + } + + if (query.WherePredicate) { + builder->AppendString(" WHERE "); + FormatExpression(builder, *query.WherePredicate, true); + } + + if (query.GroupExprs) { + builder->AppendString(" GROUP BY "); + FormatExpressions(builder, query.GroupExprs->first, true); + if (query.GroupExprs->second == ETotalsMode::BeforeHaving) { + builder->AppendString(" WITH TOTALS"); + } + } + + if (query.HavingPredicate) { + builder->AppendString(" HAVING "); + FormatExpression(builder, *query.HavingPredicate, true); + } + + if (query.GroupExprs && query.GroupExprs->second == ETotalsMode::AfterHaving) { + builder->AppendString(" WITH TOTALS"); + } + + if (!query.OrderExpressions.empty()) { + builder->AppendString(" ORDER BY "); + JoinToString( + builder, + query.OrderExpressions.begin(), + query.OrderExpressions.end(), + [] (TStringBuilderBase* builder, const std::pair<TExpressionList, bool>& pair) { + FormatExpression(builder, pair.first, true); + if (pair.second) { + builder->AppendString(" DESC"); + } + }); + } + + if (query.Offset) { + builder->AppendFormat(" OFFSET %v", *query.Offset); + } + + if (query.Limit) { + builder->AppendFormat(" LIMIT %v", *query.Limit); + } +} + +TString FormatLiteralValue(const TLiteralValue& value) +{ + TStringBuilder builder; + FormatLiteralValue(&builder, value); + return builder.Flush(); +} + +TString FormatId(TStringBuf id) +{ + TStringBuilder builder; + FormatId(&builder, id); + return builder.Flush(); +} + +TString FormatReference(const TReference& ref) +{ + TStringBuilder builder; + FormatReference(&builder, ref); + return builder.Flush(); +} + +TString FormatExpression(const TExpression& expr) +{ + TStringBuilder builder; + FormatExpression(&builder, expr, true); + return builder.Flush(); +} + +TString FormatExpression(const TExpressionList& exprs) +{ + TStringBuilder builder; + FormatExpression(&builder, exprs, true); + return builder.Flush(); +} + +TString FormatJoin(const TJoin& join) +{ + TStringBuilder builder; + FormatJoin(&builder, join); + return builder.Flush(); +} + +TString FormatQuery(const TQuery& query) +{ + TStringBuilder builder; + FormatQuery(&builder, query); + return builder.Flush(); +} + +TString InferColumnName(const TExpression& expr) +{ + TStringBuilder builder; + FormatExpression(&builder, expr, false, true); + return builder.Flush(); +} + +TString InferColumnName(const TReference& ref) +{ + TStringBuilder builder; + FormatReference(&builder, ref, true); + return builder.Flush(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient::NAst diff --git a/yt/yt/library/query/base/ast.h b/yt/yt/library/query/base/ast.h new file mode 100644 index 0000000000..455d753657 --- /dev/null +++ b/yt/yt/library/query/base/ast.h @@ -0,0 +1,420 @@ +#pragma once + +#include "public.h" +#include "query_common.h" + +#include <yt/yt/library/query/misc/objects_holder.h> +#include <library/cpp/yt/misc/hash.h> + +#include <variant> + +namespace NYT::NQueryClient::NAst { + +//////////////////////////////////////////////////////////////////////////////// + +#define XX(name) \ +struct name; \ +using name ## Ptr = name*; + +XX(TExpression) +XX(TReferenceExpression) +XX(TAliasExpression) +XX(TLiteralExpression) +XX(TFunctionExpression) +XX(TUnaryOpExpression) +XX(TBinaryOpExpression) +XX(TInExpression) +XX(TBetweenExpression) +XX(TTransformExpression) + +#undef XX + + +using TIdentifierList = std::vector<TReferenceExpressionPtr>; +using TExpressionList = std::vector<TExpressionPtr>; +using TNullableExpressionList = std::optional<TExpressionList>; +using TNullableIdentifierList = std::optional<TIdentifierList>; +using TOrderExpressionList = std::vector<std::pair<TExpressionList, bool>>; + +//////////////////////////////////////////////////////////////////////////////// + +struct TNullLiteralValue +{ }; + +bool operator == (TNullLiteralValue, TNullLiteralValue); +bool operator != (TNullLiteralValue, TNullLiteralValue); + +using TLiteralValue = std::variant< + TNullLiteralValue, + i64, + ui64, + double, + bool, + TString +>; + +using TLiteralValueList = std::vector<TLiteralValue>; +using TLiteralValueTuple = std::vector<TLiteralValue>; +using TLiteralValueTupleList = std::vector<TLiteralValueTuple>; +using TLiteralValueRangeList = std::vector<std::pair<TLiteralValueTuple, TLiteralValueTuple>>; + +//////////////////////////////////////////////////////////////////////////////// + +struct TReference +{ + TString ColumnName; + std::optional<TString> TableName; + + TReference() = default; + + TReference(const TString& columnName, const std::optional<TString>& tableName = std::nullopt) + : ColumnName(columnName) + , TableName(tableName) + { } + + operator size_t() const; +}; + +bool operator == (const TReference& lhs, const TReference& rhs); +bool operator != (const TReference& lhs, const TReference& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +struct TExpression +{ + TSourceLocation SourceLocation; + + explicit TExpression(const TSourceLocation& sourceLocation) + : SourceLocation(sourceLocation) + { } + + template <class TDerived> + const TDerived* As() const + { + return dynamic_cast<const TDerived*>(this); + } + + template <class TDerived> + TDerived* As() + { + return dynamic_cast<TDerived*>(this); + } + + TStringBuf GetSource(TStringBuf source) const; + + virtual ~TExpression() = default; +}; + +template <class T, class... TArgs> +TExpressionList MakeExpression(TObjectsHolder* holder, TArgs&& ... args) +{ + return TExpressionList(1, holder->Register(new T(std::forward<TArgs>(args)...))); +} + +bool operator == (const TExpression& lhs, const TExpression& rhs); +bool operator != (const TExpression& lhs, const TExpression& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +struct TLiteralExpression + : public TExpression +{ + TLiteralValue Value; + + TLiteralExpression( + const TSourceLocation& sourceLocation, + TLiteralValue value) + : TExpression(sourceLocation) + , Value(std::move(value)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TReferenceExpression + : public TExpression +{ + TReference Reference; + + TReferenceExpression( + const TSourceLocation& sourceLocation, + const TString& columnName) + : TExpression(sourceLocation) + , Reference(columnName) + { } + + TReferenceExpression( + const TSourceLocation& sourceLocation, + const TString& columnName, + const TString& tableName) + : TExpression(sourceLocation) + , Reference(columnName, tableName) + { } + + TReferenceExpression( + const TSourceLocation& sourceLocation, + const TReference& reference) + : TExpression(sourceLocation) + , Reference(reference) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TAliasExpression + : public TExpression +{ + TExpressionPtr Expression; + TString Name; + + TAliasExpression( + const TSourceLocation& sourceLocation, + const TExpressionPtr& expression, + TStringBuf name) + : TExpression(sourceLocation) + , Expression(expression) + , Name(TString(name)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TFunctionExpression + : public TExpression +{ + TString FunctionName; + TExpressionList Arguments; + + TFunctionExpression( + const TSourceLocation& sourceLocation, + TStringBuf functionName, + TExpressionList arguments) + : TExpression(sourceLocation) + , FunctionName(functionName) + , Arguments(std::move(arguments)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TUnaryOpExpression + : public TExpression +{ + EUnaryOp Opcode; + TExpressionList Operand; + + TUnaryOpExpression( + const TSourceLocation& sourceLocation, + EUnaryOp opcode, + TExpressionList operand) + : TExpression(sourceLocation) + , Opcode(opcode) + , Operand(std::move(operand)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TBinaryOpExpression + : public TExpression +{ + EBinaryOp Opcode; + TExpressionList Lhs; + TExpressionList Rhs; + + TBinaryOpExpression( + const TSourceLocation& sourceLocation, + EBinaryOp opcode, + TExpressionList lhs, + TExpressionList rhs) + : TExpression(sourceLocation) + , Opcode(opcode) + , Lhs(std::move(lhs)) + , Rhs(std::move(rhs)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TInExpression + : public TExpression +{ + TExpressionList Expr; + TLiteralValueTupleList Values; + + TInExpression( + const TSourceLocation& sourceLocation, + TExpressionList expression, + TLiteralValueTupleList values) + : TExpression(sourceLocation) + , Expr(std::move(expression)) + , Values(std::move(values)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TBetweenExpression + : public TExpression +{ + TExpressionList Expr; + TLiteralValueRangeList Values; + + TBetweenExpression( + const TSourceLocation& sourceLocation, + TExpressionList expression, + const TLiteralValueRangeList& values) + : TExpression(sourceLocation) + , Expr(std::move(expression)) + , Values(values) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TTransformExpression + : public TExpression +{ + TExpressionList Expr; + TLiteralValueTupleList From; + TLiteralValueTupleList To; + TNullableExpressionList DefaultExpr; + + TTransformExpression( + const TSourceLocation& sourceLocation, + TExpressionList expression, + const TLiteralValueTupleList& from, + const TLiteralValueTupleList& to, + TNullableExpressionList defaultExpr) + : TExpression(sourceLocation) + , Expr(std::move(expression)) + , From(from) + , To(to) + , DefaultExpr(std::move(defaultExpr)) + { } +}; + +//////////////////////////////////////////////////////////////////////////////// +struct TTableDescriptor +{ + NYPath::TYPath Path; + std::optional<TString> Alias; + + TTableDescriptor() = default; + + explicit TTableDescriptor( + const NYPath::TYPath& path, + const std::optional<TString>& alias = std::nullopt) + : Path(path) + , Alias(alias) + { } +}; + +bool operator == (const TTableDescriptor& lhs, const TTableDescriptor& rhs); +bool operator != (const TTableDescriptor& lhs, const TTableDescriptor& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +struct TJoin +{ + bool IsLeft; + TTableDescriptor Table; + TIdentifierList Fields; + + TExpressionList Lhs; + TExpressionList Rhs; + + TNullableExpressionList Predicate; + + TJoin( + bool isLeft, + const TTableDescriptor& table, + const TIdentifierList& fields, + const TNullableExpressionList& predicate) + : IsLeft(isLeft) + , Table(table) + , Fields(fields) + , Predicate(predicate) + { } + + TJoin( + bool isLeft, + const TTableDescriptor& table, + const TExpressionList& lhs, + const TExpressionList& rhs, + const TNullableExpressionList& predicate) + : IsLeft(isLeft) + , Table(table) + , Lhs(lhs) + , Rhs(rhs) + , Predicate(predicate) + { } +}; + +bool operator == (const TJoin& lhs, const TJoin& rhs); +bool operator != (const TJoin& lhs, const TJoin& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +struct TQuery +{ + TTableDescriptor Table; + std::vector<TJoin> Joins; + + TNullableExpressionList SelectExprs; + TNullableExpressionList WherePredicate; + + std::optional<std::pair<TExpressionList, ETotalsMode>> GroupExprs; + TNullableExpressionList HavingPredicate; + + TOrderExpressionList OrderExpressions; + + std::optional<i64> Offset; + std::optional<i64> Limit; +}; + +bool operator == (const TQuery& lhs, const TQuery& rhs); +bool operator != (const TQuery& lhs, const TQuery& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +using TAliasMap = THashMap<TString, TExpressionPtr>; + +struct TAstHead + : public TObjectsHolder +{ + std::variant<TQuery, TExpressionPtr> Ast; + TAliasMap AliasMap; + + static TAstHead MakeQuery() + { + TAstHead result; + result.Ast.emplace<TQuery>(); + return result; + } + + static TAstHead MakeExpression() + { + TAstHead result; + result.Ast.emplace<TExpressionPtr>(); + return result; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +TStringBuf GetSource(TSourceLocation sourceLocation, TStringBuf source); + +TString FormatId(TStringBuf id); +TString FormatLiteralValue(const TLiteralValue& value); +TString FormatReference(const TReference& ref); +TString FormatExpression(const TExpression& expr); +TString FormatExpression(const TExpressionList& exprs); +TString FormatJoin(const TJoin& join); +TString FormatQuery(const TQuery& query); +TString InferColumnName(const TExpression& expr); +TString InferColumnName(const TReference& ref); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient::NAst diff --git a/yt/yt/library/query/base/builtin_function_registry.cpp b/yt/yt/library/query/base/builtin_function_registry.cpp new file mode 100644 index 0000000000..6b3952f02a --- /dev/null +++ b/yt/yt/library/query/base/builtin_function_registry.cpp @@ -0,0 +1,352 @@ +#include "builtin_function_registry.h" + +#include "functions.h" + +#include <library/cpp/resource/resource.h> + +namespace NYT::NQueryClient { + +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +void RegisterBuiltinFunctions(IFunctionRegistryBuilder* builder) +{ + builder->RegisterFunction( + "is_substr", + std::vector<TType>{EValueType::String, EValueType::String}, + EValueType::Boolean, + "is_substr", + ECallingConvention::Simple); + + builder->RegisterFunction( + "lower", + std::vector<TType>{EValueType::String}, + EValueType::String, + "lower", + ECallingConvention::Simple); + + builder->RegisterFunction( + "length", + std::vector<TType>{EValueType::String}, + EValueType::Int64, + "length", + ECallingConvention::Simple); + + builder->RegisterFunction( + "yson_length", + std::vector<TType>{EValueType::Any}, + EValueType::Int64, + "yson_length", + ECallingConvention::Simple); + + builder->RegisterFunction( + "concat", + std::vector<TType>{EValueType::String, EValueType::String}, + EValueType::String, + "concat", + ECallingConvention::Simple); + + builder->RegisterFunction( + "sleep", + std::vector<TType>{EValueType::Int64}, + EValueType::Int64, + "sleep", + ECallingConvention::Simple); + + builder->RegisterFunction( + "farm_hash", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{}, + TUnionType{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Boolean, + EValueType::String + }, + EValueType::Uint64, + "farm_hash"); + + builder->RegisterFunction( + "bigb_hash", + std::vector<TType>{EValueType::String}, + EValueType::Uint64, + "bigb_hash", + ECallingConvention::Simple); + + builder->RegisterFunction( + "make_map", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{}, + TUnionType{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Boolean, + EValueType::Double, + EValueType::String, + EValueType::Any + }, + EValueType::Any, + "make_map"); + + builder->RegisterFunction( + "numeric_to_string", + std::vector<TType>{ + TUnionType{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + }}, + EValueType::String, + "str_conv", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "parse_int64", + std::vector<TType>{EValueType::String}, + EValueType::Int64, + "str_conv", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "parse_uint64", + std::vector<TType>{EValueType::String}, + EValueType::Uint64, + "str_conv", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "parse_double", + std::vector<TType>{EValueType::String}, + EValueType::Double, + "str_conv", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "regex_full_match", + "regex_full_match", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String, EValueType::String}, + EValueType::Null, + EValueType::Boolean, + "regex", + ECallingConvention::UnversionedValue, + true); + + builder->RegisterFunction( + "regex_partial_match", + "regex_partial_match", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String, EValueType::String}, + EValueType::Null, + EValueType::Boolean, + "regex", + ECallingConvention::UnversionedValue, + true); + + builder->RegisterFunction( + "regex_replace_first", + "regex_replace_first", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String, EValueType::String, EValueType::String}, + EValueType::Null, + EValueType::String, + "regex", + ECallingConvention::UnversionedValue, + true); + + builder->RegisterFunction( + "regex_replace_all", + "regex_replace_all", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String, EValueType::String, EValueType::String}, + EValueType::Null, + EValueType::String, + "regex", + ECallingConvention::UnversionedValue, + true); + + builder->RegisterFunction( + "regex_extract", + "regex_extract", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String, EValueType::String, EValueType::String}, + EValueType::Null, + EValueType::String, + "regex", + ECallingConvention::UnversionedValue, + true); + + builder->RegisterFunction( + "regex_escape", + "regex_escape", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String}, + EValueType::Null, + EValueType::String, + "regex", + ECallingConvention::UnversionedValue, + true); + + const TTypeParameter typeParameter = 0; + auto anyConstraints = std::unordered_map<TTypeParameter, TUnionType>(); + anyConstraints[typeParameter] = std::vector<EValueType>{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Boolean, + EValueType::Double, + EValueType::String, + EValueType::Any}; + + builder->RegisterAggregate( + "first", + anyConstraints, + typeParameter, + typeParameter, + typeParameter, + "first", + ECallingConvention::UnversionedValue, + true); + + auto xdeltaConstraints = std::unordered_map<TTypeParameter, TUnionType>(); + xdeltaConstraints[typeParameter] = std::vector<EValueType>{ + EValueType::Null, + EValueType::String}; + builder->RegisterAggregate( + "xdelta", + xdeltaConstraints, + typeParameter, + typeParameter, + typeParameter, + "xdelta", + ECallingConvention::UnversionedValue); + + builder->RegisterAggregate( + "avg", + std::unordered_map<TTypeParameter, TUnionType>(), + EValueType::Int64, + EValueType::Double, + EValueType::String, + "avg", + ECallingConvention::UnversionedValue); + + builder->RegisterAggregate( + "cardinality", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<EValueType>{ + EValueType::String, + EValueType::Uint64, + EValueType::Int64, + EValueType::Double, + EValueType::Boolean}, + EValueType::Uint64, + EValueType::String, + "hyperloglog", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "format_timestamp", + std::vector<TType>{EValueType::Int64, EValueType::String}, + EValueType::String, + "dates", + ECallingConvention::Simple); + + std::vector<TString> timestampFloorFunctions = { + "timestamp_floor_hour", + "timestamp_floor_day", + "timestamp_floor_week", + "timestamp_floor_month", + "timestamp_floor_year"}; + + for (const auto& name : timestampFloorFunctions) { + builder->RegisterFunction( + name, + std::vector<TType>{EValueType::Int64}, + EValueType::Int64, + "dates", + ECallingConvention::Simple); + } + + builder->RegisterFunction( + "format_guid", + std::vector<TType>{EValueType::Uint64, EValueType::Uint64}, + EValueType::String, + "format_guid", + ECallingConvention::Simple); + + std::vector<std::pair<TString, EValueType>> ypathGetFunctions = { + {"try_get_int64", EValueType::Int64}, + {"get_int64", EValueType::Int64}, + {"try_get_uint64", EValueType::Uint64}, + {"get_uint64", EValueType::Uint64}, + {"try_get_double", EValueType::Double}, + {"get_double", EValueType::Double}, + {"try_get_boolean", EValueType::Boolean}, + {"get_boolean", EValueType::Boolean}, + {"try_get_string", EValueType::String}, + {"get_string", EValueType::String}, + {"try_get_any", EValueType::Any}, + {"get_any", EValueType::Any}}; + + for (const auto& fns : ypathGetFunctions) { + auto&& name = fns.first; + auto&& type = fns.second; + builder->RegisterFunction( + name, + std::vector<TType>{EValueType::Any, EValueType::String}, + type, + "ypath_get", + ECallingConvention::UnversionedValue); + } + + builder->RegisterFunction( + "to_any", + std::vector<TType>{ + TUnionType{ + EValueType::String, + EValueType::Uint64, + EValueType::Int64, + EValueType::Double, + EValueType::Boolean, + EValueType::Any, + EValueType::Composite}}, + EValueType::Any, + "to_any", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "list_contains", + std::vector<TType>{ + EValueType::Any, + TUnionType{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + }}, + EValueType::Boolean, + "list_contains", + ECallingConvention::UnversionedValue); + + builder->RegisterFunction( + "any_to_yson_string", + std::vector<TType>{EValueType::Any}, + EValueType::String, + "any_to_yson_string", + ECallingConvention::Simple); + + builder->RegisterFunction( + "_yt_has_permissions", + "has_permissions", + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::Any, EValueType::String, EValueType::String}, + EValueType::Null, + EValueType::Boolean, + "has_permissions", + ECallingConvention::UnversionedValue); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/builtin_function_registry.h b/yt/yt/library/query/base/builtin_function_registry.h new file mode 100644 index 0000000000..9e4864488d --- /dev/null +++ b/yt/yt/library/query/base/builtin_function_registry.h @@ -0,0 +1,13 @@ +#pragma once + +#include "functions_builder.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +void RegisterBuiltinFunctions(IFunctionRegistryBuilder* builder); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/builtin_function_types.cpp b/yt/yt/library/query/base/builtin_function_types.cpp new file mode 100644 index 0000000000..f59b9056b7 --- /dev/null +++ b/yt/yt/library/query/base/builtin_function_types.cpp @@ -0,0 +1,241 @@ +#include "builtin_function_types.h" + +#include "functions_builder.h" +#include "functions.h" + +namespace NYT::NQueryClient { + +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +class TTypeInferrerFunctionRegistryBuilder + : public IFunctionRegistryBuilder +{ +public: + explicit TTypeInferrerFunctionRegistryBuilder(const TTypeInferrerMapPtr& typeInferrers) + : TypeInferrers_(typeInferrers) + { } + + void RegisterFunction( + const TString& functionName, + const TString& /*symbolName*/, + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType repeatedArgType, + TType resultType, + TStringBuf /*implementationFile*/, + ECallingConvention /*callingConvention*/, + bool /*useFunctionContext*/) override + { + TypeInferrers_->emplace(functionName, New<TFunctionTypeInferrer>( + std::move(typeParameterConstraints), + std::move(argumentTypes), + repeatedArgType, + resultType)); + } + + void RegisterFunction( + const TString& functionName, + std::vector<TType> argumentTypes, + TType resultType, + TStringBuf /*implementationFile*/, + ECallingConvention /*callingConvention*/) override + { + TypeInferrers_->emplace(functionName, New<TFunctionTypeInferrer>( + std::unordered_map<TTypeParameter, TUnionType>{}, + std::move(argumentTypes), + EValueType::Null, + resultType)); + } + + void RegisterFunction( + const TString& functionName, + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType repeatedArgType, + TType resultType, + TStringBuf /*implementationFile*/) override + { + TypeInferrers_->emplace(functionName, New<TFunctionTypeInferrer>( + std::move(typeParameterConstraints), + std::move(argumentTypes), + repeatedArgType, + resultType)); + } + + void RegisterAggregate( + const TString& aggregateName, + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + TType argumentType, + TType resultType, + TType stateType, + TStringBuf /*implementationFile*/, + ECallingConvention /*callingConvention*/, + bool /*isFirst*/) override + { + TypeInferrers_->emplace(aggregateName, New<TAggregateTypeInferrer>( + typeParameterConstraints, + argumentType, + resultType, + stateType)); + } + +private: + const TTypeInferrerMapPtr TypeInferrers_; +}; + +std::unique_ptr<IFunctionRegistryBuilder> CreateTypeInferrerFunctionRegistryBuilder( + const TTypeInferrerMapPtr& typeInferrers) +{ + return std::make_unique<TTypeInferrerFunctionRegistryBuilder>(typeInferrers); +} + +//////////////////////////////////////////////////////////////////////////////// + +TConstTypeInferrerMapPtr CreateBuiltinTypeInferrers() +{ + auto result = New<TTypeInferrerMap>(); + + const TTypeParameter primitive = 0; + + result->emplace("if", New<TFunctionTypeInferrer>( + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::Boolean, primitive, primitive}, + primitive)); + + result->emplace("is_prefix", New<TFunctionTypeInferrer>( + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{EValueType::String, EValueType::String}, + EValueType::Boolean)); + + result->emplace("is_null", New<TFunctionTypeInferrer>( + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{primitive}, + EValueType::Null, + EValueType::Boolean)); + + result->emplace("is_nan", New<TFunctionTypeInferrer>( + std::vector<TType>{EValueType::Double}, + EValueType::Boolean)); + + const TTypeParameter castable = 1; + auto castConstraints = std::unordered_map<TTypeParameter, TUnionType>(); + castConstraints[castable] = std::vector<EValueType>{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Any + }; + + result->emplace("int64", New<TFunctionTypeInferrer>( + castConstraints, + std::vector<TType>{castable}, + EValueType::Null, + EValueType::Int64)); + + result->emplace("uint64", New<TFunctionTypeInferrer>( + castConstraints, + std::vector<TType>{castable}, + EValueType::Null, + EValueType::Uint64)); + + result->emplace("double", New<TFunctionTypeInferrer>( + castConstraints, + std::vector<TType>{castable}, + EValueType::Null, + EValueType::Double)); + + result->emplace("boolean", New<TFunctionTypeInferrer>( + std::vector<TType>{EValueType::Any}, + EValueType::Boolean)); + + result->emplace("string", New<TFunctionTypeInferrer>( + std::vector<TType>{EValueType::Any}, + EValueType::String)); + + result->emplace("if_null", New<TFunctionTypeInferrer>( + std::unordered_map<TTypeParameter, TUnionType>(), + std::vector<TType>{primitive, primitive}, + primitive)); + + const TTypeParameter nullable = 2; + + std::unordered_map<TTypeParameter, TUnionType> coalesceConstraints; + coalesceConstraints[nullable] = { + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Composite, + EValueType::Any + }; + result->emplace("coalesce", New<TFunctionTypeInferrer>( + coalesceConstraints, + std::vector<TType>{}, + nullable, + nullable)); + + const TTypeParameter summable = 3; + auto sumConstraints = std::unordered_map<TTypeParameter, TUnionType>(); + sumConstraints[summable] = std::vector<EValueType>{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Double + }; + + result->emplace("sum", New<TAggregateTypeInferrer>( + sumConstraints, + summable, + summable, + summable)); + + const TTypeParameter comparable = 4; + auto minMaxConstraints = std::unordered_map<TTypeParameter, TUnionType>(); + minMaxConstraints[comparable] = std::vector<EValueType>{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Boolean, + EValueType::Double, + EValueType::String + }; + for (const auto& name : {"min", "max"}) { + result->emplace(name, New<TAggregateTypeInferrer>( + minMaxConstraints, + comparable, + comparable, + comparable)); + } + + auto argMinMaxConstraints = std::unordered_map<TTypeParameter, TUnionType>(); + argMinMaxConstraints[comparable] = std::vector<EValueType>{ + EValueType::Int64, + EValueType::Uint64, + EValueType::Boolean, + EValueType::Double, + EValueType::String + }; + for (const auto& name : {"argmin", "argmax"}) { + result->emplace(name, New<TAggregateFunctionTypeInferrer>( + argMinMaxConstraints, + std::vector<TType>{primitive, comparable}, + EValueType::String, + primitive)); + } + + TTypeInferrerFunctionRegistryBuilder builder{result.Get()}; + RegisterBuiltinFunctions(&builder); + + return result; +} + +const TConstTypeInferrerMapPtr GetBuiltinTypeInferrers() +{ + static const auto builtinTypeInferrers = CreateBuiltinTypeInferrers(); + return builtinTypeInferrers; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/builtin_function_types.h b/yt/yt/library/query/base/builtin_function_types.h new file mode 100644 index 0000000000..0b88c29826 --- /dev/null +++ b/yt/yt/library/query/base/builtin_function_types.h @@ -0,0 +1,14 @@ +#pragma once + +#include "builtin_function_registry.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<IFunctionRegistryBuilder> CreateTypeInferrerFunctionRegistryBuilder( + const TTypeInferrerMapPtr& typeInferrers); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/callbacks.h b/yt/yt/library/query/base/callbacks.h new file mode 100644 index 0000000000..9ec71ae8f5 --- /dev/null +++ b/yt/yt/library/query/base/callbacks.h @@ -0,0 +1,55 @@ +#pragma once + +#include "public.h" +#include "query_common.h" + +#include <yt/yt/core/ypath/public.h> + +#include <yt/yt/core/rpc/public.h> + +#include <yt/yt/core/actions/future.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +using TExecuteQueryCallback = std::function<TFuture<void>( + const TQueryPtr& query, + TDataSource dataSource, + IUnversionedRowsetWriterPtr writer)>; + +//////////////////////////////////////////////////////////////////////////////// + +struct IExecutor + : public virtual TRefCounted +{ + virtual TFuture<TQueryStatistics> Execute( + TConstQueryPtr query, + TConstExternalCGInfoPtr externalCGInfo, + TDataSource dataSource, + IUnversionedRowsetWriterPtr writer, + const TQueryOptions& options) = 0; + +}; + +DEFINE_REFCOUNTED_TYPE(IExecutor) + +//////////////////////////////////////////////////////////////////////////////// + +struct IPrepareCallbacks +{ + virtual ~IPrepareCallbacks() = default; + + //! Returns the initial split for a given path. + virtual TFuture<TDataSplit> GetInitialSplit(const NYPath::TYPath& path) = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +using TJoinSubqueryEvaluator = std::function<ISchemafulUnversionedReaderPtr(std::vector<TRow>, TRowBufferPtr)>; +using TJoinSubqueryProfiler = std::function<TJoinSubqueryEvaluator(TQueryPtr, TConstJoinClausePtr)>; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/base/constraints-inl.h b/yt/yt/library/query/base/constraints-inl.h new file mode 100644 index 0000000000..1bfb37fa94 --- /dev/null +++ b/yt/yt/library/query/base/constraints-inl.h @@ -0,0 +1,128 @@ +#ifndef CONSTRAINTS_INL_H_ +#error "Direct inclusion of this file is not allowed, include constraints.h" +// For the sake of sane code completion. +#include "constraints.h" +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +inline bool operator < (const TValueBound& lhs, const TValueBound& rhs) +{ + return std::tie(lhs.Value, lhs.Flag) < std::tie(rhs.Value, rhs.Flag); +} + +inline bool operator <= (const TValueBound& lhs, const TValueBound& rhs) +{ + return std::tie(lhs.Value, lhs.Flag) <= std::tie(rhs.Value, rhs.Flag); +} + +inline bool operator == (const TValueBound& lhs, const TValueBound& rhs) +{ + return std::tie(lhs.Value, lhs.Flag) == std::tie(rhs.Value, rhs.Flag); +} + +inline bool TestValue(TValue value, const TValueBound& lower, const TValueBound& upper) +{ + return lower < TValueBound{value, true} && TValueBound{value, false} < upper; +} + +//////////////////////////////////////////////////////////////////////////////// + +inline TConstraintRef TConstraintRef::Empty() +{ + TConstraintRef result; + result.ColumnId = 0; + return result; +} + +inline TConstraintRef TConstraintRef::Universal() +{ + return {}; +} + +//////////////////////////////////////////////////////////////////////////////// + +inline TValueBound TConstraint::GetLowerBound() const +{ + return {LowerValue, !LowerIncluded}; +} + +inline TValueBound TConstraint::GetUpperBound() const +{ + return {UpperValue, UpperIncluded}; +} + +inline TConstraint TConstraint::Make(TValueBound lowerBound, TValueBound upperBound, TConstraintRef next) +{ + YT_VERIFY(lowerBound < upperBound); + return { + lowerBound.Value, + upperBound.Value, + next, + !lowerBound.Flag, + upperBound.Flag}; +} + +//////////////////////////////////////////////////////////////////////////////// + +inline TValue TColumnConstraint::GetValue() const +{ + YT_ASSERT(IsExact()); + return Lower.Value; +} + +inline bool TColumnConstraint::IsExact() const +{ + return Lower.Value == Upper.Value && !Lower.Flag && Upper.Flag; +} + +inline bool TColumnConstraint::IsRange() const +{ + return Lower.Value < Upper.Value; +} + +inline bool TColumnConstraint::IsUniversal() const +{ + return Lower.Value.Type == EValueType::Min && Upper.Value.Type == EValueType::Max; +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class TOnRange> +void TReadRangesGenerator::GenerateReadRanges(TConstraintRef constraintRef, const TOnRange& onRange, ui64 rangeExpansionLimit) +{ + auto columnId = constraintRef.ColumnId; + if (columnId == SentinelColumnId) { + // Leaf node. + onRange(Row_, rangeExpansionLimit); + return; + } + + auto intervals = MakeRange(Constraints_[columnId]) + .Slice(constraintRef.StartIndex, constraintRef.EndIndex); + + if (rangeExpansionLimit < intervals.Size()) { + Row_[columnId] = TColumnConstraint{intervals.Front().GetLowerBound(), intervals.Back().GetUpperBound()}; + + onRange(Row_, rangeExpansionLimit); + } else if (!intervals.Empty()) { + ui64 nextRangeExpansionLimit = rangeExpansionLimit / intervals.Size(); + YT_VERIFY(nextRangeExpansionLimit > 0); + for (const auto& item : intervals) { + Row_[columnId] = TColumnConstraint{item.GetLowerBound(), item.GetUpperBound()}; + + Row_[columnId].Lower.Value.Id = columnId; + Row_[columnId].Upper.Value.Id = columnId; + + GenerateReadRanges(item.Next, onRange, nextRangeExpansionLimit); + } + } + + Row_[columnId] = UniversalInterval; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/constraints.cpp b/yt/yt/library/query/base/constraints.cpp new file mode 100644 index 0000000000..054ca9c378 --- /dev/null +++ b/yt/yt/library/query/base/constraints.cpp @@ -0,0 +1,401 @@ +#include "constraints.h" +#include "query.h" +#include "functions.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +// Min is included because lower bound is included. +// Max is excluded beacuse upper bound is excluded. +// It allows keep resulting ranges without augmenting suffix with additional sentinels (e.g. [Max included] -> [Max included, Max excluded]). +TValueBound MinBound{MakeUnversionedSentinelValue(EValueType::Min), false}; +TValueBound MaxBound{MakeUnversionedSentinelValue(EValueType::Max), false}; + +TColumnConstraint UniversalInterval{MinBound, MaxBound}; + +//////////////////////////////////////////////////////////////////////////////// + +TConstraintRef TConstraintsHolder::Append(std::initializer_list<TConstraint> constraints, ui32 keyPartIndex) +{ + auto& columnConstraints = (*this)[keyPartIndex]; + + TConstraintRef result; + + result.ColumnId = keyPartIndex; + result.StartIndex = columnConstraints.size(); + result.EndIndex = result.StartIndex + constraints.size(); + + columnConstraints.insert(columnConstraints.end(), constraints.begin(), constraints.end()); + + return result; +} + +TConstraintRef TConstraintsHolder::Interval(TValueBound lower, TValueBound upper, ui32 keyPartIndex) +{ + return Append({TConstraint::Make(lower, upper)}, keyPartIndex); +} + +void VerifyConstraintsAreSorted(const TConstraintsHolder& holder, TConstraintRef ref) +{ + if (ref.ColumnId == SentinelColumnId) { + return; + } + + const auto& columnConstraints = holder[ref.ColumnId]; + + for (ui32 index = ref.StartIndex + 1; index < ref.EndIndex; ++index) { + YT_VERIFY(columnConstraints[index - 1].GetUpperBound() <= columnConstraints[index].GetLowerBound()); + } +} + +TConstraintRef TConstraintsHolder::Intersect(TConstraintRef lhs, TConstraintRef rhs) +{ + VerifyConstraintsAreSorted(*this, lhs); + VerifyConstraintsAreSorted(*this, rhs); + + if (lhs.ColumnId > rhs.ColumnId) { + std::swap(lhs, rhs); + } + + if (lhs.ColumnId < rhs.ColumnId) { + // Intersection of lhs.Next with rhs for lhs bounds. + + auto& columnConstraints = (*this)[lhs.ColumnId]; + + TConstraintRef result; + result.ColumnId = lhs.ColumnId; + result.StartIndex = columnConstraints.size(); + + for (auto lhsIndex = lhs.StartIndex; lhsIndex != lhs.EndIndex; ++lhsIndex) { + const auto& lhsItem = columnConstraints[lhsIndex]; + auto next = Intersect(lhsItem.Next, rhs); + + if (next.ColumnId == SentinelColumnId || next.StartIndex != next.EndIndex) { + columnConstraints.push_back(TConstraint::Make( + lhsItem.GetLowerBound(), + lhsItem.GetUpperBound(), + next)); + } + } + + result.EndIndex = columnConstraints.size(); + return result; + } + + YT_VERIFY(lhs.ColumnId == rhs.ColumnId); + + if (lhs.ColumnId == SentinelColumnId) { + return TConstraintRef::Universal(); + } + + auto& columnConstraints = (*this)[lhs.ColumnId]; + + TConstraintRef result; + result.ColumnId = lhs.ColumnId; + result.StartIndex = columnConstraints.size(); + + auto lhsIndex = lhs.StartIndex; + auto rhsIndex = rhs.StartIndex; + + while (lhsIndex != lhs.EndIndex && rhsIndex != rhs.EndIndex) { + const auto& lhsItem = columnConstraints[lhsIndex]; + const auto& rhsItem = columnConstraints[rhsIndex]; + + // Keep by values because columnConstraints can be reallocated in .push_back method. + auto lhsLower = lhsItem.GetLowerBound(); + auto lhsUpper = lhsItem.GetUpperBound(); + auto rhsLower = rhsItem.GetLowerBound(); + auto rhsUpper = rhsItem.GetUpperBound(); + + auto lhsNext = lhsItem.Next; + auto rhsNext = rhsItem.Next; + + auto intersectionLower = std::max(lhsLower, rhsLower); + auto intersectionUpper = std::min(lhsUpper, rhsUpper); + + if (intersectionLower < intersectionUpper) { + auto next = Intersect(lhsNext, rhsNext); + if (next.ColumnId == SentinelColumnId || next.StartIndex != next.EndIndex) { + columnConstraints.push_back(TConstraint::Make( + intersectionLower, + intersectionUpper, + next)); + } + } + + if (lhsUpper < rhsUpper) { + ++lhsIndex; + } else if (rhsUpper < lhsUpper) { + ++rhsIndex; + } else { + ++lhsIndex; + ++rhsIndex; + } + } + + result.EndIndex = columnConstraints.size(); + return result; +} + +TConstraintRef TConstraintsHolder::Unite(TConstraintRef lhs, TConstraintRef rhs) +{ + VerifyConstraintsAreSorted(*this, lhs); + VerifyConstraintsAreSorted(*this, rhs); + + if (lhs.ColumnId > rhs.ColumnId) { + std::swap(lhs, rhs); + } + + if (lhs.ColumnId < rhs.ColumnId) { + // Union of lhs.Next with rhs for lhs bounds and rhs for complement of lhs bounds. + + // Treat skipped column as universal constraint. + rhs = TConstraintsHolder::Append({ + TConstraint::Make( + MinBound, + MaxBound, + rhs) + }, + lhs.ColumnId); + } + + YT_VERIFY(lhs.ColumnId == rhs.ColumnId); + + if (lhs.ColumnId == SentinelColumnId) { + return TConstraintRef::Universal(); + } + + auto& columnConstraints = (*this)[lhs.ColumnId]; + + TConstraintRef result; + result.ColumnId = lhs.ColumnId; + result.StartIndex = columnConstraints.size(); + + auto lhsIndex = lhs.StartIndex; + auto rhsIndex = rhs.StartIndex; + + TValueBound lastBound{MakeUnversionedSentinelValue(EValueType::Min), false}; + + while (lhsIndex != lhs.EndIndex && rhsIndex != rhs.EndIndex) { + const auto& lhsItem = columnConstraints[lhsIndex]; + const auto& rhsItem = columnConstraints[rhsIndex]; + + // Keep by values because columnConstraints can be reallocated in .push_back method. + auto lhsLower = lhsItem.GetLowerBound(); + auto lhsUpper = lhsItem.GetUpperBound(); + auto rhsLower = rhsItem.GetLowerBound(); + auto rhsUpper = rhsItem.GetUpperBound(); + + auto lhsNext = lhsItem.Next; + auto rhsNext = rhsItem.Next; + + // Unite [a, b] and [c, d] + // Cases: + // a b c d + // [ ] + // [ ] + // a c d b + // [ ] + // [ ] + // a c b d + // [ ] + // [ ] + + // Disjoint. Append lhs. + if (lhsUpper <= rhsLower) { + columnConstraints.push_back(TConstraint::Make( + std::max(lastBound, lhsLower), + lhsUpper, + lhsNext)); + + ++lhsIndex; + continue; + } + + // Disjoint. Append rhs. + if (rhsUpper <= lhsLower) { + columnConstraints.push_back(TConstraint::Make( + std::max(lastBound, rhsLower), + rhsUpper, + rhsNext)); + ++rhsIndex; + continue; + } + + auto intersectionLower = std::max(lhsLower, rhsLower); + auto intersectionUpper = std::min(lhsUpper, rhsUpper); + + auto unionLower = std::max(lastBound, lhsLower < rhsLower ? lhsLower : rhsLower); + lastBound = intersectionUpper; + + if (unionLower < intersectionLower) { + columnConstraints.push_back(TConstraint::Make( + unionLower, + intersectionLower, + lhsLower < rhsLower ? lhsNext : rhsNext)); + } + + YT_VERIFY(intersectionLower < intersectionUpper); + + auto next = Unite(lhsNext, rhsNext); + + columnConstraints.push_back( + TConstraint::Make(intersectionLower, intersectionUpper, next)); + + if (lhsUpper < rhsUpper) { + ++lhsIndex; + } else if (rhsUpper < lhsUpper) { + ++rhsIndex; + } else { + ++lhsIndex; + ++rhsIndex; + } + } + + while (lhsIndex != lhs.EndIndex) { + const auto& lhsItem = columnConstraints[lhsIndex]; + + auto lowerBound = std::max(lastBound, lhsItem.GetLowerBound()); + auto upperBound = lhsItem.GetUpperBound(); + auto lhsNext = lhsItem.Next; + + if (lowerBound < upperBound) { + columnConstraints.push_back(TConstraint::Make( + lowerBound, + upperBound, + lhsNext)); + } + ++lhsIndex; + } + + while (rhsIndex != rhs.EndIndex) { + const auto& rhsItem = columnConstraints[rhsIndex]; + + auto lowerBound = std::max(lastBound, rhsItem.GetLowerBound()); + auto upperBound = rhsItem.GetUpperBound(); + auto rhsNext = rhsItem.Next; + + if (lowerBound < upperBound) { + columnConstraints.push_back(TConstraint::Make( + lowerBound, + upperBound, + rhsNext)); + } + ++rhsIndex; + } + + result.EndIndex = columnConstraints.size(); + return result; +} + +TString ToString(const TConstraintsHolder& constraints, TConstraintRef root) +{ + TStringBuilder result; + + result.AppendString("Constraints:"); + + auto addOffset = [&] (int offset) { + for (int i = 0; i < offset; ++i) { + result.AppendString(". "); + } + return result; + }; + + std::function<void(TConstraintRef)> printNode = + [&] (TConstraintRef ref) { + if (ref.ColumnId == SentinelColumnId) { + result.AppendString(" <universe>"); + } else { + if (ref.StartIndex == ref.EndIndex) { + result.AppendString(" <empty>"); + return; + } + + for (const auto& item : MakeRange(constraints[ref.ColumnId]).Slice(ref.StartIndex, ref.EndIndex)) { + result.AppendString("\n"); + addOffset(ref.ColumnId); + TColumnConstraint columnConstraint{item.GetLowerBound(), item.GetUpperBound()}; + + if (columnConstraint.IsExact()) { + result.AppendFormat("%kv", columnConstraint.GetValue()); + } else { + result.AppendFormat("%v%kv .. %kv%v", + "[("[columnConstraint.Lower.Flag], + columnConstraint.Lower.Value, + columnConstraint.Upper.Value, + ")]"[columnConstraint.Upper.Flag]); + } + + result.AppendString(":"); + printNode(item.Next); + } + } + }; + + printNode(root); + return result.Flush(); +} + +TReadRangesGenerator::TReadRangesGenerator(const TConstraintsHolder& constraints) + : Constraints_(constraints) + , Row_(Constraints_.size(), UniversalInterval) +{ } + +static void CopyValues(TRange<TValue> source, TMutableRow dest) +{ + std::copy(source.Begin(), source.End(), dest.Begin()); +} + +TMutableRow MakeLowerBound(TRowBuffer* rowBuffer, TRange<TValue> boundPrefix, TValueBound lastBound) +{ + auto prefixSize = boundPrefix.Size(); + + if (lastBound.Value.Type == EValueType::Min) { + auto lowerBound = rowBuffer->AllocateUnversioned(prefixSize); + CopyValues(boundPrefix, lowerBound); + return lowerBound; + } + + // Consider included/excluded bounds. + bool lowerExcluded = lastBound.Flag; + auto lowerBound = rowBuffer->AllocateUnversioned(prefixSize + 1 + lowerExcluded); + CopyValues(boundPrefix, lowerBound); + lowerBound[prefixSize] = lastBound.Value; + if (lowerExcluded) { + lowerBound[prefixSize + 1] = MakeUnversionedSentinelValue(EValueType::Max); + } + return lowerBound; +} + +TMutableRow MakeUpperBound(TRowBuffer* rowBuffer, TRange<TValue> boundPrefix, TValueBound lastBound) +{ + auto prefixSize = boundPrefix.Size(); + + // Consider included/excluded bounds. + bool upperIncluded = lastBound.Flag; + auto upperBound = rowBuffer->AllocateUnversioned(prefixSize + 1 + upperIncluded); + CopyValues(boundPrefix, upperBound); + upperBound[prefixSize] = lastBound.Value; + if (upperIncluded) { + upperBound[prefixSize + 1] = MakeUnversionedSentinelValue(EValueType::Max); + } + return upperBound; +} + +TRowRange RowRangeFromPrefix(TRowBuffer* rowBuffer, TRange<TValue> boundPrefix) +{ + auto prefixSize = boundPrefix.Size(); + + auto lowerBound = rowBuffer->AllocateUnversioned(prefixSize); + CopyValues(boundPrefix, lowerBound); + + auto upperBound = rowBuffer->AllocateUnversioned(prefixSize + 1); + CopyValues(boundPrefix, upperBound); + upperBound[prefixSize] = MakeUnversionedSentinelValue(EValueType::Max); + return std::make_pair(lowerBound, upperBound); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/constraints.h b/yt/yt/library/query/base/constraints.h new file mode 100644 index 0000000000..4322a632ce --- /dev/null +++ b/yt/yt/library/query/base/constraints.h @@ -0,0 +1,137 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/query/engine_api/public.h> + +#include <yt/yt/client/table_client/row_buffer.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct TValueBound +{ + TValue Value; + // Bounds are located between values. + // Flag denotes when the bound is before value or after it. + // For upper bound Flag = Included. + // For lower bound Flag = !Included. + bool Flag; +}; + +extern TValueBound MinBound; +extern TValueBound MaxBound; + +constexpr ui32 SentinelColumnId = std::numeric_limits<ui32>::max(); + +bool operator < (const TValueBound& lhs, const TValueBound& rhs); + +bool operator <= (const TValueBound& lhs, const TValueBound& rhs); + +bool operator == (const TValueBound& lhs, const TValueBound& rhs); + +bool TestValue(TValue value, const TValueBound& lower, const TValueBound& upper); + +struct TConstraintRef +{ + // Universal constraint if ColumnId is sentinel. + ui32 ColumnId = SentinelColumnId; + ui32 StartIndex = 0; + ui32 EndIndex = 0; + + static TConstraintRef Empty(); + + static TConstraintRef Universal(); +}; + +struct TConstraint +{ + // For exact match (key = ...) both values are equal. + TValue LowerValue; + TValue UpperValue; + + TConstraintRef Next; + + // TValueBound is not used to get more tight layout. + bool LowerIncluded; + bool UpperIncluded; + + TValueBound GetLowerBound() const; + + TValueBound GetUpperBound() const; + + static TConstraint Make(TValueBound lowerBound, TValueBound upperBound, TConstraintRef next = TConstraintRef::Universal()); +}; + +struct TColumnConstraint +{ + TValueBound Lower; + TValueBound Upper; + + TValue GetValue() const; + + bool IsExact() const; + + bool IsRange() const; + + bool IsUniversal() const; +}; + +struct TColumnConstraints + : public std::vector<TConstraint> +{ }; + +extern TColumnConstraint UniversalInterval; + +struct TConstraintsHolder + : public std::vector<TColumnConstraints> +{ + explicit TConstraintsHolder(ui32 columnCount) + : std::vector<TColumnConstraints>(columnCount) + { } + + TConstraintRef Append(std::initializer_list<TConstraint> constraints, ui32 keyPartIndex); + + TConstraintRef Interval(TValueBound lower, TValueBound upper, ui32 keyPartIndex); + + TConstraintRef Intersect(TConstraintRef lhs, TConstraintRef rhs); + + TConstraintRef Unite(TConstraintRef lhs, TConstraintRef rhs); + + TConstraintRef ExtractFromExpression( + const TConstExpressionPtr& expr, + const TKeyColumns& keyColumns, + const TRowBufferPtr& rowBuffer, + const TConstConstraintExtractorMapPtr& constraintExtractors = GetBuiltinConstraintExtractors()); +}; + +TString ToString(const TConstraintsHolder& constraints, TConstraintRef root); + +class TReadRangesGenerator +{ +public: + explicit TReadRangesGenerator(const TConstraintsHolder& constraints); + + template <class TOnRange> + void GenerateReadRanges(TConstraintRef constraintRef, const TOnRange& onRange, ui64 rangeExpansionLimit = std::numeric_limits<ui64>::max()); + +private: + const TConstraintsHolder& Constraints_; + std::vector<TColumnConstraint> Row_; +}; + +TMutableRow MakeLowerBound(TRowBuffer* rowBuffer, TRange<TValue> boundPrefix, TValueBound lastBound); + +TMutableRow MakeUpperBound(TRowBuffer* rowBuffer, TRange<TValue> boundPrefix, TValueBound lastBound); + +TRowRange RowRangeFromPrefix(TRowBuffer* rowBuffer, TRange<TValue> boundPrefix); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + +#define CONSTRAINTS_INL_H_ +#include "constraints-inl.h" +#undef CONSTRAINTS_INL_H_ diff --git a/yt/yt/library/query/base/coordination_helpers.cpp b/yt/yt/library/query/base/coordination_helpers.cpp new file mode 100644 index 0000000000..1fe370523d --- /dev/null +++ b/yt/yt/library/query/base/coordination_helpers.cpp @@ -0,0 +1,58 @@ +#include "coordination_helpers.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +TRow WidenKeySuccessor(TRow key, size_t prefix, const TRowBufferPtr& rowBuffer, bool captureValues) +{ + auto wideKey = rowBuffer->AllocateUnversioned(prefix + 1); + + for (ui32 index = 0; index < prefix; ++index) { + wideKey[index] = key[index]; + if (captureValues) { + wideKey[index] = rowBuffer->CaptureValue(wideKey[index]); + } + } + + wideKey[prefix] = MakeUnversionedSentinelValue(EValueType::Max); + + return wideKey; +} + +TRow WidenKeySuccessor(TRow key, const TRowBufferPtr& rowBuffer, bool captureValues) +{ + return WidenKeySuccessor(key, key.GetCount(), rowBuffer, captureValues); +} + +size_t GetSignificantWidth(TRow row) +{ + auto valueIt = row.Begin(); + while (valueIt != row.End() && !IsSentinelType(valueIt->Type)) { + ++valueIt; + } + return std::distance(row.Begin(), valueIt); +} + +size_t Step(size_t current, size_t source, size_t target) +{ + YT_VERIFY(target <= source); + YT_VERIFY(current <= source); + + // Original expression: ((c * t / s + 1) * s + t - 1) / t - c; + auto result = (source - 1 - current * target % source ) / target + 1; + YT_VERIFY(current + result <= source); + YT_VERIFY(result > 0); + return result; +} + +void TRangeFormatter::operator()(TStringBuilderBase* builder, TRowRange source) const +{ + builder->AppendFormat("[%v .. %v]", + source.first, + source.second); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/coordination_helpers.h b/yt/yt/library/query/base/coordination_helpers.h new file mode 100644 index 0000000000..4b9e9d2b6f --- /dev/null +++ b/yt/yt/library/query/base/coordination_helpers.h @@ -0,0 +1,627 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/table_client/row_buffer.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <yt/yt/library/numeric/algorithm_helpers.h> + +#include <yt/yt/core/misc/range.h> + +// TODO(lukyan): Checks denoted by YT_QL_CHECK are heavy. Change them to YT_ASSERT after some time. +#define YT_QL_CHECK(expr) YT_VERIFY(expr) + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +TRow WidenKeySuccessor(TRow key, size_t prefix, const TRowBufferPtr& rowBuffer, bool captureValues); + +TRow WidenKeySuccessor(TRow key, const TRowBufferPtr& rowBuffer, bool captureValues); + +size_t GetSignificantWidth(TRow row); + +size_t Step(size_t current, size_t source, size_t target); + +struct TRangeFormatter +{ + void operator()(TStringBuilderBase* builder, TRowRange source) const; +}; + +using TRangeIt = TRange<TRowRange>::iterator; +using TSampleIt = TRange<TRow>::iterator; + +// Ranges must be cropped before using ForEachRange + +template <class T, class TCallback> +void ForEachRange(TRange<std::pair<T, T>> ranges, std::pair<T, T> limit, const TCallback& callback) +{ + YT_VERIFY(!ranges.Empty()); + + auto it = ranges.begin(); + auto lower = limit.first; + + while (true) { + auto next = it; + ++next; + + if (next == ranges.end()) { + break; + } + + auto upper = it->second; + callback(std::make_pair(lower, upper)); + + it = next; + lower = it->first; + } + + auto upper = limit.second; + callback(std::make_pair(lower, upper)); +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class TItem, class TLess, class TNotGreater> +TRange<TItem> CropItems(TRange<TItem> items, TLess less, TNotGreater notGreater) +{ + auto itemsBegin = BinarySearch(items.begin(), items.end(), less); + auto itemsEnd = BinarySearch(itemsBegin, items.end(), notGreater); + return TRange<TItem>(itemsBegin, itemsEnd); +} + +// TPredicate implements operator(): + +// operator(itemIt, pivot) ~ itemIt PRECEDES pivot: +// ranges: !(pivot < itemIt->upper) +// points: *itemIt < pivot + +// operator(pivot, itemIt) ~ it FOLLOWS pivot: +// ranges, key: !(pivot < itemIt->lower) +// points: !(pivot < *itemIt) + +// Properties: +// item PRECEDES shard ==> item NOT FOLLOWS shard +// item FOLLOWS shard ==> item NOT PRECEDES shard + + +// SplitByPivots does not repeat items in callbacks. + +// For input: [..) [..) [.|..|..|..) +// OnItems [..) [..) | +// OnShards [.|..|..|..) + +// For input: [..) [..) [.|.) [..) | +// OnItems [..) [..) | +// OnShards [.|.) +// OnItems [..) | + +template <class TItem, class TShard, class TPredicate, class TOnItemsFunctor, class TOnShardsFunctor> +void SplitByPivots( + TRange<TItem> items, + TRange<TShard> shards, + TPredicate pred, + TOnItemsFunctor onItemsFunctor, + TOnShardsFunctor onShardsFunctor) +{ + auto shardIt = shards.begin(); + auto itemIt = items.begin(); + + while (itemIt != items.end()) { + // Run binary search to find the relevant shards. + // First shard such that !Follows(itemIt, shard) + + // First shard: item NOT FOLLOWS shard + shardIt = ExponentialSearch(shardIt, shards.end(), [&] (auto it) { + // For interval: *shardIt <= itemIt->lower + // For points: *shardIt <= *itemIt + return pred(it, itemIt); // item FOLLOWS shard + }); + + // For interval: itemIt->upper <= *shardIt + // For points: *itemIt < shardIt is always true: *shardIt <= *itemIt ~ shardIt > *itemIt + + if (shardIt == shards.end()) { + onItemsFunctor(itemIt, items.end(), shardIt); + return; + } + + // item PRECEDES shard + if (pred(itemIt, shardIt)) { // PRECEDES + // First item: item NOT PRECEDES shard + auto nextItemsIt = ExponentialSearch(itemIt, items.end(), [&] (auto it) { + // For interval: itemIt->upper <= *shardIt + // For points: *itemIt < shardIt + return pred(it, shardIt); // item PRECEDES shard + }); + + onItemsFunctor(itemIt, nextItemsIt, shardIt); + itemIt = nextItemsIt; + } else { + // First shard: item PRECEDES shard + auto endShardIt = ExponentialSearch(shardIt, shards.end(), [&] (auto it) { + // For interval: !(itemIt->upper <= *shardIt) ~ *shardIt < itemIt->upper + // For points: *itemIt < shardIt + return !pred(itemIt, it); // item PRECEDES shard + }); + + onShardsFunctor(shardIt, endShardIt, itemIt); + shardIt = endShardIt; + ++itemIt; + } + } +} + +// GroupByShards does not repeat shards in callbacks. + +// For input: [..|..|..|..) [..) [..) | +// OnShards [..|..|..| ) +// OnItems [ ..) [..) [..) | + +// For input: [..) [..|..|..|..) [..) | +// OnItems [..) [..| ) +// OnShards [ ..|..| ) +// OnItems [ ..) [..) | + +template <class TItem, class TShard, class TPredicate, class TGroupFunctor> +void GroupByShards( + TRange<TItem> items, + TRange<TShard> shards, + TPredicate pred, + TGroupFunctor onGroupFunctor) +{ + auto shardIt = shards.begin(); + auto itemIt = items.begin(); + + while (itemIt != items.end()) { + // Run binary search to find the relevant shards. + + // First shard: item NOT FOLLOWS shard + auto shardItStart = ExponentialSearch(shardIt, shards.end(), [&] (auto it) { + // For interval: *shardIt <= itemIt->lower + // For points: *shardIt <= *itemIt + return pred(it, itemIt); // item FOLLOWS shard + }); + + // pred(shardIt, itemIt) + // !pred(itemIt, shardIt) + // pred(itemIt, shardIt) + // !pred(shardIt, itemIt) + + // For interval: itemIt->upper <= *shardIt + // For points: *itemIt < shardIt is allways true: *shardIt <= *itemIt ~ shardIt > *itemIt + + // First shard: item PRECEDES shard + shardIt = ExponentialSearch(shardItStart, shards.end(), [&] (auto it) { + // For interval: !(itemIt->upper <= *shardIt) ~ *shardIt < itemIt->upper + // For points: *itemIt < shardIt + return !pred(itemIt, it); // item PRECEDES shard + }); + + if (shardIt != shards.end()) { + // First item: item NOT PRECEDES shard + auto itemsItNext = ExponentialSearch(itemIt, items.end(), [&] (auto it) { + // For interval: itemIt->upper <= *shardIt + // For points: *itemIt < shardIt + return pred(it, shardIt); // item PRECEDES shard + }); + +#if 0 + auto itemsItEnd = ExponentialSearch(itemsItNext, items.end(), [&] (auto it) { + return !pred(shardIt, it); // item FOLLOWS shard + }); + + YT_VERIFY(itemsItNext == itemsItEnd || itemsItNext + 1 == itemsItEnd); +#else + auto itemsItEnd = itemsItNext; + if (itemsItEnd != items.end() && !pred(shardIt, itemsItEnd)) { // item FOLLOWS shard + ++itemsItEnd; + } +#endif + + onGroupFunctor(itemIt, itemsItEnd, shardItStart, shardIt); + + itemIt = itemsItNext; + ++shardIt; + + // TODO(lukyan): Reduce comparisons. + // There are three cases for the next iteration: + // 0. [ ) [ | | | ) [ ) // itemsItNext != itemsItEnd + // 1. [ ) [ | ) [ ) [ ) | // itemsItNext != itemsItEnd + // 2. [ ) | | | [ ) // itemsItNext == itemsItEnd + // In cases 0 and 1 no need to call `auto shardItStart = ExponentialSearch`. + + } else { + onGroupFunctor(itemIt, items.end(), shardItStart, shardIt); + return; + } + } +} + +template <class TItem, class TShard, class TPredicate, class TOnItemsFunctor> +void GroupItemsByShards( + TRange<TItem> items, + TRange<TShard> shards, + TPredicate pred, + TOnItemsFunctor onItemsFunctor) +{ + GroupByShards( + items, + shards, + pred, + [&] (auto itemsIt, auto itemsItEnd, auto shardIt, auto shardItEnd) { + YT_VERIFY(itemsIt != itemsItEnd); + // shardItEnd can invalid. + while (shardIt != shardItEnd) { + onItemsFunctor(shardIt++, itemsIt, itemsIt + 1); + } + + onItemsFunctor(shardIt, itemsIt, itemsItEnd); + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +TRow GetPivotKey(const T& shard); + +template <class T> +TRow GetNextPivotKey(const T& shard); + +template <class T> +TRange<TRow> GetSampleKeys(const T& shard); + +//////////////////////////////////////////////////////////////////////////////// + +template <class T, class TOnItemsFunctor> +void SplitRangesByTablets( + TRange<TRowRange> ranges, + TRange<T> tablets, + TRow lowerCapBound, + TRow upperCapBound, + TOnItemsFunctor onItemsFunctor) +{ + using TShardIt = typename TRange<T>::iterator; + + struct TPredicate + { + TRow GetKey(TShardIt shardIt) const + { + return GetPivotKey(*shardIt); + } + + // itemIt PRECEDES shardIt + bool operator() (const TRowRange* itemIt, TShardIt shardIt) const + { + return itemIt->second <= GetKey(shardIt); + } + + // itemIt FOLLOWS shardIt + bool operator() (TShardIt shardIt, const TRowRange* itemIt) const + { + return GetKey(shardIt) <= itemIt->first; + } + }; + + auto croppedRanges = CropItems( + ranges, + [&] (const TRowRange* itemIt) { + return !(lowerCapBound < itemIt->second); + }, + [&] (const TRowRange* itemIt) { + return !(upperCapBound < itemIt->first); + }); + + YT_VERIFY(!tablets.Empty()); + + GroupItemsByShards( + croppedRanges, + tablets.Slice(1, tablets.size()), + TPredicate{}, + onItemsFunctor); +} + +template <class T, class TOnItemsFunctor> +void SplitKeysByTablets( + TRange<TRow> keys, + size_t keyWidth, + size_t fullKeySize, + TRange<T> tablets, + TRow lowerCapBound, + TRow upperCapBound, + TOnItemsFunctor onItemsFunctor) +{ + using TShardIt = typename TRange<T>::iterator; + + struct TPredicate + { + size_t KeySize; + bool IsFullKey; + + TRow GetKey(TShardIt shardIt) const + { + return GetPivotKey(*shardIt); + } + + bool Less(TRow lhs, TRow rhs) const + { + return CompareRows(lhs, rhs, KeySize) < 0; + } + + bool LessOrEqual(TRow lhs, TRow rhs) const + { + return CompareRows(lhs, rhs, KeySize) < IsFullKey; + } + + // itemIt PRECEDES shardIt + bool operator() (const TRow* itemIt, TShardIt shardIt) const + { + return Less(*itemIt, GetKey(shardIt)); + } + + // itemIt FOLLOWS shardIt + bool operator() (TShardIt shardIt, const TRow* itemIt) const + { + // Less? + return LessOrEqual(GetKey(shardIt), *itemIt); + } + }; + + auto croppedKeys = CropItems( + keys, + [&] (const TRow* itemIt) { + return *itemIt < lowerCapBound; + }, + [&] (const TRow* itemIt) { + return !(upperCapBound < *itemIt); + }); + + YT_VERIFY(!tablets.Empty()); + + GroupItemsByShards( + croppedKeys, + tablets.Slice(1, tablets.size()), + TPredicate{keyWidth, keyWidth == fullKeySize}, + onItemsFunctor); +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class T, class TOnGroup> +void GroupRangesByPartition(TRange<TRowRange> ranges, TRange<T> partitions, const TOnGroup& onGroup) +{ + using TShardIt = typename TRange<T>::iterator; + + if (!ranges.Empty()) { + TRow lower = GetPivotKey(partitions.Front()); + TRow upper = GetNextPivotKey(partitions.Back()); + + YT_VERIFY(lower < ranges.Front().second); + YT_VERIFY(ranges.Back().first < upper); + } + + struct TPredicate + { + TRow GetKey(TShardIt shardIt) const + { + return GetNextPivotKey(*shardIt); + } + + // itemIt PRECEDES shardIt + bool operator() (TRangeIt itemIt, TShardIt shardIt) const + { + return itemIt->second <= GetKey(shardIt); + } + + // itemIt FOLLOWS shardIt + bool operator() (TShardIt shardIt, TRangeIt itemIt) const + { + return GetKey(shardIt) <= itemIt->first; + } + }; + + GroupItemsByShards(ranges, partitions, TPredicate{}, onGroup); +} + +template <class T> +std::vector<TSharedRange<TRowRange>> SplitTablet( + TRange<T> partitions, + TSharedRange<TRowRange> ranges, + TRowBufferPtr rowBuffer, + size_t maxSubsplitsPerTablet, + bool verboseLogging, + const NLogging::TLogger& Logger) +{ + using TShardIt = typename TRange<T>::iterator; + using TItemIt = TRange<TRowRange>::iterator; + + struct TGroup + { + TShardIt PartitionIt; + TItemIt BeginIt; + TItemIt EndIt; + }; + + std::vector<TGroup> groupedByPartitions; + + GroupRangesByPartition(ranges, MakeRange(partitions), [&] (TShardIt shardIt, TItemIt itemIt, TItemIt itemItEnd) { + YT_VERIFY(itemIt != itemItEnd); + + if (shardIt == partitions.end()) { + YT_VERIFY(itemIt + 1 == ranges.end()); + return; + } + + YT_VERIFY(groupedByPartitions.empty() || groupedByPartitions.back().PartitionIt != shardIt); + groupedByPartitions.push_back(TGroup{shardIt, itemIt, itemItEnd}); + }); + + struct TPredicate + { + TRow GetKey(const TRow* shardIt) const + { + return *shardIt; + } + + // itemIt PRECEDES shardIt + bool operator() (const TRowRange* itemIt, const TRow* shardIt) const + { + return itemIt->second <= GetKey(shardIt); + } + + // itemIt FOLLOWS shardIt + bool operator() (const TRow* shardIt, const TRowRange* itemIt) const + { + return GetKey(shardIt) <= itemIt->first; + } + }; + + size_t allShardCount = 0; + + // Calculate touched shards (partitions an) count. + for (auto [partitionIt, beginIt, endIt] : groupedByPartitions) { + GroupByShards( + MakeRange(beginIt, endIt), + GetSampleKeys(*partitionIt), + TPredicate{}, + [&] (TRangeIt /*rangesIt*/, TRangeIt /*rangesItEnd*/, TSampleIt sampleIt, TSampleIt sampleItEnd) { + allShardCount += 1 + std::distance(sampleIt, sampleItEnd); + }); + } + size_t targetSplitCount = std::min(maxSubsplitsPerTablet, allShardCount); + + YT_VERIFY(targetSplitCount > 0); + + YT_LOG_DEBUG_IF(verboseLogging, "AllShardCount: %v, TargetSplitCount: %v", + allShardCount, + targetSplitCount); + + std::vector<TSharedRange<TRowRange>> groupedSplits; + std::vector<TRowRange> group; + + auto holder = MakeSharedRangeHolder(ranges.GetHolder(), rowBuffer); + + size_t currentShardCount = 0; + size_t lastSampleCount = 0; + auto addGroup = [&] (size_t count) { + YT_VERIFY(!group.empty()); + + size_t nextStep = Step(currentShardCount, allShardCount, targetSplitCount); + + YT_VERIFY(count <= nextStep); + YT_VERIFY(currentShardCount <= allShardCount); + + currentShardCount += count; + + if (count == nextStep) { + YT_LOG_DEBUG_IF(verboseLogging, "(%v, %v) make batch [%v .. %v] from %v ranges", + lastSampleCount, + currentShardCount, + group.front().first, + group.back().second, + group.size()); + + for (size_t i = 0; i + 1 < group.size(); ++i) { + YT_QL_CHECK(group[i].second <= group[i + 1].first); + } + + for (size_t i = 0; i < group.size(); ++i) { + YT_QL_CHECK(group[i].first < group[i].second); + } + + groupedSplits.push_back(MakeSharedRange(std::move(group), holder)); + lastSampleCount = currentShardCount; + } + }; + + for (auto [partitionIt, beginIt, endIt] : groupedByPartitions) { + const auto& partition = *partitionIt; + TRowRange partitionBounds(GetPivotKey(partition), GetNextPivotKey(partition)); + + YT_LOG_DEBUG_IF(verboseLogging, "Iterating over partition %v: [%v .. %v]", + partitionBounds, + beginIt - begin(ranges), + endIt - begin(ranges)); + + + auto slice = MakeRange(beginIt, endIt); + + // Do not need to crop. Already cropped in GroupRangesByPartition. + + auto minBound = std::max<TRow>(slice.Front().first, rowBuffer->CaptureRow(GetPivotKey(partition))); + auto maxBound = std::min<TRow>(slice.Back().second, rowBuffer->CaptureRow(GetNextPivotKey(partition))); + + auto samples = GetSampleKeys(partition); + + TRangeIt rangesItLast = nullptr; + + GroupByShards( + slice, + samples, + TPredicate{}, + [&] (TRangeIt rangesIt, TRangeIt rangesItEnd, TSampleIt sampleIt, TSampleIt sampleItEnd) { + YT_VERIFY(rangesIt != rangesItEnd); + + if (sampleIt != sampleItEnd) { + TRow start = minBound; + if (rangesIt == rangesItLast) { + if (sampleIt != samples.begin()) { + start = *(sampleIt - 1); + } + } else { + if (rangesIt != slice.begin()) { + start = rangesIt->first; + } + } + + { + auto upper = rangesIt + 1 == slice.end() ? maxBound : rangesIt->second; + YT_QL_CHECK(*(sampleItEnd - 1) <= upper); + } + + auto currentBound = start; + + while (sampleIt != sampleItEnd) { + size_t nextStep = std::min<size_t>( + Step(currentShardCount, allShardCount, targetSplitCount), + sampleItEnd - sampleIt); + YT_VERIFY(nextStep > 0); + + sampleIt += nextStep - 1; + + auto nextBound = rowBuffer->CaptureRow(*sampleIt); + YT_QL_CHECK(currentBound < nextBound); + group.emplace_back(currentBound, nextBound); + + addGroup(nextStep); + currentBound = nextBound; + ++sampleIt; + } + } + + // TODO: Capture *sampleIt ? + auto lower = sampleIt == samples.begin() ? minBound : *(sampleIt - 1); + auto upper = sampleIt == samples.end() ? maxBound : *sampleIt; + + lower = std::max<TRow>(lower, rangesIt->first); + upper = std::min<TRow>(upper, (rangesItEnd - 1)->second); + + ForEachRange(MakeRange(rangesIt, rangesItEnd), TRowRange(lower, upper), [&] (auto item) { + group.push_back(item); + }); + + addGroup(1); + + rangesItLast = rangesItEnd - 1; + }); + } + + YT_VERIFY(currentShardCount == allShardCount); + + return groupedSplits; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/functions.cpp b/yt/yt/library/query/base/functions.cpp new file mode 100644 index 0000000000..b9c918c32f --- /dev/null +++ b/yt/yt/library/query/base/functions.cpp @@ -0,0 +1,186 @@ +#include "functions.h" + +#include <library/cpp/yt/misc/variant.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +int TFunctionTypeInferrer::GetNormalizedConstraints( + std::vector<TTypeSet>* typeConstraints, + std::vector<int>* formalArguments, + std::optional<std::pair<int, bool>>* repeatedType) const +{ + std::unordered_map<TTypeParameter, int> idToIndex; + + auto getIndex = [&] (const TType& type) -> int { + return Visit(type, + [&] (TTypeParameter genericId) -> int { + auto itIndex = idToIndex.find(genericId); + if (itIndex != idToIndex.end()) { + return itIndex->second; + } else { + int index = typeConstraints->size(); + auto it = TypeParameterConstraints_.find(genericId); + if (it == TypeParameterConstraints_.end()) { + typeConstraints->push_back(TTypeSet({ + EValueType::Null, + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Any})); + } else { + typeConstraints->push_back(TTypeSet(it->second.begin(), it->second.end())); + } + idToIndex.emplace(genericId, index); + return index; + } + }, + [&] (EValueType fixedType) -> int { + int index = typeConstraints->size(); + typeConstraints->push_back(TTypeSet({fixedType})); + return index; + }, + [&] (const TUnionType& unionType) -> int { + int index = typeConstraints->size(); + typeConstraints->push_back(TTypeSet(unionType.begin(), unionType.end())); + return index; + }); + }; + + for (const auto& argumentType : ArgumentTypes_) { + formalArguments->push_back(getIndex(argumentType)); + } + + if (!(std::holds_alternative<EValueType>(RepeatedArgumentType_) && + std::get<EValueType>(RepeatedArgumentType_) == EValueType::Null)) + { + *repeatedType = std::make_pair( + getIndex(RepeatedArgumentType_), + std::get_if<TUnionType>(&RepeatedArgumentType_)); + } + + return getIndex(ResultType_); +} + +void TAggregateTypeInferrer::GetNormalizedConstraints( + TTypeSet* constraint, + std::optional<EValueType>* stateType, + std::optional<EValueType>* resultType, + TStringBuf name) const +{ + if (TypeParameterConstraints_.size() > 1) { + THROW_ERROR_EXCEPTION("Too many constraints for aggregate function"); + } + + auto setType = [&] (const TType& targetType, bool allowGeneric) -> std::optional<EValueType> { + if (auto* fixedType = std::get_if<EValueType>(&targetType)) { + return *fixedType; + } + if (allowGeneric) { + if (auto* typeId = std::get_if<TTypeParameter>(&targetType)) { + auto found = TypeParameterConstraints_.find(*typeId); + if (found != TypeParameterConstraints_.end()) { + return std::nullopt; + } + } + } + THROW_ERROR_EXCEPTION("Invalid type constraints for aggregate function %Qv", name); + }; + + Visit(ArgumentType_, + [&] (const TUnionType& unionType) { + *constraint = TTypeSet(unionType.begin(), unionType.end()); + *resultType = setType(ResultType_, false); + *stateType = setType(StateType_, false); + }, + [&] (EValueType fixedType) { + *constraint = TTypeSet({fixedType}); + *resultType = setType(ResultType_, false); + *stateType = setType(StateType_, false); + }, + [&] (TTypeParameter typeId) { + auto found = TypeParameterConstraints_.find(typeId); + if (found == TypeParameterConstraints_.end()) { + THROW_ERROR_EXCEPTION("Invalid type constraints for aggregate function %Qv", name); + } + + *constraint = TTypeSet(found->second.begin(), found->second.end()); + *resultType = setType(ResultType_, true); + *stateType = setType(StateType_, true); + }); +} + +std::pair<int, int> TAggregateFunctionTypeInferrer::GetNormalizedConstraints( + std::vector<TTypeSet>* typeConstraints, + std::vector<int>* argumentConstraintIndexes) const +{ + std::unordered_map<TTypeParameter, int> idToIndex; + + auto getIndex = [&] (const TType& type) -> int { + return Visit(type, + [&] (EValueType fixedType) -> int { + typeConstraints->push_back(TTypeSet({fixedType})); + return typeConstraints->size() - 1; + }, + [&] (TTypeParameter genericId) -> int { + auto itIndex = idToIndex.find(genericId); + if (itIndex != idToIndex.end()) { + return itIndex->second; + } else { + int index = typeConstraints->size(); + auto it = TypeParameterConstraints_.find(genericId); + if (it == TypeParameterConstraints_.end()) { + typeConstraints->push_back(TTypeSet({ + EValueType::Null, + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Any})); + } else { + typeConstraints->push_back(TTypeSet(it->second.begin(), it->second.end())); + } + idToIndex.emplace(genericId, index); + return index; + } + }, + [&] (const TUnionType& unionType) -> int { + typeConstraints->push_back(TTypeSet(unionType.begin(), unionType.end())); + return typeConstraints->size() - 1; + }); + }; + + for (const auto& argumentType : ArgumentTypes_) { + argumentConstraintIndexes->push_back(getIndex(argumentType)); + } + + return std::make_pair(getIndex(StateType_), getIndex(ResultType_)); +} + + +//////////////////////////////////////////////////////////////////////////////// + +const ITypeInferrerPtr& TTypeInferrerMap::GetFunction(const TString& functionName) const +{ + auto found = this->find(functionName); + if (found == this->end()) { + THROW_ERROR_EXCEPTION("Undefined function %Qv", + functionName); + } + return found->second; +} + +//////////////////////////////////////////////////////////////////////////////// + +bool IsUserCastFunction(const TString& name) +{ + return name == "int64" || name == "uint64" || name == "double"; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/functions.h b/yt/yt/library/query/base/functions.h new file mode 100644 index 0000000000..52ddf6c9ff --- /dev/null +++ b/yt/yt/library/query/base/functions.h @@ -0,0 +1,165 @@ +#pragma once + +#include "public.h" + +#include "key_trie.h" +#include "constraints.h" +#include "functions_common.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +class ITypeInferrer + : public virtual TRefCounted +{ +public: + template <class TDerived> + const TDerived* As() const + { + return dynamic_cast<const TDerived*>(this); + } + + template <class TDerived> + TDerived* As() + { + return dynamic_cast<TDerived*>(this); + } +}; + +DEFINE_REFCOUNTED_TYPE(ITypeInferrer) + +//////////////////////////////////////////////////////////////////////////////// + +class TFunctionTypeInferrer + : public ITypeInferrer +{ +public: + TFunctionTypeInferrer( + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType repeatedArgumentType, + TType resultType) + : TypeParameterConstraints_(std::move(typeParameterConstraints)) + , ArgumentTypes_(std::move(argumentTypes)) + , RepeatedArgumentType_(repeatedArgumentType) + , ResultType_(resultType) + { } + + TFunctionTypeInferrer( + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType resultType) + : TFunctionTypeInferrer( + std::move(typeParameterConstraints), + std::move(argumentTypes), + EValueType::Null, + resultType) + { } + + TFunctionTypeInferrer( + std::vector<TType> argumentTypes, + TType resultType) + : TFunctionTypeInferrer( + std::unordered_map<TTypeParameter, TUnionType>(), + std::move(argumentTypes), + resultType) + { } + + int GetNormalizedConstraints( + std::vector<TTypeSet>* typeConstraints, + std::vector<int>* formalArguments, + std::optional<std::pair<int, bool>>* repeatedType) const; + +private: + const std::unordered_map<TTypeParameter, TUnionType> TypeParameterConstraints_; + const std::vector<TType> ArgumentTypes_; + const TType RepeatedArgumentType_; + const TType ResultType_; +}; + +class TAggregateTypeInferrer + : public ITypeInferrer +{ +public: + TAggregateTypeInferrer( + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + TType argumentType, + TType resultType, + TType stateType) + : TypeParameterConstraints_(std::move(typeParameterConstraints)) + , ArgumentType_(argumentType) + , ResultType_(resultType) + , StateType_(stateType) + { } + + void GetNormalizedConstraints( + TTypeSet* constraint, + std::optional<EValueType>* stateType, + std::optional<EValueType>* resultType, + TStringBuf name) const; + +private: + const std::unordered_map<TTypeParameter, TUnionType> TypeParameterConstraints_; + const TType ArgumentType_; + const TType ResultType_; + const TType StateType_; +}; + +class TAggregateFunctionTypeInferrer + : public ITypeInferrer +{ +public: + TAggregateFunctionTypeInferrer( + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType stateType, + TType resultType) + : TypeParameterConstraints_(std::move(typeParameterConstraints)) + , ArgumentTypes_(std::move(argumentTypes)) + , StateType_(stateType) + , ResultType_(resultType) + { } + + std::pair<int, int> GetNormalizedConstraints( + std::vector<TTypeSet>* typeConstraints, + std::vector<int>* argumentConstraintIndexes) const; + +private: + const std::unordered_map<TTypeParameter, TUnionType> TypeParameterConstraints_; + const std::vector<TType> ArgumentTypes_; + const TType StateType_; + const TType ResultType_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +using TRangeExtractor = std::function<TKeyTriePtr( + const TConstFunctionExpressionPtr& expr, + const TKeyColumns& keyColumns, + const TRowBufferPtr& rowBuffer)>; + +using TConstraintExtractor = std::function<TConstraintRef( + TConstraintsHolder* constraints, + const TConstFunctionExpressionPtr& expr, + const TKeyColumns& keyColumns, + const TRowBufferPtr& rowBuffer)>; + +//////////////////////////////////////////////////////////////////////////////// + +struct TTypeInferrerMap + : public TRefCounted + , public std::unordered_map<TString, ITypeInferrerPtr> +{ + const ITypeInferrerPtr& GetFunction(const TString& functionName) const; +}; + +DEFINE_REFCOUNTED_TYPE(TTypeInferrerMap) + +//////////////////////////////////////////////////////////////////////////////// + +bool IsUserCastFunction(const TString& name); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/functions_builder.h b/yt/yt/library/query/base/functions_builder.h new file mode 100644 index 0000000000..25671b2f27 --- /dev/null +++ b/yt/yt/library/query/base/functions_builder.h @@ -0,0 +1,54 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/query/base/functions_common.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct IFunctionRegistryBuilder +{ + virtual ~IFunctionRegistryBuilder() = default; + + virtual void RegisterFunction( + const TString& functionName, + const TString& symbolName, + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType repeatedArgType, + TType resultType, + TStringBuf implementationFile, + ECallingConvention callingConvention, + bool useFunctionContext = false) = 0; + + virtual void RegisterFunction( + const TString& functionName, + std::vector<TType> argumentTypes, + TType resultType, + TStringBuf implementationFile, + ECallingConvention callingConvention) = 0; + + virtual void RegisterFunction( + const TString& functionName, + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + std::vector<TType> argumentTypes, + TType repeatedArgType, + TType resultType, + TStringBuf implementationFile) = 0; + + virtual void RegisterAggregate( + const TString& aggregateName, + std::unordered_map<TTypeParameter, TUnionType> typeParameterConstraints, + TType argumentType, + TType resultType, + TType stateType, + TStringBuf implementationFile, + ECallingConvention callingConvention, + bool isFirst = false) = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/functions_common.cpp b/yt/yt/library/query/base/functions_common.cpp new file mode 100644 index 0000000000..54fea31f03 --- /dev/null +++ b/yt/yt/library/query/base/functions_common.cpp @@ -0,0 +1,69 @@ +#include "functions_common.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +EValueType TTypeSet::GetFront() const +{ + YT_VERIFY(!IsEmpty()); + static const int MultiplyDeBruijnBitPosition[64] = { + 0, 1, 2, 53, 3, 7, 54, 27, 4, 38, 41, 8, 34, 55, 48, 28, + 62, 5, 39, 46, 44, 42, 22, 9, 24, 35, 59, 56, 49, 18, 29, 11, + 63, 52, 6, 26, 37, 40, 33, 47, 61, 45, 43, 21, 23, 58, 17, 10, + 51, 25, 36, 32, 60, 20, 57, 16, 50, 31, 19, 15, 30, 14, 13, 12 + }; + + return EValueType(MultiplyDeBruijnBitPosition[((Value_ & -Value_) * 0x022fdd63cc95386d) >> 58]); +} + +size_t TTypeSet::GetSize() const +{ + size_t result = 0; + ui64 mask = 1; + for (size_t index = 0; index < 8 * sizeof(ui64); ++index, mask <<= 1) { + if (Value_ & mask) { + ++result; + } + } + return result; +} + +TTypeSet operator | (const TTypeSet& lhs, const TTypeSet& rhs) +{ + return TTypeSet(lhs.Value_ | rhs.Value_); +} + +TTypeSet operator & (const TTypeSet& lhs, const TTypeSet& rhs) +{ + return TTypeSet(lhs.Value_ & rhs.Value_); +} + +void FormatValue(TStringBuilderBase* builder, const TTypeSet& typeSet, TStringBuf /*spec*/) +{ + if (typeSet.GetSize() == 1) { + builder->AppendFormat("%lv", typeSet.GetFront()); + } else { + builder->AppendString("one of {"); + bool isFirst = true; + typeSet.ForEach([&] (EValueType type) { + if (!isFirst) { + builder->AppendString(", "); + } else { + isFirst = false; + } + builder->AppendFormat("%lv", type); + + }); + builder->AppendString("}"); + } +} + +TString ToString(const TTypeSet& typeSet) +{ + return ToStringViaBuilder(typeSet); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/functions_common.h b/yt/yt/library/query/base/functions_common.h new file mode 100644 index 0000000000..ded2437af1 --- /dev/null +++ b/yt/yt/library/query/base/functions_common.h @@ -0,0 +1,108 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/table_client/row_base.h> + +namespace NYT::NQueryClient { + +using NTableClient::EValueType; + +//////////////////////////////////////////////////////////////////////////////// + +using TTypeParameter = int; +using TUnionType = std::vector<EValueType>; +using TType = std::variant<EValueType, TTypeParameter, TUnionType>; + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(ECallingConvention, + ((Simple) (0)) + ((UnversionedValue) (1)) +); + +DEFINE_ENUM(ETypeTag, + ((ConcreteType) (0)) + ((TypeParameter) (1)) + ((UnionType) (2)) +); + +//////////////////////////////////////////////////////////////////////////////// + +class TTypeSet +{ +public: + TTypeSet() + : Value_(0) + { } + + explicit TTypeSet(ui64 value) + : Value_(value) + { } + + TTypeSet(std::initializer_list<EValueType> values) + : Value_(0) + { + Assign(values.begin(), values.end()); + } + + template <class TIterator> + TTypeSet(TIterator begin, TIterator end) + : Value_(0) + { + Assign(begin, end); + } + + template <class TIterator> + void Assign(TIterator begin, TIterator end) + { + Value_ = 0; + for (; begin != end; ++begin) { + Set(*begin); + } + } + + void Set(EValueType type) + { + Value_ |= 1 << ui8(type); + } + + bool Get(EValueType type) const + { + return Value_ & (1 << ui8(type)); + } + + EValueType GetFront() const; + + bool IsEmpty() const + { + return Value_ == 0; + } + + size_t GetSize() const; + + template <class TFunctor> + void ForEach(TFunctor functor) const + { + ui64 mask = 1; + for (size_t index = 0; index < 8 * sizeof(ui64); ++index, mask <<= 1) { + if (Value_ & mask) { + functor(EValueType(index)); + } + } + } + + friend TTypeSet operator | (const TTypeSet& lhs, const TTypeSet& rhs); + friend TTypeSet operator & (const TTypeSet& lhs, const TTypeSet& rhs); + +private: + ui64 Value_ = 0; + +}; + +void FormatValue(TStringBuilderBase* builder, const TTypeSet& typeSet, TStringBuf spec); +TString ToString(const TTypeSet& typeSet); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/key_trie.cpp b/yt/yt/library/query/base/key_trie.cpp new file mode 100644 index 0000000000..7bae64a37f --- /dev/null +++ b/yt/yt/library/query/base/key_trie.cpp @@ -0,0 +1,688 @@ +#include "key_trie.h" +#include "query_helpers.h" + +#include <yt/yt/library/numeric/algorithm_helpers.h> + +#include <deque> +#include <tuple> + +namespace NYT::NQueryClient { + +using namespace NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +TKeyTriePtr ReduceKeyTrie(TKeyTriePtr keyTrie) +{ + // TODO(lukyan): If keyTrie is too big, reduce its size + return keyTrie; +} + +struct TKeyTrieComparer +{ + bool operator () (const std::pair<TValue, TKeyTriePtr>& element, TValue pivot) const + { + return element.first < pivot; + } + + bool operator () (TValue pivot, const std::pair<TValue, TKeyTriePtr>& element) const + { + return pivot < element.first; + } + + bool operator () (const std::pair<TValue, TKeyTriePtr>& lhs, const std::pair<TValue, TKeyTriePtr>& rhs) const + { + return lhs.first < rhs.first; + } +}; + +int CompareBound(const TBound& lhs, const TBound& rhs, bool lhsDir, bool rhsDir) +{ + auto rank = [] (bool direction, bool included) { + // < - (false, fasle) + // > - (true, false) + // <= - (false, true) + // >= - (true, true) + + // (< x) < (>= x) < (<= x) < (> x) + return (included ? -1 : 2) * (direction ? 1 : -1); + }; + + int result = CompareRowValues(lhs.Value, rhs.Value); + return result == 0 + ? rank(lhsDir, lhs.Included) - rank(rhsDir, rhs.Included) + : result; +} + +template <class TEachCallback> +void MergeBounds(const std::vector<TBound>& lhs, const std::vector<TBound>& rhs, TEachCallback eachCallback) +{ + auto first = lhs.begin(); + auto second = rhs.begin(); + + bool firstIsOpen = true; + bool secondIsOpen = true; + + while (first != lhs.end() && second != rhs.end()) { + if (CompareBound(*first, *second, firstIsOpen, secondIsOpen) < 0) { + eachCallback(*first, firstIsOpen); + ++first; + firstIsOpen = !firstIsOpen; + } else { + eachCallback(*second, secondIsOpen); + ++second; + secondIsOpen = !secondIsOpen; + } + } + + while (first != lhs.end()) { + eachCallback(*first, firstIsOpen); + ++first; + firstIsOpen = !firstIsOpen; + } + + while (second != rhs.end()) { + eachCallback(*second, secondIsOpen); + ++second; + secondIsOpen = !secondIsOpen; + } +} + +std::vector<TBound> UniteBounds(const std::vector<TBound>& lhs, const std::vector<TBound>& rhs) +{ + int cover = 0; + std::vector<TBound> result; + bool resultIsOpen = false; + + MergeBounds(lhs, rhs, [&] (TBound bound, bool isOpen) { + if ((isOpen ? cover++ : --cover) == 0) { + if (result.empty() || !(result.back() == bound && isOpen == resultIsOpen)) { + result.push_back(bound); + resultIsOpen = !resultIsOpen; + } + } + }); + + return result; +} + +void UniteBounds(std::vector<std::vector<TBound>>* bounds) +{ + while (bounds->size() > 1) { + size_t i = 0; + while (2 * i + 1 < bounds->size()) { + (*bounds)[i] = UniteBounds((*bounds)[2 * i], (*bounds)[2 * i + 1]); + ++i; + } + if (2 * i < bounds->size()) { + (*bounds)[i] = (*bounds)[2 * i]; + ++i; + } + bounds->resize(i); + } +} + +std::vector<TBound> IntersectBounds(const std::vector<TBound>& lhs, const std::vector<TBound>& rhs) +{ + int cover = 0; + std::vector<TBound> result; + bool resultIsOpen = false; + + MergeBounds(lhs, rhs, [&] (TBound bound, bool isOpen) { + if ((isOpen ? cover++ : --cover) == 1) { + if (result.empty() || !(result.back() == bound && isOpen == resultIsOpen)) { + result.push_back(bound); + resultIsOpen = !resultIsOpen; + } + } + }); + + return result; +} + +TKeyTriePtr UniteKeyTrie(const std::vector<TKeyTriePtr>& tries) +{ + if (tries.empty()) { + return TKeyTrie::Empty(); + } else if (tries.size() == 1) { + return tries.front(); + } + + std::vector<TKeyTriePtr> maxTries; + size_t offset = 0; + for (const auto& trie : tries) { + if (!trie) { + return TKeyTrie::Universal(); + } + + if (trie->Offset > offset) { + maxTries.clear(); + offset = trie->Offset; + } + + if (trie->Offset == offset) { + maxTries.push_back(trie); + } + } + + std::vector<std::pair<TValue, TKeyTriePtr>> groups; + for (const auto& trie : maxTries) { + for (auto& next : trie->Next) { + groups.push_back(std::move(next)); + } + } + + std::sort(groups.begin(), groups.end(), TKeyTrieComparer()); + + auto result = New<TKeyTrie>(offset); + std::vector<TKeyTriePtr> unique; + + auto it = groups.begin(); + auto end = groups.end(); + while (it != end) { + unique.clear(); + auto same = it; + for (; same != end && same->first == it->first; ++same) { + unique.push_back(same->second); + } + result->Next.emplace_back(it->first, UniteKeyTrie(unique)); + it = same; + } + + std::vector<std::vector<TBound>> bounds; + for (const auto& trie : maxTries) { + if (!trie->Bounds.empty()) { + bounds.push_back(std::move(trie->Bounds)); + } + } + + UniteBounds(&bounds); + + YT_VERIFY(bounds.size() <= 1); + if (!bounds.empty()) { + std::vector<TBound> deletedPoints; + + deletedPoints.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + for (const auto& next : result->Next) { + deletedPoints.emplace_back(next.first, false); + deletedPoints.emplace_back(next.first, false); + } + deletedPoints.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + + result->Bounds = IntersectBounds(bounds.front(), deletedPoints); + } + + return result; +} + +TKeyTriePtr UniteKeyTrie(TKeyTriePtr lhs, TKeyTriePtr rhs) +{ + return UniteKeyTrie({lhs, rhs}); +} + +bool Covers(const std::vector<TBound>& bounds, const TValue& point) +{ + YT_VERIFY(!(bounds.size() & 1)); + + auto index = BinarySearch( + 0, + bounds.size() / 2, + [&] (int index) -> bool { + const auto& bound = bounds[index * 2 + 1]; + return bound.Value == point ? !bound.Included : bound.Value < point; + }); + + if (index < bounds.size() / 2) { + const auto& bound = bounds[index * 2]; + return bound.Value == point ? bound.Included : bound.Value < point; + } else { + return false; + } +} + +TKeyTriePtr IntersectKeyTrie(TKeyTriePtr lhs, TKeyTriePtr rhs) +{ + auto lhsOffset = lhs ? lhs->Offset : std::numeric_limits<size_t>::max(); + auto rhsOffset = rhs ? rhs->Offset : std::numeric_limits<size_t>::max(); + + if (lhsOffset < rhsOffset) { + auto result = New<TKeyTrie>(*lhs); + for (auto& next : result->Next) { + next.second = IntersectKeyTrie(next.second, rhs); + } + return result; + } + + if (lhsOffset > rhsOffset) { + auto result = New<TKeyTrie>(*rhs); + for (auto& next : result->Next) { + next.second = IntersectKeyTrie(next.second, lhs); + } + return result; + } + + if (!lhs && !rhs) { + return nullptr; + } + + YT_VERIFY(lhs); + YT_VERIFY(rhs); + + auto result = New<TKeyTrie>(lhs->Offset); + result->Bounds = IntersectBounds(lhs->Bounds, rhs->Bounds); + + // Iterate through resulting bounds and convert singleton ranges into + // new edges in the trie. This enables futher range limiting. + auto it = result->Bounds.begin(); + auto jt = result->Bounds.begin(); + auto kt = result->Bounds.end(); + while (it < kt) { + const auto& lhs = *it++; + const auto& rhs = *it++; + if (lhs == rhs) { + result->Next.emplace_back(lhs.Value, TKeyTrie::Universal()); + } else { + if (std::distance(jt, it) > 2) { + *jt++ = lhs; + *jt++ = rhs; + } else { + ++jt; ++jt; + } + } + } + + result->Bounds.erase(jt, kt); + + for (const auto& next : lhs->Next) { + if (Covers(rhs->Bounds, next.first)) { + result->Next.push_back(next); + } + } + + for (const auto& next : rhs->Next) { + if (Covers(lhs->Bounds, next.first)) { + result->Next.push_back(next); + } + } + + for (const auto& next : lhs->Next) { + auto eq = std::equal_range(rhs->Next.begin(), rhs->Next.end(), next.first, TKeyTrieComparer()); + if (eq.first != eq.second) { + result->Next.emplace_back(next.first, IntersectKeyTrie(eq.first->second, next.second)); + } + } + + std::sort(result->Next.begin(), result->Next.end(), TKeyTrieComparer()); + return result; +} + +void GetRangesFromTrieWithinRangeImpl( + const TRowRange& keyRange, + TKeyTriePtr trie, + std::vector<TMutableRowRange>* result, + TRowBufferPtr rowBuffer, + bool insertUndefined, + ui64 rangeCountLimit, + std::vector<TValue> prefix = std::vector<TValue>(), + bool refineLower = true, + bool refineUpper = true) +{ + auto lowerBoundSize = keyRange.first.GetCount(); + auto upperBoundSize = keyRange.second.GetCount(); + + struct TState + { + TKeyTriePtr Trie; + std::vector<TValue> Prefix; + bool RefineLower; + bool RefineUpper; + }; + + std::vector<std::tuple<TBound, bool>> resultBounds; + std::vector<std::tuple<TValue, TKeyTriePtr, bool, bool>> nextValues; + + std::deque<TState> states; + states.push_back(TState{trie, prefix, refineLower, refineUpper}); + + while (!states.empty()) { + auto state = std::move(states.front()); + states.pop_front(); + const auto& trie = state.Trie; + auto prefix = std::move(state.Prefix); + auto refineLower = state.RefineLower; + auto refineUpper = state.RefineUpper; + + size_t offset = prefix.size(); + + if (offset >= lowerBoundSize) { + refineLower = false; + } + + if (refineUpper && offset >= upperBoundSize) { + // NB: prefix is exactly the upper bound, which is non-inlusive. + continue; + } + + YT_VERIFY(!refineLower || offset < lowerBoundSize); + YT_VERIFY(!refineUpper || offset < upperBoundSize); + + TUnversionedRowBuilder builder(offset); + + auto trieOffset = trie ? trie->Offset : std::numeric_limits<size_t>::max(); + + auto makeValue = [] (TUnversionedValue value, int id) { + value.Id = id; + return value; + }; + + if (trieOffset > offset) { + if (refineLower && refineUpper && keyRange.first[offset] == keyRange.second[offset]) { + prefix.emplace_back(keyRange.first[offset]); + states.push_back(TState{trie, std::move(prefix), true, true}); + } else if (trie && insertUndefined) { + prefix.emplace_back(MakeUnversionedSentinelValue(EValueType::TheBottom)); + states.push_back(TState{trie, std::move(prefix), false, false}); + } else { + TMutableRowRange range; + for (size_t i = 0; i < offset; ++i) { + builder.AddValue(makeValue(prefix[i], i)); + } + + if (refineLower) { + for (size_t i = offset; i < lowerBoundSize; ++i) { + builder.AddValue(makeValue(keyRange.first[i], i)); + } + } + range.first = rowBuffer->CaptureRow(builder.GetRow()); + builder.Reset(); + + for (size_t i = 0; i < offset; ++i) { + builder.AddValue(makeValue(prefix[i], i)); + } + + if (refineUpper) { + for (size_t i = offset; i < upperBoundSize; ++i) { + builder.AddValue(makeValue(keyRange.second[i], i)); + } + } else { + builder.AddValue(MakeUnversionedSentinelValue(EValueType::Max)); + } + range.second = rowBuffer->CaptureRow(builder.GetRow()); + builder.Reset(); + + if (insertUndefined || !IsEmpty(range)) { + result->push_back(range); + } + } + continue; + } + + YT_VERIFY(trie); + YT_VERIFY(trie->Offset == offset); + + YT_VERIFY(!(trie->Bounds.size() & 1)); + + resultBounds.clear(); + resultBounds.reserve(trie->Bounds.size()); + + for (size_t i = 0; i + 1 < trie->Bounds.size(); i += 2) { + auto lower = trie->Bounds[i]; + auto upper = trie->Bounds[i + 1]; + + YT_VERIFY(CompareBound(lower, upper, true, false) < 0); + + bool lowerBoundRefined = false; + bool upperBoundRefined = false; + + if (offset < lowerBoundSize) { + auto keyRangeLowerBound = TBound(keyRange.first[offset], true); + if (CompareBound(upper, keyRangeLowerBound, false, true) < 0) { + continue; + } else if (refineLower && CompareBound(lower, keyRangeLowerBound, true, true) <= 0) { + lowerBoundRefined = true; + } + } + + if (offset < upperBoundSize) { + auto keyRangeUpperBound = TBound(keyRange.second[offset], offset + 1 < upperBoundSize); + if (CompareBound(lower, keyRangeUpperBound, true, false) > 0) { + continue; + } else if (refineUpper && CompareBound(upper, keyRangeUpperBound, false, false) >= 0) { + upperBoundRefined = true; + } + } + + resultBounds.emplace_back(lower, lowerBoundRefined); + resultBounds.emplace_back(upper, upperBoundRefined); + } + + nextValues.clear(); + nextValues.reserve(trie->Next.size()); + + for (const auto& next : trie->Next) { + auto value = next.first; + + bool refineLowerNext = false; + bool refineUpperNext = false; + + if (refineLower) { + if (value < keyRange.first[offset]) { + continue; + } else if (value == keyRange.first[offset]) { + refineLowerNext = true; + } + } + + if (refineUpper) { + if (value > keyRange.second[offset]) { + continue; + } else if (value == keyRange.second[offset]) { + refineUpperNext = true; + } + } + + nextValues.emplace_back(value, next.second, refineLowerNext, refineUpperNext); + } + + ui64 subrangeCount = resultBounds.size() / 2 + nextValues.size(); + + if (subrangeCount > rangeCountLimit) { + auto min = TBound(MakeUnversionedSentinelValue(EValueType::Max), false); + auto max = TBound(MakeUnversionedSentinelValue(EValueType::Min), true); + + auto updateMinMax = [&] (const TBound& lower, const TBound& upper) { + if (CompareBound(lower, min, true, true) < 0) { + min = lower; + } + if (CompareBound(upper, max, false, false) > 0) { + max = upper; + } + }; + + for (size_t i = 0; i + 1 < resultBounds.size(); i += 2) { + auto lower = std::get<0>(resultBounds[i]); + auto upper = std::get<0>(resultBounds[i + 1]); + updateMinMax(lower, upper); + } + + for (const auto& next : nextValues) { + auto value = TBound(std::get<0>(next), true); + updateMinMax(value, value); + } + + TMutableRowRange range; + + for (size_t j = 0; j < offset; ++j) { + builder.AddValue(makeValue(prefix[j], j)); + } + + if (refineLower && min.Included && min.Value == keyRange.first[offset]) { + for (size_t j = offset; j < lowerBoundSize; ++j) { + builder.AddValue(makeValue(keyRange.first[j], j)); + } + } else { + builder.AddValue(makeValue(min.Value, offset)); + + if (!min.Included) { + builder.AddValue(MakeUnversionedSentinelValue(EValueType::Max)); + } + } + + range.first = rowBuffer->CaptureRow(builder.GetRow()); + builder.Reset(); + + for (size_t j = 0; j < offset; ++j) { + builder.AddValue(makeValue(prefix[j], j)); + } + + if (refineUpper && max.Included && max.Value == keyRange.second[offset]) { + for (size_t j = offset; j < upperBoundSize; ++j) { + builder.AddValue(makeValue(keyRange.second[j], j)); + } + } else { + builder.AddValue(makeValue(max.Value, offset)); + + if (max.Included) { + builder.AddValue(MakeUnversionedSentinelValue(EValueType::Max)); + } + } + + range.second = rowBuffer->CaptureRow(builder.GetRow()); + builder.Reset(); + result->push_back(range); + + continue; + } + + rangeCountLimit -= subrangeCount; + + for (size_t i = 0; i + 1 < resultBounds.size(); i += 2) { + auto lower = std::get<0>(resultBounds[i]); + auto upper = std::get<0>(resultBounds[i + 1]); + bool lowerBoundRefined = std::get<1>(resultBounds[i]); + bool upperBoundRefined = std::get<1>(resultBounds[i + 1]); + + TMutableRowRange range; + for (size_t j = 0; j < offset; ++j) { + builder.AddValue(makeValue(prefix[j], j)); + } + + if (lowerBoundRefined) { + for (size_t j = offset; j < lowerBoundSize; ++j) { + builder.AddValue(makeValue(keyRange.first[j], j)); + } + } else { + builder.AddValue(makeValue(lower.Value, offset)); + + if (!lower.Included) { + builder.AddValue(MakeUnversionedSentinelValue(EValueType::Max)); + } + } + + range.first = rowBuffer->CaptureRow(builder.GetRow()); + builder.Reset(); + + for (size_t j = 0; j < offset; ++j) { + builder.AddValue(makeValue(prefix[j], j)); + } + + if (upperBoundRefined) { + for (size_t j = offset; j < upperBoundSize; ++j) { + builder.AddValue(makeValue(keyRange.second[j], j)); + } + } else { + builder.AddValue(makeValue(upper.Value, offset)); + + if (upper.Included) { + builder.AddValue(MakeUnversionedSentinelValue(EValueType::Max)); + } + } + + range.second = rowBuffer->CaptureRow(builder.GetRow()); + builder.Reset(); + result->push_back(range); + } + + prefix.emplace_back(); + + for (const auto& next : nextValues) { + auto value = std::get<0>(next); + auto trie = std::get<1>(next); + bool refineLowerNext = std::get<2>(next); + bool refineUpperNext = std::get<3>(next); + prefix.back() = value; + states.push_back(TState{trie, prefix, refineLowerNext, refineUpperNext}); + } + } +} + +TMutableRowRanges GetRangesFromTrieWithinRange( + const TRowRange& keyRange, + TKeyTriePtr trie, + TRowBufferPtr rowBuffer, + bool insertUndefined, + ui64 rangeCountLimit) +{ + TMutableRowRanges result; + GetRangesFromTrieWithinRangeImpl(keyRange, trie, &result, rowBuffer, insertUndefined, rangeCountLimit); + + if (insertUndefined) { + return result; + } + + std::sort(result.begin(), result.end()); + result.erase(MergeOverlappingRanges(result.begin(), result.end()), result.end()); + return result; +} + +TString ToString(TKeyTriePtr node) { + auto printOffset = [](int offset) { + TString str; + for (int i = 0; i < offset; ++i) { + str += " "; + } + return str; + }; + + std::function<TString(TKeyTriePtr, size_t)> printNode = + [&] (TKeyTriePtr node, size_t offset) { + TString str; + str += printOffset(offset); + + if (!node) { + str += "(universe)"; + } else { + str += "(key"; + str += NYT::ToString(node->Offset); + str += ", { "; + + for (int i = 0; i < std::ssize(node->Bounds); i += 2) { + str += node->Bounds[i].Included ? "[" : "("; + str += Format("%k", node->Bounds[i].Value); + str += ":"; + str += Format("%k", node->Bounds[i+1].Value); + str += node->Bounds[i+1].Included ? "]" : ")"; + if (i + 2 < std::ssize(node->Bounds)) { + str += ", "; + } + } + + str += " })"; + + for (const auto& next : node->Next) { + str += "\n"; + str += printOffset(node->Offset); + str += Format("%k", next.first); + str += ":\n"; + str += printNode(next.second, offset + 1); + } + } + return str; + }; + + return printNode(node, 0); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/key_trie.h b/yt/yt/library/query/base/key_trie.h new file mode 100644 index 0000000000..7fc046cbd4 --- /dev/null +++ b/yt/yt/library/query/base/key_trie.h @@ -0,0 +1,99 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/table_client/row_buffer.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct TBound +{ + TValue Value; + bool Included; + + TBound( + TValue value, + bool included) + : Value(value) + , Included(included) + { } + + bool operator == (const TBound& other) const + { + return Value == other.Value + && Included == other.Included; + } + + bool operator != (const TBound& other) const + { + return !(*this == other); + } + +}; + +void UniteBounds(std::vector<std::vector<TBound>>* bounds); +int CompareBound(const TBound& lhs, const TBound& rhs, bool lhsDir, bool rhsDir); + +bool Covers(const std::vector<TBound>& bounds, const TValue& point); + +std::vector<TBound> IntersectBounds( + const std::vector<TBound>& lhs, + const std::vector<TBound>& rhs); + +DECLARE_REFCOUNTED_STRUCT(TKeyTrie) + +struct TKeyTrie + : public TRefCounted +{ + size_t Offset = 0; + + std::vector<std::pair<TValue, TKeyTriePtr>> Next; // TODO: rename to Following + std::vector<TBound> Bounds; + + TKeyTrie(size_t offset) + : Offset(offset) + { } + + TKeyTrie(const TKeyTrie& other) + : Offset(other.Offset) + , Next(other.Next) + , Bounds(other.Bounds) + { } + + TKeyTrie(TKeyTrie&&) = default; + + TKeyTrie& operator=(const TKeyTrie&) = default; + TKeyTrie& operator=(TKeyTrie&&) = default; + + static TKeyTriePtr Empty() + { + return New<TKeyTrie>(0); + } + + static TKeyTriePtr Universal() + { + return nullptr; + } + + friend TKeyTriePtr UniteKeyTrie(TKeyTriePtr lhs, TKeyTriePtr rhs); + friend TKeyTriePtr UniteKeyTrie(const std::vector<TKeyTriePtr>& tries); + friend TKeyTriePtr IntersectKeyTrie(TKeyTriePtr lhs, TKeyTriePtr rhs); +}; + +DEFINE_REFCOUNTED_TYPE(TKeyTrie) + +TMutableRowRanges GetRangesFromTrieWithinRange( + const TRowRange& keyRange, + TKeyTriePtr trie, + TRowBufferPtr rowBuffer, + bool insertUndefined = false, + ui64 rangeCountLimit = std::numeric_limits<ui64>::max()); + +TString ToString(TKeyTriePtr node); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/lexer.h b/yt/yt/library/query/base/lexer.h new file mode 100644 index 0000000000..d882be186f --- /dev/null +++ b/yt/yt/library/query/base/lexer.h @@ -0,0 +1,85 @@ +#pragma once + +#include "ast.h" +#include "parser.h" + +namespace NYT::NQueryClient::NAst { + +//////////////////////////////////////////////////////////////////////////////// + +class TBaseLexer +{ +public: + TBaseLexer( + const TString& source, + TParser::token_type strayToken); + + TParser::token_type GetNextToken( + TParser::semantic_type* yyval, + TParser::location_type* yyloc); + +private: + void Initialize(const char* begin, const char* end); + +private: + TParser::token_type StrayToken_; + bool InjectedStrayToken_; + + // Ragel state variables. + // See Ragel User Manual for host interface specification. + const char* p; + const char* pe; + const char* ts; + const char* te; + const char* eof; + int cs; + int act; + + // Saves embedded chunk boundaries and embedding depth. + const char* rs; + const char* re; + int rd; + + // Saves beginning-of-string boundary to compute locations. + const char* s; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TLexer +{ +public: + TLexer( + const TString& source, + TParser::token_type strayToken, + THashMap<TString, TString> placeholderValues); + + TParser::token_type GetNextToken( + TParser::semantic_type* yyval, + TParser::location_type* yyloc); + +private: + struct TPlaceholderLexerData + { + TBaseLexer Lexer; + TParser::location_type Location; + }; + + TBaseLexer QueryLexer_; + std::optional<TPlaceholderLexerData> Placeholder_; + + THashMap<TString, TString> PlaceholderValues_; + + std::optional<TParser::token_type> GetNextTokenFromPlaceholder( + TParser::semantic_type* yyval, + TParser::location_type* yyloc); + + void SetPlaceholder( + TParser::semantic_type* yyval, + TParser::location_type* yyloc); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient::NAst + diff --git a/yt/yt/library/query/base/lexer.rl6 b/yt/yt/library/query/base/lexer.rl6 new file mode 100644 index 0000000000..f85efc0453 --- /dev/null +++ b/yt/yt/library/query/base/lexer.rl6 @@ -0,0 +1,336 @@ +#include <yt/yt/core/misc/finally.h> + +#include <yt/yt/library/query/base/lexer.h> + +#include <util/system/defaults.h> +#include <util/string/cast.h> +#include <util/string/escape.h> + +namespace NYT { +namespace NQueryClient { +namespace NAst { + +//////////////////////////////////////////////////////////////////////////////// + +typedef TParser::token_type TToken; + +//////////////////////////////////////////////////////////////////////////////// + +%%{ + machine Lexer; + alphtype char; + + end = 0; + all = ^0; + wss = space+; + + kw_select = 'select'i; + kw_from = 'from'i; + kw_where = 'where'i; + kw_having = 'having'i; + kw_offset = 'offset'i; + kw_limit = 'limit'i; + kw_join = 'join'i; + kw_using = 'using'i; + kw_group_by = 'group'i wss 'by'i; + kw_with_totals = 'with'i wss 'totals'i; + kw_order_by = 'order'i wss 'by'i; + kw_asc = 'asc'i; + kw_desc = 'desc'i; + kw_left = 'left'i; + kw_as = 'as'i; + kw_on = 'on'i; + kw_and = 'and'i; + kw_or = 'or'i; + kw_is = 'is'i; + kw_not = 'not'i; + kw_null = 'null'i; + kw_between = 'between'i; + kw_in = 'in'i; + kw_transform = 'transform'i; + kw_false = 'false'i; + kw_true = 'true'i; + kw_yson_false = '%false'i; + kw_yson_true = '%true'i; + + keyword = kw_select | kw_from | kw_where | kw_having | kw_offset | kw_limit | kw_join | kw_using | kw_group_by + | kw_with_totals | kw_order_by | kw_asc | kw_desc | kw_left | kw_as | kw_on | kw_and | kw_or | kw_is | kw_not + | kw_null | kw_between | kw_in | kw_transform | kw_false | kw_true | kw_yson_false | kw_yson_true; + + identifier = [a-zA-Z_][a-zA-Z_0-9]* - keyword; + + fltexp = [Ee] [+\-]? digit+; + fltdot = (digit* '.' digit+) | (digit+ '.' digit*); + + int64_literal = digit+; + uint64_literal = digit+ 'u'; + double_literal = fltdot fltexp?; + single_quoted_string = "'" ( [^'\\] | /\\./ )* "'"; + double_quoted_string = '"' ( [^"\\] | /\\./ )* '"'; + string_literal = single_quoted_string | double_quoted_string; + placeholder_literal = "{" ( [a-zA-Z_][a-zA-Z_0-9]* ) "}"; + + backtick_quoted_identifier = "`" ( [^`\\] | /\\./ )* "`"; + + square_bracket_quoted_identifier := |* + '[' => { + if (++rd == 1) { + rs = fpc + 1; + } + }; + ']' => { + if (--rd == 0) { + re = fpc; + type = TToken::Identifier; + value->build(TString(rs, re - rs)); + fnext main; + fbreak; + } + }; + all; + *|; + + main := |* + + kw_select => { type = TToken::KwSelect; fbreak; }; + kw_from => { type = TToken::KwFrom; fbreak; }; + kw_where => { type = TToken::KwWhere; fbreak; }; + kw_having => { type = TToken::KwHaving; fbreak; }; + kw_offset => { type = TToken::KwOffset; fbreak; }; + kw_limit => { type = TToken::KwLimit; fbreak; }; + kw_join => { type = TToken::KwJoin; fbreak; }; + kw_using => { type = TToken::KwUsing; fbreak; }; + kw_group_by => { type = TToken::KwGroupBy; fbreak; }; + kw_with_totals => { type = TToken::KwWithTotals; fbreak; }; + kw_order_by => { type = TToken::KwOrderBy; fbreak; }; + kw_asc => { type = TToken::KwAsc; fbreak; }; + kw_desc => { type = TToken::KwDesc; fbreak; }; + kw_left => { type = TToken::KwLeft; fbreak; }; + kw_as => { type = TToken::KwAs; fbreak; }; + kw_on => { type = TToken::KwOn; fbreak; }; + kw_and => { type = TToken::KwAnd; fbreak; }; + kw_or => { type = TToken::KwOr; fbreak; }; + kw_is => { type = TToken::KwIs; fbreak; }; + kw_not => { type = TToken::KwNot; fbreak; }; + kw_null => { type = TToken::KwNull; fbreak; }; + kw_between => { type = TToken::KwBetween; fbreak; }; + kw_in => { type = TToken::KwIn; fbreak; }; + kw_transform => { type = TToken::KwTransform; fbreak; }; + kw_false => { type = TToken::KwFalse; fbreak; }; + kw_true => { type = TToken::KwTrue; fbreak; }; + kw_yson_false => { type = TToken::KwFalse; fbreak; }; + kw_yson_true => { type = TToken::KwTrue; fbreak; }; + + identifier => { + type = TToken::Identifier; + value->build(TString(ts, te - ts)); + fbreak; + }; + int64_literal => { + type = TToken::Int64Literal; + value->build(FromString<ui64>(ts, te - ts)); + fbreak; + }; + uint64_literal => { + type = TToken::Uint64Literal; + value->build(FromString<ui64>(ts, te - ts - 1)); + fbreak; + }; + double_literal => { + type = TToken::DoubleLiteral; + value->build(FromString<double>(ts, te - ts)); + fbreak; + }; + string_literal => { + type = TToken::StringLiteral; + value->build(UnescapeC(ts + 1, te - ts - 2)); + fbreak; + }; + placeholder_literal => { + type = TToken::PlaceholderLiteral; + value->build(TString(ts + 1, te - ts - 2)); + fbreak; + }; + + backtick_quoted_identifier => { + type = TToken::Identifier; + value->build(UnescapeC(ts + 1, te - ts - 2)); + fbreak; + }; + + '[' => { + fhold; + fgoto square_bracket_quoted_identifier; + }; + ']' => { + THROW_ERROR_EXCEPTION("Unexpected symbol \"]\" at position %v", ts - p); + }; + + '<=' => { type = TToken::OpLessOrEqual; fbreak; }; + '>=' => { type = TToken::OpGreaterOrEqual; fbreak; }; + '!=' => { type = TToken::OpNotEqualCStyle; fbreak; }; + '<>' => { type = TToken::OpNotEqualSql92; fbreak; }; + '<<' => { type = TToken::OpLeftShift; fbreak; }; + '>>' => { type = TToken::OpRightShift; fbreak; }; + '||' => { type = TToken::OpConcatenate; fbreak; }; + + # Single-character tokens. + [()*,<=>+-/%.|&~#] => { + type = static_cast<TToken>(fc); + fbreak; + }; + + end => { type = TToken::End; fbreak; }; + + # Advance location pointers when skipping whitespace. + wss => { location->first = te - s; }; + *|; + +}%% + +namespace { +%% write data; +} // namespace anonymous + +TBaseLexer::TBaseLexer( + const TString& source, + TParser::token_type strayToken) + : StrayToken_(strayToken) + , InjectedStrayToken_(false) + , p(nullptr) + , pe(nullptr) + , eof(nullptr) + , rs(nullptr) + , re(nullptr) + , rd(0) + , s(nullptr) +{ + Initialize(source.c_str(), source.c_str() + source.length()); +} + +void TBaseLexer::Initialize(const char* begin, const char* end) +{ + p = s = begin; + pe = eof = end; + + rs = re = nullptr; + rd = 0; + + %% write init; +} + +TParser::token_type TBaseLexer::GetNextToken( + TParser::semantic_type* value, + TParser::location_type* location) +{ + if (!InjectedStrayToken_) { + InjectedStrayToken_ = true; + location->first = 0; + location->second = 0; + return StrayToken_; + } + + TParser::token_type type = TToken::End; + + location->first = p - s; + %% write exec; + location->second = p - s; + + if (cs == %%{ write error; }%%) { + // TODO(sandello): Handle lexer failures. + return TToken::Failure; + } else { + return type; + } +} + +TLexer::TLexer( + const TString& source, + TParser::token_type strayToken, + THashMap<TString, TString> placeholderValues) + : QueryLexer_(source, strayToken) + , PlaceholderValues_(std::move(placeholderValues)) +{ } + +std::optional<TParser::token_type> TLexer::GetNextTokenFromPlaceholder( + TParser::semantic_type* value, + TParser::location_type* location) +{ + const auto token = Placeholder_->Lexer.GetNextToken(value, location); + if (token == TToken::PlaceholderLiteral) { + THROW_ERROR_EXCEPTION("Unexpected placeholder inside of another placeholder"); + } + + if (token == TToken::End) { + Placeholder_ = std::nullopt; + return std::nullopt; + } + + *location = Placeholder_->Location; + return token; +} + +void TLexer::SetPlaceholder( + TParser::semantic_type* value, + TParser::location_type* location) +{ + const TString* placeholderValue = nullptr; + { + auto finally = Finally([&] () { + value->destroy<TString>(); + }); + + const auto& placeholderName = value->as<TString>(); + + const auto it = PlaceholderValues_.find(placeholderName); + if (it == PlaceholderValues_.end()) { + THROW_ERROR_EXCEPTION("Placeholder was not found") + << TErrorAttribute("name", placeholderName); + } + + placeholderValue = &it->second; + } + + Placeholder_ = {TBaseLexer{*placeholderValue, TToken::StrayWillParseExpression}, *location}; + + const auto token = Placeholder_->Lexer.GetNextToken(value, location); + if (token != TToken::StrayWillParseExpression) { + THROW_ERROR_EXCEPTION("First placeholder token has to be stray"); + } +} + +TParser::token_type TLexer::GetNextToken( + TParser::semantic_type* value, + TParser::location_type* location) +{ + if (Placeholder_) { + const auto tokenFromPlaceholder = GetNextTokenFromPlaceholder(value, location); + if (tokenFromPlaceholder) { + return tokenFromPlaceholder.value(); + } + } + + auto tokenFromQuery = QueryLexer_.GetNextToken(value, location); + + if (tokenFromQuery == TToken::PlaceholderLiteral) { + SetPlaceholder(value, location); + + const auto tokenFromPlaceholder = GetNextTokenFromPlaceholder(value, location); + if (!tokenFromPlaceholder) { + THROW_ERROR_EXCEPTION("Placeholder should not be empty"); + } else if (tokenFromPlaceholder == TToken::PlaceholderLiteral) { + THROW_ERROR_EXCEPTION("Unexpected placeholder inside of another placeholder"); + } + + return tokenFromPlaceholder.value(); + } + + return tokenFromQuery; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NAst +} // namespace NQueryClient +} // namespace NYT + diff --git a/yt/yt/library/query/base/parser.ypp b/yt/yt/library/query/base/parser.ypp new file mode 100644 index 0000000000..83173ce191 --- /dev/null +++ b/yt/yt/library/query/base/parser.ypp @@ -0,0 +1,770 @@ +%skeleton "lalr1.cc" +%require "3.0" +%language "C++" + +%define api.namespace {NYT::NQueryClient::NAst} +%define api.prefix {yt_ql_yy} +%define api.value.type variant +%define api.location.type {TSourceLocation} +%define parser_class_name {TParser} +%define parse.error verbose + +%defines +%locations + +%parse-param {TLexer& lexer} +%parse-param {TAstHead* head} +%parse-param {const TString& source} + +%code requires { + #pragma GCC diagnostic ignored "-Wimplicit-fallthrough" + + #include <yt/yt/library/query/base/ast.h> + + namespace NYT { namespace NQueryClient { namespace NAst { + using namespace NTableClient; + + class TLexer; + class TParser; + } } } +} + +%code { + #include <yt/yt/library/query/base/lexer.h> + + #define yt_ql_yylex lexer.GetNextToken + + #ifndef YYLLOC_DEFAULT + #define YYLLOC_DEFAULT(Current, Rhs, N) \ + do { \ + if (N) { \ + (Current).first = YYRHSLOC(Rhs, 1).first; \ + (Current).second = YYRHSLOC (Rhs, N).second; \ + } else { \ + (Current).first = (Current).second = YYRHSLOC(Rhs, 0).second; \ + } \ + } while (false) + #endif +} + +// Special stray tokens to control parser flow. + +// NB: Enumerate stray tokens in decreasing order, e. g. 999, 998, and so on +// so that actual tokens won't change their identifiers. +// NB: And keep one-character tokens consistent with their ASCII codes +// to simplify lexing. + +%token End 0 "end of stream" +%token Failure 256 "lexer failure" + +%token StrayWillParseQuery 999 +%token StrayWillParseJobQuery 998 +%token StrayWillParseExpression 997 + +// Language tokens. + +%token KwSelect "keyword `SELECT`" +%token KwFrom "keyword `FROM`" +%token KwWhere "keyword `WHERE`" +%token KwHaving "keyword `HAVING`" +%token KwOffset "keyword `OFFSET`" +%token KwLimit "keyword `LIMIT`" +%token KwJoin "keyword `JOIN`" +%token KwUsing "keyword `USING`" +%token KwGroupBy "keyword `GROUP BY`" +%token KwWithTotals "keyword `WITH TOTALS`" +%token KwOrderBy "keyword `ORDER BY`" +%token KwAsc "keyword `ASC`" +%token KwDesc "keyword `DESC`" +%token KwLeft "keyword `LEFT`" +%token KwAs "keyword `AS`" +%token KwOn "keyword `ON`" + +%token KwAnd "keyword `AND`" +%token KwOr "keyword `OR`" +%token KwIs "keyword `IS`" +%token KwNot "keyword `NOT`" +%token KwNull "keyword `NULL`" +%token KwBetween "keyword `BETWEEN`" +%token KwIn "keyword `IN`" +%token KwTransform "keyword `TRANSFORM`" + +%token KwFalse "keyword `TRUE`" +%token KwTrue "keyword `FALSE`" + +%token <TString> Identifier "identifier" + +%token <i64> Int64Literal "int64 literal" +%token <ui64> Uint64Literal "uint64 literal" +%token <double> DoubleLiteral "double literal" +%token <TString> StringLiteral "string literal" +%token <TString> PlaceholderLiteral "placeholder literal" + + + +%token OpTilde 126 "`~`" +%token OpNumberSign 35 "`#`" +%token OpVerticalBar 124 "`|`" +%token OpAmpersand 38 "`&`" +%token OpModulo 37 "`%`" +%token OpLeftShift "`<<`" +%token OpRightShift "`>>`" + +%token LeftParenthesis 40 "`(`" +%token RightParenthesis 41 "`)`" + +%token Asterisk 42 "`*`" +%token OpPlus 43 "`+`" +%token Comma 44 "`,`" +%token OpMinus 45 "`-`" +%token Dot 46 "`.`" +%token OpDivide 47 "`/`" +%token OpConcatenate 48 "`||`" + + +%token OpLess 60 "`<`" +%token OpLessOrEqual "`<=`" +%token OpEqual 61 "`=`" +%token OpNotEqualCStyle "`!=`" +%token OpNotEqualSql92 "`<>`" +%token OpGreater 62 "`>`" +%token OpGreaterOrEqual "`>=`" + +%type <ETotalsMode> group-by-clause-tail + +%type <TTableDescriptor> table-descriptor + +%type <bool> is-desc +%type <bool> is-left + +%type <TReferenceExpressionPtr> qualified-identifier +%type <TIdentifierList> identifier-list + +%type <TOrderExpressionList> order-expr-list +%type <TExpressionList> expression +%type <TExpressionList> or-op-expr +%type <TExpressionList> and-op-expr +%type <TExpressionList> not-op-expr +%type <TExpressionList> is-null-op-expr +%type <TExpressionList> equal-op-expr +%type <TExpressionList> relational-op-expr +%type <TExpressionList> bitor-op-expr +%type <TExpressionList> bitand-op-expr +%type <TExpressionList> shift-op-expr +%type <TExpressionList> multiplicative-op-expr +%type <TExpressionList> additive-op-expr +%type <TExpressionList> unary-expr +%type <TExpressionList> atomic-expr +%type <TExpressionList> comma-expr +%type <TNullableExpressionList> transform-default-expr +%type <TNullableExpressionList> join-predicate + +%type <std::optional<TLiteralValue>> literal-value +%type <std::optional<TLiteralValue>> const-value +%type <TLiteralValueList> const-list +%type <TLiteralValueList> const-tuple +%type <TLiteralValueTupleList> const-tuple-list +%type <TLiteralValueRangeList> const-range-list + +%type <EUnaryOp> unary-op + +%type <EBinaryOp> relational-op +%type <EBinaryOp> multiplicative-op +%type <EBinaryOp> additive-op + +%start head + +%% + +head + : StrayWillParseQuery parse-query + | StrayWillParseJobQuery parse-job-query + | StrayWillParseExpression parse-expression +; + +parse-query + : select-clause from-clause where-clause group-by-clause order-by-clause offset-clause limit-clause +; + +parse-job-query + : select-clause where-clause +; + +parse-expression + : expression[expr] + { + if ($expr.size() != 1) { + THROW_ERROR_EXCEPTION("Expected scalar expression, got %Qv", GetSource(@$, source)); + } + std::get<TExpressionPtr>(head->Ast) = $expr.front(); + } +; + +select-clause + : optional-select-keyword comma-expr[projections] + { + std::get<TQuery>(head->Ast).SelectExprs = $projections; + } + | optional-select-keyword Asterisk + { } +; + +optional-select-keyword + : KwSelect + | +; + +table-descriptor + : Identifier[path] Identifier[alias] + { + $$ = TTableDescriptor($path, $alias); + } + | Identifier[path] KwAs Identifier[alias] + { + $$ = TTableDescriptor($path, $alias); + } + | Identifier[path] + { + $$ = TTableDescriptor($path); + } +; + +from-clause + : KwFrom table-descriptor[table] join-clause + { + std::get<TQuery>(head->Ast).Table = $table; + } +; + +join-predicate + : KwAnd and-op-expr[predicate] + { + $$ = $predicate; + } + | { } +; + +join-clause + : join-clause is-left[isLeft] KwJoin table-descriptor[table] KwUsing identifier-list[fields] join-predicate[predicate] + { + std::get<TQuery>(head->Ast).Joins.emplace_back($isLeft, $table, $fields, $predicate); + } + | join-clause is-left[isLeft] KwJoin table-descriptor[table] KwOn bitor-op-expr[lhs] OpEqual bitor-op-expr[rhs] join-predicate[predicate] + { + std::get<TQuery>(head->Ast).Joins.emplace_back($isLeft, $table, $lhs, $rhs, $predicate); + } + | +; + +is-left + : KwLeft + { + $$ = true; + } + | + { + $$ = false; + } +; + +where-clause + : KwWhere or-op-expr[predicate] + { + std::get<TQuery>(head->Ast).WherePredicate = $predicate; + } + | +; + +group-by-clause + : KwGroupBy comma-expr[exprs] group-by-clause-tail[totalsMode] + { + std::get<TQuery>(head->Ast).GroupExprs = std::make_pair($exprs, $totalsMode); + } + | +; + +group-by-clause-tail + : KwWithTotals + { + $$ = ETotalsMode::BeforeHaving; + } + | having-clause + { + $$ = ETotalsMode::None; + } + | having-clause KwWithTotals + { + $$ = ETotalsMode::AfterHaving; + } + | KwWithTotals having-clause + { + $$ = ETotalsMode::BeforeHaving; + } + | + { + $$ = ETotalsMode::None; + } +; + +having-clause + : KwHaving or-op-expr[predicate] + { + std::get<TQuery>(head->Ast).HavingPredicate = $predicate; + } +; + +order-by-clause + : KwOrderBy order-expr-list[exprs] + { + std::get<TQuery>(head->Ast).OrderExpressions = $exprs; + } + | +; + +order-expr-list + : order-expr-list[list] Comma expression[expr] is-desc[isDesc] + { + $$.swap($list); + $$.emplace_back($expr, $isDesc); + } + | expression[expr] is-desc[isDesc] + { + $$.emplace_back($expr, $isDesc); + } +; + +is-desc + : KwDesc + { + $$ = true; + } + | KwAsc + { + $$ = false; + } + | + { + $$ = false; + } +; + +offset-clause + : KwOffset Int64Literal[offset] + { + std::get<TQuery>(head->Ast).Offset = $offset; + } + | +; + +limit-clause + : KwLimit Int64Literal[limit] + { + std::get<TQuery>(head->Ast).Limit = $limit; + } + | +; + +identifier-list + : identifier-list[list] Comma qualified-identifier[value] + { + $$.swap($list); + $$.push_back($value); + } + | qualified-identifier[value] + { + $$.push_back($value); + } +; + +expression + : or-op-expr + { $$ = $1; } + | or-op-expr[expr] KwAs Identifier[name] + { + if ($expr.size() != 1) { + THROW_ERROR_EXCEPTION("Aliased expression %Qv must be scalar", GetSource(@$, source)); + } + auto inserted = head->AliasMap.emplace($name, $expr.front()).second; + if (!inserted) { + THROW_ERROR_EXCEPTION("Alias %Qv has been already used", $name); + } + $$ = MakeExpression<TAliasExpression>(head, @$, $expr.front(), $name); + } +; + +or-op-expr + : or-op-expr[lhs] KwOr and-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::Or, $lhs, $rhs); + } + | and-op-expr + { $$ = $1; } +; + +and-op-expr + + : and-op-expr[lhs] KwAnd not-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::And, $lhs, $rhs); + } + | not-op-expr + { $$ = $1; } +; + +not-op-expr + : KwNot is-null-op-expr[expr] + { + $$ = MakeExpression<TUnaryOpExpression>(head, @$, EUnaryOp::Not, $expr); + } + | is-null-op-expr + { $$ = $1; } +; + +is-null-op-expr + : equal-op-expr[expr] KwIs KwNull + { + $$ = MakeExpression<TFunctionExpression>(head, @$, "is_null", $expr); + } + | equal-op-expr[expr] KwIs KwNot KwNull + { + $$ = MakeExpression<TUnaryOpExpression>(head, @$, EUnaryOp::Not, + MakeExpression<TFunctionExpression>(head, @$, "is_null", $expr)); + } + | equal-op-expr + { $$ = $1; } +; + +equal-op-expr + : equal-op-expr[lhs] OpEqual relational-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::Equal, $lhs, $rhs); + } + + | equal-op-expr[lhs] OpNotEqual relational-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::NotEqual, $lhs, $rhs); + } + | relational-op-expr + { $$ = $1; } +; + +OpNotEqual + : OpNotEqualCStyle + | OpNotEqualSql92 +; + +relational-op-expr + : relational-op-expr[lhs] relational-op[opcode] bitor-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, $opcode, $lhs, $rhs); + } + | unary-expr[expr] KwBetween const-tuple[lower] KwAnd const-tuple[upper] + { + TExpressionList lowerExpr; + for (const auto& value : $lower) { + lowerExpr.push_back(head->New<TLiteralExpression>(@$, value)); + } + + TExpressionList upperExpr; + for (const auto& value : $upper) { + upperExpr.push_back(head->New<TLiteralExpression>(@$, value)); + } + + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::And, + MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::GreaterOrEqual, $expr, lowerExpr), + MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::LessOrEqual, $expr, upperExpr)); + } + | unary-expr[expr] KwBetween LeftParenthesis const-range-list[ranges] RightParenthesis + { + $$ = MakeExpression<TBetweenExpression>(head, @$, $expr, $ranges); + } + | unary-expr[expr] KwIn LeftParenthesis const-tuple-list[args] RightParenthesis + { + $$ = MakeExpression<TInExpression>(head, @$, $expr, $args); + } + | bitor-op-expr + { $$ = $1; } +; + +relational-op + : OpLess + { $$ = EBinaryOp::Less; } + | OpLessOrEqual + { $$ = EBinaryOp::LessOrEqual; } + | OpGreater + { $$ = EBinaryOp::Greater; } + | OpGreaterOrEqual + { $$ = EBinaryOp::GreaterOrEqual; } +; + +bitor-op-expr + : bitor-op-expr[lhs] OpVerticalBar bitand-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::BitOr, $lhs, $rhs); + } + | bitand-op-expr + { $$ = $1; } +; + +bitand-op-expr + : bitand-op-expr[lhs] OpAmpersand shift-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::BitAnd, $lhs, $rhs); + } + | shift-op-expr + { $$ = $1; } +; + +shift-op-expr + : shift-op-expr[lhs] OpLeftShift additive-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::LeftShift, $lhs, $rhs); + } + | shift-op-expr[lhs] OpRightShift additive-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, EBinaryOp::RightShift, $lhs, $rhs); + } + | additive-op-expr + { $$ = $1; } +; + +additive-op-expr + : additive-op-expr[lhs] additive-op[opcode] multiplicative-op-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, $opcode, $lhs, $rhs); + } + | multiplicative-op-expr + { $$ = $1; } +; + +additive-op + : OpPlus + { $$ = EBinaryOp::Plus; } + | OpMinus + { $$ = EBinaryOp::Minus; } + | OpConcatenate + { $$ = EBinaryOp::Concatenate; } +; + +multiplicative-op-expr + : multiplicative-op-expr[lhs] multiplicative-op[opcode] unary-expr[rhs] + { + $$ = MakeExpression<TBinaryOpExpression>(head, @$, $opcode, $lhs, $rhs); + } + | unary-expr + { $$ = $1; } +; + +multiplicative-op + : Asterisk + { $$ = EBinaryOp::Multiply; } + | OpDivide + { $$ = EBinaryOp::Divide; } + | OpModulo + { $$ = EBinaryOp::Modulo; } +; + +comma-expr + : comma-expr[lhs] Comma expression[rhs] + { + $$ = $lhs; + $$.insert($$.end(), $rhs.begin(), $rhs.end()); + } + | expression + { $$ = $1; } +; + +unary-expr + : unary-op[opcode] unary-expr[rhs] + { + $$ = MakeExpression<TUnaryOpExpression>(head, @$, $opcode, $rhs); + } + | atomic-expr + { $$ = $1; } +; + +unary-op + : OpPlus + { $$ = EUnaryOp::Plus; } + | OpMinus + { $$ = EUnaryOp::Minus; } + | OpTilde + { $$ = EUnaryOp::BitNot; } +; + +qualified-identifier + : Identifier[name] + { + $$ = head->New<TReferenceExpression>(@$, $name); + } + | Identifier[table] Dot Identifier[name] + { + $$ = head->New<TReferenceExpression>(@$, $name, $table); + } +; + +atomic-expr + : qualified-identifier[identifier] + { + $$ = TExpressionList(1, $identifier); + } + | Identifier[name] LeftParenthesis RightParenthesis + { + $$ = MakeExpression<TFunctionExpression>(head, @$, $name, TExpressionList()); + } + | Identifier[name] LeftParenthesis comma-expr[args] RightParenthesis + { + $$ = MakeExpression<TFunctionExpression>(head, @$, $name, $args); + } + | KwTransform LeftParenthesis expression[expr] Comma LeftParenthesis const-tuple-list[from] RightParenthesis Comma LeftParenthesis const-tuple-list[to] RightParenthesis transform-default-expr[default] RightParenthesis + { + $$ = MakeExpression<TTransformExpression>(head, @$, $expr, $from, $to, $default); + } + | LeftParenthesis comma-expr[expr] RightParenthesis + { + $$ = $expr; + } + | literal-value[value] + { + $$ = MakeExpression<TLiteralExpression>(head, @$, *$value); + } +; + +transform-default-expr + : Comma expression[expr] + { + $$ = $expr; + } + | { } +; + +literal-value + : Int64Literal + { $$ = $1; } + | Uint64Literal + { $$ = $1; } + | DoubleLiteral + { $$ = $1; } + | StringLiteral + { $$ = $1; } + | KwFalse + { $$ = false; } + | KwTrue + { $$ = true; } + | KwNull + { $$ = TNullLiteralValue(); } + | OpNumberSign + { $$ = TNullLiteralValue(); } +; + +const-value + : unary-op[op] const-value[value] + { + switch ($op) { + case EUnaryOp::Minus: { + if (const auto* data = std::get_if<i64>(&*$value)) { + $$ = -*data; + } else if (const auto* data = std::get_if<ui64>(&*$value)) { + $$ = -*data; + } else if (const auto* data = std::get_if<double>(&*$value)) { + $$ = -*data; + } else { + THROW_ERROR_EXCEPTION("Negation of unsupported type"); + } + break; + } + case EUnaryOp::Plus: + $$ = $value; + break; + case EUnaryOp::BitNot: { + if (const auto* data = std::get_if<i64>(&*$value)) { + $$ = ~*data; + } else if (const auto* data = std::get_if<ui64>(&*$value)) { + $$ = ~*data; + } else { + THROW_ERROR_EXCEPTION("Bitwise negation of unsupported type"); + } + break; + } + default: + YT_ABORT(); + } + + } + | literal-value[value] + { $$ = $value; } +; + +const-list + : const-list[as] Comma const-value[a] + { + $$.swap($as); + $$.push_back(*$a); + } + | const-value[a] + { + $$.push_back(*$a); + } +; + +const-tuple + : const-value[a] + { + $$.push_back(*$a); + } + | LeftParenthesis const-list[a] RightParenthesis + { + $$ = $a; + } +; + +const-tuple-list + : const-tuple-list[as] Comma const-tuple[a] + { + $$.swap($as); + $$.push_back($a); + } + | const-tuple[a] + { + $$.push_back($a); + } +; + +const-range-list + : const-range-list[as] Comma const-tuple[a] KwAnd const-tuple[b] + { + $$.swap($as); + $$.emplace_back($a, $b); + } + | const-tuple[a] KwAnd const-tuple[b] + { + $$.emplace_back($a, $b); + } +; + +%% + +namespace NYT { +namespace NQueryClient { +namespace NAst { + +//////////////////////////////////////////////////////////////////////////////// + +void TParser::error(const location_type& location, const std::string& message) +{ + auto leftContextStart = std::max<size_t>(location.first, 16) - 16; + auto rightContextEnd = std::min<size_t>(location.second + 16, source.size()); + + THROW_ERROR_EXCEPTION("Error while parsing query: %v", message) + << TErrorAttribute("position", Format("%v-%v", location.first, location.second)) + << TErrorAttribute("query", Format("%v >>>>> %v <<<<< %v", + source.substr(leftContextStart, location.first - leftContextStart), + source.substr(location.first, location.second - location.first), + source.substr(location.second, rightContextEnd - location.second))); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NAst +} // namespace NQueryClient +} // namespace NYT diff --git a/yt/yt/library/query/base/private.h b/yt/yt/library/query/base/private.h new file mode 100644 index 0000000000..bb87daccf4 --- /dev/null +++ b/yt/yt/library/query/base/private.h @@ -0,0 +1,16 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger QueryClientLogger("QueryClient"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/base/public.cpp b/yt/yt/library/query/base/public.cpp new file mode 100644 index 0000000000..13eb986fcb --- /dev/null +++ b/yt/yt/library/query/base/public.cpp @@ -0,0 +1,12 @@ +#include "public.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +const NYPath::TYPath QueryPoolsPath("//sys/ql_pools"); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/base/public.h b/yt/yt/library/query/base/public.h new file mode 100644 index 0000000000..c658b2c58c --- /dev/null +++ b/yt/yt/library/query/base/public.h @@ -0,0 +1,149 @@ +#pragma once + +#include <yt/yt/client/query_client/public.h> + +#include <yt/yt/client/transaction_client/public.h> + +#include <yt/yt/client/table_client/public.h> + +#include <yt/yt/core/ypath/public.h> + +namespace NYT::NQueryClient { + +using NTransactionClient::TTimestamp; + +using NTableClient::TRowRange; + +using TReadSessionId = TGuid; + +struct TDataSplit; + +//////////////////////////////////////////////////////////////////////////////// + +namespace NProto { + +class TColumnDescriptor; +class TExpression; +class TGroupClause; +class TProjectClause; +class TJoinClause; +class TQuery; +class TQueryOptions; +class TDataSource; + +} // namespace NProto + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TExpression) +using TConstExpressionPtr = TIntrusivePtr<const TExpression>; + +DECLARE_REFCOUNTED_STRUCT(TFunctionExpression) +using TConstFunctionExpressionPtr = TIntrusivePtr<const TFunctionExpression>; + +DECLARE_REFCOUNTED_STRUCT(TAggregateFunctionExpression) +using TConstAggregateFunctionExpressionPtr = TIntrusivePtr<const TAggregateFunctionExpression>; + +DECLARE_REFCOUNTED_STRUCT(TJoinClause) +using TConstJoinClausePtr = TIntrusivePtr<const TJoinClause>; + +DECLARE_REFCOUNTED_STRUCT(TGroupClause) +using TConstGroupClausePtr = TIntrusivePtr<const TGroupClause>; + +DECLARE_REFCOUNTED_STRUCT(TOrderClause) +using TConstOrderClausePtr = TIntrusivePtr<const TOrderClause>; + +DECLARE_REFCOUNTED_STRUCT(TProjectClause) +using TConstProjectClausePtr = TIntrusivePtr<const TProjectClause>; + +DECLARE_REFCOUNTED_STRUCT(TBaseQuery) +using TConstBaseQueryPtr = TIntrusivePtr<const TBaseQuery>; + +DECLARE_REFCOUNTED_STRUCT(TFrontQuery) +using TConstFrontQueryPtr = TIntrusivePtr<const TFrontQuery>; + +DECLARE_REFCOUNTED_STRUCT(TQuery) +using TConstQueryPtr = TIntrusivePtr<const TQuery>; + +struct IPrepareCallbacks; + +struct TQueryStatistics; + +struct TQueryOptions; + +DECLARE_REFCOUNTED_STRUCT(IAggregateFunctionDescriptor) + +DECLARE_REFCOUNTED_STRUCT(ICallingConvention) + +DECLARE_REFCOUNTED_STRUCT(IExecutor) + +DECLARE_REFCOUNTED_STRUCT(IEvaluator) + +DECLARE_REFCOUNTED_CLASS(TExecutorConfig) + +DECLARE_REFCOUNTED_CLASS(TColumnEvaluator) + +DECLARE_REFCOUNTED_STRUCT(IColumnEvaluatorCache) + +DECLARE_REFCOUNTED_CLASS(TColumnEvaluatorCacheConfig) +DECLARE_REFCOUNTED_CLASS(TColumnEvaluatorCacheDynamicConfig) + +DECLARE_REFCOUNTED_STRUCT(TExternalCGInfo) +using TConstExternalCGInfoPtr = TIntrusivePtr<const TExternalCGInfo>; + +DECLARE_REFCOUNTED_STRUCT(TTypeInferrerMap) +using TConstTypeInferrerMapPtr = TIntrusivePtr<const TTypeInferrerMap>; + +const TConstTypeInferrerMapPtr GetBuiltinTypeInferrers(); + +DECLARE_REFCOUNTED_STRUCT(IFunctionRegistry) +DECLARE_REFCOUNTED_CLASS(ITypeInferrer) + +DECLARE_REFCOUNTED_CLASS(TFunctionImplCache) + +using NTableClient::ISchemafulUnversionedReader; +using NTableClient::ISchemafulUnversionedReaderPtr; +using NTableClient::ISchemalessUnversionedReader; +using NTableClient::ISchemalessUnversionedReaderPtr; +using NTableClient::IUnversionedRowsetWriter; +using NTableClient::IUnversionedRowsetWriterPtr; +using NTableClient::EValueType; +using NTableClient::TTableSchema; +using NTableClient::TTableSchemaPtr; +using NTableClient::TColumnSchema; +using NTableClient::TKeyColumns; +using NTableClient::TColumnFilter; +using NTableClient::TRowRange; + +using NTransactionClient::TTimestamp; +using NTransactionClient::NullTimestamp; + +using NTableClient::TRowBuffer; +using NTableClient::TRowBufferPtr; + +using TSchemaColumns = std::vector<NTableClient::TColumnSchema>; + +using TRow = NTableClient::TUnversionedRow; +using TMutableRow = NTableClient::TMutableUnversionedRow; +using TRowHeader = NTableClient::TUnversionedRowHeader; +using TRowBuilder = NTableClient::TUnversionedRowBuilder; +using TOwningRow = NTableClient::TUnversionedOwningRow; +using TOwningRowBuilder = NTableClient::TUnversionedOwningRowBuilder; +using TValue = NTableClient::TUnversionedValue; +using TValueData = NTableClient::TUnversionedValueData; +using TOwningValue = NTableClient::TUnversionedOwningValue; +using TLegacyOwningKey = NTableClient::TLegacyOwningKey; + +using TKeyRange = std::pair<TLegacyOwningKey, TLegacyOwningKey>; +using TMutableRowRange = std::pair<TMutableRow, TMutableRow>; +using TRowRanges = std::vector<TRowRange>; +using TMutableRowRanges = std::vector<TMutableRowRange>; + +//////////////////////////////////////////////////////////////////////////////// + +extern const NYPath::TYPath QueryPoolsPath; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/base/query.cpp b/yt/yt/library/query/base/query.cpp new file mode 100644 index 0000000000..03603854a8 --- /dev/null +++ b/yt/yt/library/query/base/query.cpp @@ -0,0 +1,980 @@ +#include "query.h" +#include "private.h" + +#include <yt/yt/library/query/proto/query.pb.h> + +#include <yt/yt/client/table_client/row_base.h> +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/wire_protocol.h> + +#include <yt/yt/core/ytree/serialize.h> +#include <yt/yt/core/ytree/convert.h> + +#include <library/cpp/yt/misc/cast.h> + +#include <limits> + +namespace NYT::NQueryClient { + +using namespace NTableClient; +using namespace NObjectClient; + +using NYT::ToProto; +using NYT::FromProto; + +//////////////////////////////////////////////////////////////////////////////// + +//! Computes key index for a given column name. +int ColumnNameToKeyPartIndex(const TKeyColumns& keyColumns, const TString& columnName) +{ + for (int index = 0; index < std::ssize(keyColumns); ++index) { + if (keyColumns[index] == columnName) { + return index; + } + } + return -1; +} + +TLogicalTypePtr ToQLType(const NTableClient::TLogicalTypePtr& columnType) +{ + if (IsV1Type(columnType)) { + const auto wireType = GetWireType(columnType); + return MakeLogicalType(GetLogicalType(wireType), false); + } else { + return columnType; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +struct TExpressionPrinter + : TAbstractExpressionPrinter<TExpressionPrinter, TConstExpressionPtr> +{ + using TBase = TAbstractExpressionPrinter<TExpressionPrinter, TConstExpressionPtr>; + TExpressionPrinter(TStringBuilderBase* builder, bool omitValues) + : TBase(builder, omitValues) + { } +}; + +TString InferName(TConstExpressionPtr expr, bool omitValues) +{ + if (!expr) { + return TString(); + } + TStringBuilder builder; + TExpressionPrinter expressionPrinter(&builder, omitValues); + expressionPrinter.Visit(expr); + return builder.Flush(); +} + +TString InferName(TConstBaseQueryPtr query, TInferNameOptions options) +{ + auto namedItemFormatter = [&] (TStringBuilderBase* builder, const TNamedItem& item) { + builder->AppendString(InferName(item.Expression, options.OmitValues)); + if (!options.OmitAliases) { + builder->AppendFormat(" AS %v", item.Name); + } + }; + + auto orderItemFormatter = [&] (TStringBuilderBase* builder, const TOrderItem& item) { + builder->AppendFormat("%v %v", + InferName(item.Expression, options.OmitValues), + item.Descending ? "DESC" : "ASC"); + }; + + std::vector<TString> clauses; + TString str; + + if (query->ProjectClause) { + str = JoinToString(query->ProjectClause->Projections, namedItemFormatter); + } else { + str = "*"; + } + + clauses.emplace_back("SELECT " + str); + + if (auto derivedQuery = dynamic_cast<const TQuery*>(query.Get())) { + for (const auto& joinClause : derivedQuery->JoinClauses) { + std::vector<TString> selfJoinEquation; + for (const auto& equation : joinClause->SelfEquations) { + selfJoinEquation.push_back(InferName(equation.Expression, options.OmitValues)); + } + std::vector<TString> foreignJoinEquation; + for (const auto& equation : joinClause->ForeignEquations) { + foreignJoinEquation.push_back(InferName(equation, options.OmitValues)); + } + + clauses.push_back(Format( + "%v JOIN[common prefix: %v, foreign prefix: %v] ON (%v) = (%v)", + joinClause->IsLeft ? "LEFT" : "INNER", + joinClause->CommonKeyPrefix, + joinClause->ForeignKeyPrefix, + JoinToString(selfJoinEquation), + JoinToString(foreignJoinEquation))); + + if (joinClause->Predicate && !options.OmitJoinPredicate) { + clauses.push_back("AND " + InferName(joinClause->Predicate, options.OmitValues)); + } + } + + if (derivedQuery->WhereClause) { + clauses.push_back(TString("WHERE ") + InferName(derivedQuery->WhereClause, options.OmitValues)); + } + } + + if (query->GroupClause) { + clauses.push_back(Format("GROUP BY[common prefix: %v, disjoint: %v, aggregates: %v] %v", + query->GroupClause->CommonPrefixWithPrimaryKey, + query->UseDisjointGroupBy, + MakeFormattableView(query->GroupClause->AggregateItems, [] (auto* builder, const auto& item) { + builder->AppendString(item.AggregateFunction); + }), + JoinToString(query->GroupClause->GroupItems, namedItemFormatter))); + if (query->GroupClause->TotalsMode == ETotalsMode::BeforeHaving) { + clauses.push_back("WITH TOTALS"); + } + } + + if (query->HavingClause) { + clauses.push_back(TString("HAVING ") + InferName(query->HavingClause, options.OmitValues)); + if (query->GroupClause->TotalsMode == ETotalsMode::AfterHaving) { + clauses.push_back("WITH TOTALS"); + } + } + + if (query->OrderClause) { + clauses.push_back(TString("ORDER BY ") + JoinToString(query->OrderClause->OrderItems, orderItemFormatter)); + } + + + + if (query->Limit < std::numeric_limits<i64>::max()) { + clauses.push_back(TString("OFFSET ") + (options.OmitValues ? "?" : ToString(query->Offset))); + clauses.push_back(TString("LIMIT ") + (options.OmitValues ? "?" : ToString(query->Limit))); + } + + return JoinToString(clauses, TStringBuf(" ")); +} + +//////////////////////////////////////////////////////////////////////////////// + +bool Compare( + TConstExpressionPtr lhs, + const TTableSchema& lhsSchema, + TConstExpressionPtr rhs, + const TTableSchema& rhsSchema, + size_t maxIndex) +{ +#define CHECK(condition) \ + do { \ + if (!(condition)) { \ + return false; \ + } \ + } while (false) + + CHECK(*lhs->LogicalType == *rhs->LogicalType); + + if (auto literalLhs = lhs->As<TLiteralExpression>()) { + auto literalRhs = rhs->As<TLiteralExpression>(); + CHECK(literalRhs); + CHECK(literalLhs->Value == literalRhs->Value); + } else if (auto referenceLhs = lhs->As<TReferenceExpression>()) { + auto referenceRhs = rhs->As<TReferenceExpression>(); + CHECK(referenceRhs); + auto lhsIndex = lhsSchema.GetColumnIndexOrThrow(referenceLhs->ColumnName); + auto rhsIndex = rhsSchema.GetColumnIndexOrThrow(referenceRhs->ColumnName); + CHECK(lhsIndex == rhsIndex); + CHECK(static_cast<size_t>(lhsIndex) < maxIndex); + } else if (auto functionLhs = lhs->As<TFunctionExpression>()) { + auto functionRhs = rhs->As<TFunctionExpression>(); + CHECK(functionRhs); + CHECK(functionLhs->FunctionName == functionRhs->FunctionName); + CHECK(functionLhs->Arguments.size() == functionRhs->Arguments.size()); + + for (size_t index = 0; index < functionLhs->Arguments.size(); ++index) { + CHECK(Compare(functionLhs->Arguments[index], lhsSchema, functionRhs->Arguments[index], rhsSchema, maxIndex)); + } + } else if (auto unaryLhs = lhs->As<TUnaryOpExpression>()) { + auto unaryRhs = rhs->As<TUnaryOpExpression>(); + CHECK(unaryRhs); + CHECK(unaryLhs->Opcode == unaryRhs->Opcode); + CHECK(Compare(unaryLhs->Operand, lhsSchema, unaryRhs->Operand, rhsSchema, maxIndex)); + } else if (auto binaryLhs = lhs->As<TBinaryOpExpression>()) { + auto binaryRhs = rhs->As<TBinaryOpExpression>(); + CHECK(binaryRhs); + CHECK(binaryLhs->Opcode == binaryRhs->Opcode); + CHECK(Compare(binaryLhs->Lhs, lhsSchema, binaryRhs->Lhs, rhsSchema, maxIndex)); + CHECK(Compare(binaryLhs->Rhs, lhsSchema, binaryRhs->Rhs, rhsSchema, maxIndex)); + } else if (auto inLhs = lhs->As<TInExpression>()) { + auto inRhs = rhs->As<TInExpression>(); + CHECK(inRhs); + CHECK(inLhs->Arguments.size() == inRhs->Arguments.size()); + for (size_t index = 0; index < inLhs->Arguments.size(); ++index) { + CHECK(Compare(inLhs->Arguments[index], lhsSchema, inRhs->Arguments[index], rhsSchema, maxIndex)); + } + + CHECK(inLhs->Values.Size() == inRhs->Values.Size()); + for (size_t index = 0; index < inLhs->Values.Size(); ++index) { + CHECK(inLhs->Values[index] == inRhs->Values[index]); + } + } else if (auto betweenLhs = lhs->As<TBetweenExpression>()) { + auto betweenRhs = rhs->As<TBetweenExpression>(); + CHECK(betweenRhs); + CHECK(betweenLhs->Arguments.size() == betweenRhs->Arguments.size()); + for (size_t index = 0; index < betweenLhs->Arguments.size(); ++index) { + CHECK(Compare(betweenLhs->Arguments[index], lhsSchema, betweenRhs->Arguments[index], rhsSchema, maxIndex)); + } + + CHECK(betweenLhs->Ranges.Size() == betweenRhs->Ranges.Size()); + for (size_t index = 0; index < betweenLhs->Ranges.Size(); ++index) { + CHECK(betweenLhs->Ranges[index] == betweenRhs->Ranges[index]); + } + } else if (auto transformLhs = lhs->As<TTransformExpression>()) { + auto transformRhs = rhs->As<TTransformExpression>(); + CHECK(transformRhs); + CHECK(transformLhs->Arguments.size() == transformRhs->Arguments.size()); + for (size_t index = 0; index < transformLhs->Arguments.size(); ++index) { + CHECK(Compare(transformLhs->Arguments[index], lhsSchema, transformRhs->Arguments[index], rhsSchema, maxIndex)); + } + + CHECK(transformLhs->Values.Size() == transformRhs->Values.Size()); + for (size_t index = 0; index < transformLhs->Values.Size(); ++index) { + CHECK(transformLhs->Values[index] == transformRhs->Values[index]); + } + CHECK(Compare(transformLhs->DefaultExpression, lhsSchema, transformRhs->DefaultExpression, rhsSchema, maxIndex)); + } else { + YT_ABORT(); + } +#undef CHECK + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ThrowTypeMismatchError( + EValueType lhsType, + EValueType rhsType, + TStringBuf source, + TStringBuf lhsSource, + TStringBuf rhsSource) +{ + THROW_ERROR_EXCEPTION("Type mismatch in expression %Qv", source) + << TErrorAttribute("lhs_source", lhsSource) + << TErrorAttribute("rhs_source", rhsSource) + << TErrorAttribute("lhs_type", lhsType) + << TErrorAttribute("rhs_type", rhsType); +} + +struct TExtraColumnsChecker + : public TVisitor<TExtraColumnsChecker> +{ + using TBase = TVisitor<TExtraColumnsChecker>; + + const THashSet<TString>& Names; + bool HasExtraColumns = false; + + explicit TExtraColumnsChecker(const THashSet<TString>& names) + : Names(names) + { } + + void OnReference(const TReferenceExpression* referenceExpr) + { + HasExtraColumns |= Names.count(referenceExpr->ColumnName) == 0; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +std::vector<size_t> GetJoinGroups( + const std::vector<TConstJoinClausePtr>& joinClauses, + TTableSchemaPtr schema) +{ + THashSet<TString> names; + for (const auto& column : schema->Columns()) { + names.insert(column.Name()); + } + + std::vector<size_t> joinGroups; + + size_t counter = 0; + for (const auto& joinClause : joinClauses) { + TExtraColumnsChecker extraColumnsChecker(names); + + for (const auto& equation : joinClause->SelfEquations) { + if (!equation.Evaluated) { + extraColumnsChecker.Visit(equation.Expression); + } + } + + if (extraColumnsChecker.HasExtraColumns) { + YT_VERIFY(counter > 0); + joinGroups.push_back(counter); + counter = 0; + names.clear(); + for (const auto& column : schema->Columns()) { + names.insert(column.Name()); + } + } + + ++counter; + schema = joinClause->GetTableSchema(*schema); + } + + if (counter > 0) { + joinGroups.push_back(counter); + counter = 0; + } + + return joinGroups; +} + +NLogging::TLogger MakeQueryLogger(TGuid queryId) +{ + return QueryClientLogger.WithTag("FragmentId: %v", queryId); +} + +NLogging::TLogger MakeQueryLogger(TConstBaseQueryPtr query) +{ + return MakeQueryLogger(query->Id); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TExpression* serialized, const TConstExpressionPtr& original) +{ + if (!original) { + serialized->set_kind(static_cast<int>(EExpressionKind::None)); + return; + } + + // N.B. backward compatibility old `type` proto field could contain only + // Int64,Uint64,String,Boolean,Null,Any types. + const auto wireType = NTableClient::GetPhysicalType( + NTableClient::CastToV1Type(original->LogicalType).first); + + serialized->set_type(static_cast<int>(wireType)); + + if (!IsV1Type(original->LogicalType) || + *original->LogicalType != *MakeLogicalType(GetLogicalType(wireType), false)) + { + ToProto(serialized->mutable_logical_type(), original->LogicalType); + } + + if (auto literalExpr = original->As<TLiteralExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::Literal)); + auto* proto = serialized->MutableExtension(NProto::TLiteralExpression::literal_expression); + auto value = TValue(literalExpr->Value); + auto data = value.Data; + + switch (value.Type) { + case EValueType::Int64: { + proto->set_int64_value(data.Int64); + break; + } + + case EValueType::Uint64: { + proto->set_uint64_value(data.Uint64); + break; + } + + case EValueType::Double: { + proto->set_double_value(data.Double); + break; + } + + case EValueType::String: { + proto->set_string_value(data.String, value.Length); + break; + } + + case EValueType::Boolean: { + proto->set_boolean_value(data.Boolean); + break; + } + + case EValueType::Null: { + break; + } + + default: + YT_ABORT(); + } + + } else if (auto referenceExpr = original->As<TReferenceExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::Reference)); + auto* proto = serialized->MutableExtension(NProto::TReferenceExpression::reference_expression); + proto->set_column_name(referenceExpr->ColumnName); + } else if (auto functionExpr = original->As<TFunctionExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::Function)); + auto* proto = serialized->MutableExtension(NProto::TFunctionExpression::function_expression); + proto->set_function_name(functionExpr->FunctionName); + ToProto(proto->mutable_arguments(), functionExpr->Arguments); + } else if (auto unaryOpExpr = original->As<TUnaryOpExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::UnaryOp)); + auto* proto = serialized->MutableExtension(NProto::TUnaryOpExpression::unary_op_expression); + proto->set_opcode(static_cast<int>(unaryOpExpr->Opcode)); + ToProto(proto->mutable_operand(), unaryOpExpr->Operand); + } else if (auto binaryOpExpr = original->As<TBinaryOpExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::BinaryOp)); + auto* proto = serialized->MutableExtension(NProto::TBinaryOpExpression::binary_op_expression); + proto->set_opcode(static_cast<int>(binaryOpExpr->Opcode)); + ToProto(proto->mutable_lhs(), binaryOpExpr->Lhs); + ToProto(proto->mutable_rhs(), binaryOpExpr->Rhs); + } else if (auto inExpr = original->As<TInExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::In)); + auto* proto = serialized->MutableExtension(NProto::TInExpression::in_expression); + ToProto(proto->mutable_arguments(), inExpr->Arguments); + + auto writer = CreateWireProtocolWriter(); + writer->WriteUnversionedRowset(inExpr->Values); + ToProto(proto->mutable_values(), MergeRefsToString(writer->Finish())); + } else if (auto betweenExpr = original->As<TBetweenExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::Between)); + auto* proto = serialized->MutableExtension(NProto::TBetweenExpression::between_expression); + ToProto(proto->mutable_arguments(), betweenExpr->Arguments); + + auto rangesWriter = CreateWireProtocolWriter(); + for (const auto& range : betweenExpr->Ranges) { + rangesWriter->WriteUnversionedRow(range.first); + rangesWriter->WriteUnversionedRow(range.second); + } + ToProto(proto->mutable_ranges(), MergeRefsToString(rangesWriter->Finish())); + } else if (auto transformExpr = original->As<TTransformExpression>()) { + serialized->set_kind(static_cast<int>(EExpressionKind::Transform)); + auto* proto = serialized->MutableExtension(NProto::TTransformExpression::transform_expression); + ToProto(proto->mutable_arguments(), transformExpr->Arguments); + + auto writer = CreateWireProtocolWriter(); + writer->WriteUnversionedRowset(transformExpr->Values); + ToProto(proto->mutable_values(), MergeRefsToString(writer->Finish())); + if (transformExpr->DefaultExpression) { + ToProto(proto->mutable_default_expression(), transformExpr->DefaultExpression); + } + } +} + +void FromProto(TConstExpressionPtr* original, const NProto::TExpression& serialized) +{ + TLogicalTypePtr type; + if (serialized.has_logical_type()) { + FromProto(&type, serialized.logical_type()); + } else { + auto wireType = CheckedEnumCast<EValueType>(serialized.type()); + type = MakeLogicalType(GetLogicalType(wireType), false); + } + + auto kind = CheckedEnumCast<EExpressionKind>(serialized.kind()); + switch (kind) { + case EExpressionKind::None: { + *original = nullptr; + return; + } + + case EExpressionKind::Literal: { + auto result = New<TLiteralExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TLiteralExpression::literal_expression); + + if (ext.has_int64_value()) { + result->Value = MakeUnversionedInt64Value(ext.int64_value()); + } else if (ext.has_uint64_value()) { + result->Value = MakeUnversionedUint64Value(ext.uint64_value()); + } else if (ext.has_double_value()) { + result->Value = MakeUnversionedDoubleValue(ext.double_value()); + } else if (ext.has_string_value()) { + result->Value = MakeUnversionedStringValue(ext.string_value()); + } else if (ext.has_boolean_value()) { + result->Value = MakeUnversionedBooleanValue(ext.boolean_value()); + } else { + result->Value = MakeUnversionedSentinelValue(EValueType::Null); + } + + *original = result; + return; + } + + case EExpressionKind::Reference: { + auto result = New<TReferenceExpression>(type); + const auto& data = serialized.GetExtension(NProto::TReferenceExpression::reference_expression); + result->ColumnName = data.column_name(); + *original = result; + return; + } + + case EExpressionKind::Function: { + auto result = New<TFunctionExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TFunctionExpression::function_expression); + result->FunctionName = ext.function_name(); + FromProto(&result->Arguments, ext.arguments()); + *original = result; + return; + } + + case EExpressionKind::UnaryOp: { + auto result = New<TUnaryOpExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TUnaryOpExpression::unary_op_expression); + result->Opcode = EUnaryOp(ext.opcode()); + FromProto(&result->Operand, ext.operand()); + *original = result; + return; + } + + case EExpressionKind::BinaryOp: { + auto result = New<TBinaryOpExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TBinaryOpExpression::binary_op_expression); + result->Opcode = EBinaryOp(ext.opcode()); + FromProto(&result->Lhs, ext.lhs()); + FromProto(&result->Rhs, ext.rhs()); + *original = result; + return; + } + + case EExpressionKind::In: { + auto result = New<TInExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TInExpression::in_expression); + FromProto(&result->Arguments, ext.arguments()); + auto reader = CreateWireProtocolReader( + TSharedRef::FromString(ext.values()), + New<TRowBuffer>(TExpressionRowsetTag())); + result->Values = reader->ReadUnversionedRowset(true); + *original = result; + return; + } + + case EExpressionKind::Between: { + auto result = New<TBetweenExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TBetweenExpression::between_expression); + FromProto(&result->Arguments, ext.arguments()); + + TRowRanges ranges; + auto rowBuffer = New<TRowBuffer>(TExpressionRowsetTag()); + auto rangesReader = CreateWireProtocolReader( + TSharedRef::FromString<TExpressionRowsetTag>(ext.ranges()), + rowBuffer); + while (!rangesReader->IsFinished()) { + auto lowerBound = rangesReader->ReadUnversionedRow(true); + auto upperBound = rangesReader->ReadUnversionedRow(true); + ranges.emplace_back(lowerBound, upperBound); + } + result->Ranges = MakeSharedRange(std::move(ranges), std::move(rowBuffer)); + *original = result; + return; + } + + case EExpressionKind::Transform: { + auto result = New<TTransformExpression>(GetWireType(type)); + const auto& ext = serialized.GetExtension(NProto::TTransformExpression::transform_expression); + FromProto(&result->Arguments, ext.arguments()); + auto reader = CreateWireProtocolReader( + TSharedRef::FromString(ext.values()), + New<TRowBuffer>(TExpressionRowsetTag())); + result->Values = reader->ReadUnversionedRowset(true); + if (ext.has_default_expression()) { + FromProto(&result->DefaultExpression, ext.default_expression()); + } + *original = result; + return; + } + } + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TNamedItem* serialized, const TNamedItem& original) +{ + ToProto(serialized->mutable_expression(), original.Expression); + ToProto(serialized->mutable_name(), original.Name); +} + +void FromProto(TNamedItem* original, const NProto::TNamedItem& serialized) +{ + *original = TNamedItem( + FromProto<TConstExpressionPtr>(serialized.expression()), + serialized.name()); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TAggregateItem* serialized, const TAggregateItem& original) +{ + ToProto(serialized->mutable_expression(), original.Arguments.front()); + serialized->set_aggregate_function_name(original.AggregateFunction); + serialized->set_state_type(static_cast<int>(original.StateType)); + serialized->set_result_type(static_cast<int>(original.ResultType)); + ToProto(serialized->mutable_name(), original.Name); + ToProto(serialized->mutable_arguments(), original.Arguments); +} + +void FromProto(TAggregateItem* original, const NProto::TAggregateItem& serialized) +{ + original->AggregateFunction = serialized.aggregate_function_name(); + original->Name = serialized.name(); + original->StateType = static_cast<EValueType>(serialized.state_type()); + original->ResultType = static_cast<EValueType>(serialized.result_type()); + // COMPAT(sabdenovch) + if (serialized.arguments_size() > 0) { + original->Arguments = FromProto<std::vector<TConstExpressionPtr>>(serialized.arguments()); + } else { + original->Arguments = {FromProto<TConstExpressionPtr>(serialized.expression())}; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TSelfEquation* proto, const TSelfEquation& original) +{ + ToProto(proto->mutable_expression(), original.Expression); + proto->set_evaluated(original.Evaluated); +} + +void FromProto(TSelfEquation* original, const NProto::TSelfEquation& serialized) +{ + FromProto(&original->Expression, serialized.expression()); + FromProto(&original->Evaluated, serialized.evaluated()); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TColumnDescriptor* proto, const TColumnDescriptor& original) +{ + proto->set_name(original.Name); + proto->set_index(original.Index); +} + +void FromProto(TColumnDescriptor* original, const NProto::TColumnDescriptor& serialized) +{ + FromProto(&original->Name, serialized.name()); + FromProto(&original->Index, serialized.index()); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TJoinClause* proto, const TConstJoinClausePtr& original) +{ + ToProto(proto->mutable_original_schema(), original->Schema.Original); + ToProto(proto->mutable_schema_mapping(), original->Schema.Mapping); + ToProto(proto->mutable_self_joined_columns(), original->SelfJoinedColumns); + ToProto(proto->mutable_foreign_joined_columns(), original->ForeignJoinedColumns); + + ToProto(proto->mutable_foreign_equations(), original->ForeignEquations); + ToProto(proto->mutable_self_equations(), original->SelfEquations); + + ToProto(proto->mutable_foreign_object_id(), original->ForeignObjectId); + ToProto(proto->mutable_foreign_cell_id(), original->ForeignCellId); + + proto->set_is_left(original->IsLeft); + + // COMPAT(lukyan) + bool canUseSourceRanges = original->ForeignKeyPrefix == original->ForeignEquations.size(); + proto->set_can_use_source_ranges(canUseSourceRanges); + proto->set_common_key_prefix(canUseSourceRanges ? original->CommonKeyPrefix : 0); + proto->set_common_key_prefix_new(original->CommonKeyPrefix); + proto->set_foreign_key_prefix(original->ForeignKeyPrefix); + + if (original->Predicate) { + ToProto(proto->mutable_predicate(), original->Predicate); + } +} + +void FromProto(TConstJoinClausePtr* original, const NProto::TJoinClause& serialized) +{ + auto result = New<TJoinClause>(); + FromProto(&result->Schema.Original, serialized.original_schema()); + FromProto(&result->Schema.Mapping, serialized.schema_mapping()); + FromProto(&result->SelfJoinedColumns, serialized.self_joined_columns()); + FromProto(&result->ForeignJoinedColumns, serialized.foreign_joined_columns()); + FromProto(&result->ForeignEquations, serialized.foreign_equations()); + FromProto(&result->SelfEquations, serialized.self_equations()); + FromProto(&result->ForeignObjectId, serialized.foreign_object_id()); + FromProto(&result->ForeignCellId, serialized.foreign_cell_id()); + FromProto(&result->IsLeft, serialized.is_left()); + FromProto(&result->CommonKeyPrefix, serialized.common_key_prefix()); + + if (serialized.has_common_key_prefix_new()) { + FromProto(&result->CommonKeyPrefix, serialized.common_key_prefix_new()); + } + + // COMPAT(lukyan) + if (serialized.can_use_source_ranges()) { + result->ForeignKeyPrefix = result->ForeignEquations.size(); + } else { + FromProto(&result->ForeignKeyPrefix, serialized.foreign_key_prefix()); + } + + if (serialized.has_predicate()) { + FromProto(&result->Predicate, serialized.predicate()); + } + + *original = result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TGroupClause* proto, const TConstGroupClausePtr& original) +{ + ToProto(proto->mutable_group_items(), original->GroupItems); + ToProto(proto->mutable_aggregate_items(), original->AggregateItems); + proto->set_totals_mode(static_cast<int>(original->TotalsMode)); + proto->set_common_prefix_with_primary_key(static_cast<int>(original->CommonPrefixWithPrimaryKey)); +} + +void FromProto(TConstGroupClausePtr* original, const NProto::TGroupClause& serialized) +{ + auto result = New<TGroupClause>(); + FromProto(&result->GroupItems, serialized.group_items()); + FromProto(&result->AggregateItems, serialized.aggregate_items()); + result->TotalsMode = ETotalsMode(serialized.totals_mode()); + result->CommonPrefixWithPrimaryKey = serialized.common_prefix_with_primary_key(); + + *original = result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TProjectClause* proto, const TConstProjectClausePtr& original) +{ + ToProto(proto->mutable_projections(), original->Projections); +} + +void FromProto(TConstProjectClausePtr* original, const NProto::TProjectClause& serialized) +{ + auto result = New<TProjectClause>(); + result->Projections.reserve(serialized.projections_size()); + for (int i = 0; i < serialized.projections_size(); ++i) { + result->AddProjection(FromProto<TNamedItem>(serialized.projections(i))); + } + *original = result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TOrderItem* serialized, const TOrderItem& original) +{ + ToProto(serialized->mutable_expression(), original.Expression); + serialized->set_descending(original.Descending); +} + +void FromProto(TOrderItem* original, const NProto::TOrderItem& serialized) +{ + FromProto(&original->Expression, serialized.expression()); + FromProto(&original->Descending, serialized.descending()); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TOrderClause* proto, const TConstOrderClausePtr& original) +{ + ToProto(proto->mutable_order_items(), original->OrderItems); +} + +void FromProto(TConstOrderClausePtr* original, const NProto::TOrderClause& serialized) +{ + auto result = New<TOrderClause>(); + FromProto(&result->OrderItems, serialized.order_items()); + *original = result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TQuery* serialized, const TConstQueryPtr& original) +{ + ToProto(serialized->mutable_id(), original->Id); + + serialized->set_offset(original->Offset); + serialized->set_limit(original->Limit); + serialized->set_use_disjoint_group_by(original->UseDisjointGroupBy); + serialized->set_infer_ranges(original->InferRanges); + serialized->set_is_final(original->IsFinal); + + ToProto(serialized->mutable_original_schema(), original->Schema.Original); + ToProto(serialized->mutable_schema_mapping(), original->Schema.Mapping); + + ToProto(serialized->mutable_join_clauses(), original->JoinClauses); + + if (original->WhereClause) { + ToProto(serialized->mutable_where_clause(), original->WhereClause); + } + + if (original->GroupClause) { + ToProto(serialized->mutable_group_clause(), original->GroupClause); + } + + if (original->HavingClause) { + ToProto(serialized->mutable_having_clause(), original->HavingClause); + } + + if (original->OrderClause) { + ToProto(serialized->mutable_order_clause(), original->OrderClause); + } + + if (original->ProjectClause) { + ToProto(serialized->mutable_project_clause(), original->ProjectClause); + } +} + +void FromProto(TConstQueryPtr* original, const NProto::TQuery& serialized) +{ + auto result = New<TQuery>(FromProto<TGuid>(serialized.id())); + + result->Offset = serialized.offset(); + result->Limit = serialized.limit(); + result->UseDisjointGroupBy = serialized.use_disjoint_group_by(); + result->InferRanges = serialized.infer_ranges(); + FromProto(&result->IsFinal, serialized.is_final()); + + FromProto(&result->Schema.Original, serialized.original_schema()); + FromProto(&result->Schema.Mapping, serialized.schema_mapping()); + + FromProto(&result->JoinClauses, serialized.join_clauses()); + + if (serialized.has_where_clause()) { + FromProto(&result->WhereClause, serialized.where_clause()); + } + + if (serialized.has_group_clause()) { + FromProto(&result->GroupClause, serialized.group_clause()); + } + + if (serialized.has_having_clause()) { + FromProto(&result->HavingClause, serialized.having_clause()); + } + + if (serialized.has_order_clause()) { + FromProto(&result->OrderClause, serialized.order_clause()); + } + + if (serialized.has_project_clause()) { + FromProto(&result->ProjectClause, serialized.project_clause()); + } + + *original = result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TQueryOptions* serialized, const TQueryOptions& original) +{ + serialized->set_timestamp(original.TimestampRange.Timestamp); + serialized->set_retention_timestamp(original.TimestampRange.RetentionTimestamp); + serialized->set_verbose_logging(original.VerboseLogging); + serialized->set_new_range_inference(original.NewRangeInference); + serialized->set_max_subqueries(original.MaxSubqueries); + serialized->set_enable_code_cache(original.EnableCodeCache); + ToProto(serialized->mutable_workload_descriptor(), original.WorkloadDescriptor); + serialized->set_allow_full_scan(original.AllowFullScan); + ToProto(serialized->mutable_read_session_id(), original.ReadSessionId); + serialized->set_deadline(ToProto<ui64>(original.Deadline)); + serialized->set_memory_limit_per_node(original.MemoryLimitPerNode); + if (original.ExecutionPool) { + serialized->set_execution_pool(*original.ExecutionPool); + } + serialized->set_suppress_access_tracking(original.SuppressAccessTracking); + serialized->set_range_expansion_limit(original.RangeExpansionLimit); +} + +void FromProto(TQueryOptions* original, const NProto::TQueryOptions& serialized) +{ + original->TimestampRange.Timestamp = serialized.timestamp(); + original->TimestampRange.RetentionTimestamp = serialized.retention_timestamp(); + original->VerboseLogging = serialized.verbose_logging(); + original->NewRangeInference = serialized.new_range_inference(); + original->MaxSubqueries = serialized.max_subqueries(); + original->EnableCodeCache = serialized.enable_code_cache(); + original->WorkloadDescriptor = serialized.has_workload_descriptor() + ? FromProto<TWorkloadDescriptor>(serialized.workload_descriptor()) + : TWorkloadDescriptor(); + original->AllowFullScan = serialized.allow_full_scan(); + original->ReadSessionId = serialized.has_read_session_id() + ? FromProto<TReadSessionId>(serialized.read_session_id()) + : TReadSessionId::Create(); + + if (serialized.has_memory_limit_per_node()) { + original->MemoryLimitPerNode = serialized.memory_limit_per_node(); + } + + if (serialized.has_execution_pool()) { + original->ExecutionPool = serialized.execution_pool(); + } + + original->Deadline = serialized.has_deadline() + ? FromProto<TInstant>(serialized.deadline()) + : TInstant::Max(); + + if (serialized.has_suppress_access_tracking()) { + original->SuppressAccessTracking = serialized.suppress_access_tracking(); + } + + if (serialized.has_range_expansion_limit()) { + original->RangeExpansionLimit = serialized.range_expansion_limit(); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToProto(NProto::TDataSource* serialized, const TDataSource& original) +{ + ToProto(serialized->mutable_object_id(), original.ObjectId); + ToProto(serialized->mutable_cell_id(), original.CellId); + serialized->set_mount_revision(original.MountRevision); + + auto rangesWriter = CreateWireProtocolWriter(); + for (const auto& range : original.Ranges) { + rangesWriter->WriteUnversionedRow(range.first); + rangesWriter->WriteUnversionedRow(range.second); + } + ToProto(serialized->mutable_ranges(), MergeRefsToString(rangesWriter->Finish())); + + if (original.Keys) { + std::vector<TColumnSchema> columns; + for (auto type : original.Schema) { + columns.emplace_back("", type); + } + + TTableSchema schema(columns); + auto keysWriter = CreateWireProtocolWriter(); + keysWriter->WriteTableSchema(schema); + keysWriter->WriteSchemafulRowset(original.Keys); + ToProto(serialized->mutable_keys(), MergeRefsToString(keysWriter->Finish())); + } + serialized->set_lookup_supported(original.LookupSupported); + serialized->set_key_width(original.KeyWidth); +} + +void FromProto(TDataSource* original, const NProto::TDataSource& serialized) +{ + FromProto(&original->ObjectId, serialized.object_id()); + FromProto(&original->CellId, serialized.cell_id()); + original->MountRevision = serialized.mount_revision(); + + struct TDataSourceBufferTag + { }; + + TRowRanges ranges; + auto rowBuffer = New<TRowBuffer>(TDataSourceBufferTag()); + auto rangesReader = CreateWireProtocolReader( + TSharedRef::FromString<TDataSourceBufferTag>(serialized.ranges()), + rowBuffer); + while (!rangesReader->IsFinished()) { + auto lowerBound = rangesReader->ReadUnversionedRow(true); + auto upperBound = rangesReader->ReadUnversionedRow(true); + ranges.emplace_back(lowerBound, upperBound); + } + original->Ranges = MakeSharedRange(std::move(ranges), rowBuffer); + + if (serialized.has_keys()) { + auto keysReader = CreateWireProtocolReader( + TSharedRef::FromString<TDataSourceBufferTag>(serialized.keys()), + rowBuffer); + + auto schema = keysReader->ReadTableSchema(); + auto schemaData = keysReader->GetSchemaData(schema, NTableClient::TColumnFilter()); + original->Keys = keysReader->ReadSchemafulRowset(schemaData, true); + } + original->LookupSupported = serialized.lookup_supported(); + original->KeyWidth = serialized.key_width(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/query.h b/yt/yt/library/query/base/query.h new file mode 100644 index 0000000000..66444d0e3d --- /dev/null +++ b/yt/yt/library/query/base/query.h @@ -0,0 +1,1231 @@ +#pragma once + +#include "public.h" +#include "query_common.h" + +#include <yt/yt/client/table_client/logical_type.h> +#include <yt/yt/client/table_client/row_buffer.h> +#include <yt/yt/client/table_client/schema.h> + +#include <yt/yt/core/misc/guid.h> +#include <yt/yt/core/misc/property.h> +#include <yt/yt/core/misc/range.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +//! Computes key index for a given column name. +int ColumnNameToKeyPartIndex(const TKeyColumns& keyColumns, const TString& columnName); + +//! Derives type of reference expression based on table column type. +//! +//! For historical reasons reference expressions used to have `wire type` of column i.e. +//! if column had `Int16` type its reference would have `Int64` type. +//! `DeriveReferenceType` keeps this behaviour for V1 types, but for V3 types actual type is returned. +NTableClient::TLogicalTypePtr ToQLType(const NTableClient::TLogicalTypePtr& columnType); + +struct TColumnDescriptor +{ + // Renamed column. + // TODO: Do not keep name but restore name from table alias and column name from original schema. + TString Name; + // Index in schema. + int Index; +}; + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EExpressionKind, + ((None) (0)) + ((Literal) (1)) + ((Reference) (2)) + ((Function) (3)) + ((UnaryOp) (4)) + ((BinaryOp) (5)) + ((In) (6)) + ((Transform) (7)) + ((Between) (8)) +); + +struct TExpression + : public TRefCounted +{ + NTableClient::TLogicalTypePtr LogicalType; + + explicit TExpression(NTableClient::TLogicalTypePtr type) + : LogicalType(std::move(type)) + { } + + explicit TExpression(EValueType type) + : LogicalType(MakeLogicalType(GetLogicalType(type), false)) + { } + + EValueType GetWireType() const + { + return NTableClient::GetWireType(LogicalType); + } + + template <class TDerived> + const TDerived* As() const + { + return dynamic_cast<const TDerived*>(this); + } + + template <class TDerived> + TDerived* As() + { + return dynamic_cast<TDerived*>(this); + } +}; + +DEFINE_REFCOUNTED_TYPE(TExpression) + +struct TLiteralExpression + : public TExpression +{ + TOwningValue Value; + + explicit TLiteralExpression(EValueType type) + : TExpression(type) + { } + + TLiteralExpression(EValueType type, TOwningValue value) + : TExpression(type) + , Value(value) + { } +}; + +struct TReferenceExpression + : public TExpression +{ + TString ColumnName; + + explicit TReferenceExpression(const NTableClient::TLogicalTypePtr& type) + : TExpression(ToQLType(type)) + { } + + TReferenceExpression(const NTableClient::TLogicalTypePtr& type, TStringBuf columnName) + : TExpression(ToQLType(type)) + , ColumnName(columnName) + { } +}; + +struct TFunctionExpression + : public TExpression +{ + TString FunctionName; + std::vector<TConstExpressionPtr> Arguments; + + explicit TFunctionExpression(EValueType type) + : TExpression(type) + { } + + TFunctionExpression( + EValueType type, + const TString& functionName, + const std::vector<TConstExpressionPtr>& arguments) + : TExpression(type) + , FunctionName(functionName) + , Arguments(arguments) + { } +}; + +DEFINE_REFCOUNTED_TYPE(TFunctionExpression) + +struct TAggregateFunctionExpression + : public TReferenceExpression +{ + std::vector<TConstExpressionPtr> Arguments; + EValueType StateType; + EValueType ResultType; + TString FunctionName; + + TAggregateFunctionExpression( + const NTableClient::TLogicalTypePtr& type, + const TString& exprName, + const std::vector<TConstExpressionPtr>& arguments, + EValueType stateType, + EValueType resultType, + const TString& functionName) + : TReferenceExpression(type, exprName) + , Arguments(arguments) + , StateType(stateType) + , ResultType(resultType) + , FunctionName(functionName) + { } +}; + +DEFINE_REFCOUNTED_TYPE(TAggregateFunctionExpression) + +struct TUnaryOpExpression + : public TExpression +{ + EUnaryOp Opcode; + TConstExpressionPtr Operand; + + explicit TUnaryOpExpression(EValueType type) + : TExpression(type) + { } + + TUnaryOpExpression( + EValueType type, + EUnaryOp opcode, + TConstExpressionPtr operand) + : TExpression(type) + , Opcode(opcode) + , Operand(operand) + { } +}; + +struct TBinaryOpExpression + : public TExpression +{ + EBinaryOp Opcode; + TConstExpressionPtr Lhs; + TConstExpressionPtr Rhs; + + explicit TBinaryOpExpression(EValueType type) + : TExpression(type) + { } + + TBinaryOpExpression( + EValueType type, + EBinaryOp opcode, + TConstExpressionPtr lhs, + TConstExpressionPtr rhs) + : TExpression(type) + , Opcode(opcode) + , Lhs(lhs) + , Rhs(rhs) + { } +}; + +struct TExpressionRowsetTag +{ }; + +struct TInExpression + : public TExpression +{ + std::vector<TConstExpressionPtr> Arguments; + TSharedRange<TRow> Values; + + explicit TInExpression(EValueType type) + : TExpression(type) + { + YT_VERIFY(type == EValueType::Boolean); + } + + TInExpression( + std::vector<TConstExpressionPtr> arguments, + TSharedRange<TRow> values) + : TExpression(EValueType::Boolean) + , Arguments(std::move(arguments)) + , Values(std::move(values)) + { } +}; + +struct TBetweenExpression + : public TExpression +{ + std::vector<TConstExpressionPtr> Arguments; + TSharedRange<TRowRange> Ranges; + + explicit TBetweenExpression(EValueType type) + : TExpression(type) + { + YT_VERIFY(type == EValueType::Boolean); + } + + TBetweenExpression( + std::vector<TConstExpressionPtr> arguments, + TSharedRange<TRowRange> ranges) + : TExpression(EValueType::Boolean) + , Arguments(std::move(arguments)) + , Ranges(std::move(ranges)) + { } +}; + +struct TTransformExpression + : public TExpression +{ + std::vector<TConstExpressionPtr> Arguments; + TSharedRange<TRow> Values; + TConstExpressionPtr DefaultExpression; + + explicit TTransformExpression(EValueType type) + : TExpression(type) + { } + + TTransformExpression( + EValueType type, + std::vector<TConstExpressionPtr> arguments, + TSharedRange<TRow> values, + TConstExpressionPtr defaultExpression) + : TExpression(type) + , Arguments(std::move(arguments)) + , Values(std::move(values)) + , DefaultExpression(std::move(defaultExpression)) + { } +}; + +void ThrowTypeMismatchError( + EValueType lhsType, + EValueType rhsType, + TStringBuf source, + TStringBuf lhsSource, + TStringBuf rhsSource); + +//////////////////////////////////////////////////////////////////////////////// + +struct TNamedItem +{ + TConstExpressionPtr Expression; + TString Name; + + TNamedItem() = default; + + TNamedItem( + TConstExpressionPtr expression, + const TString& name) + : Expression(expression) + , Name(name) + { } +}; + +using TNamedItemList = std::vector<TNamedItem>; + +struct TAggregateItem +{ + std::vector<TConstExpressionPtr> Arguments; + TString Name; + TString AggregateFunction; + EValueType StateType; + EValueType ResultType; + + TAggregateItem() = default; + + TAggregateItem( + std::vector<TConstExpressionPtr> arguments, + const TString& aggregateFunction, + const TString& name, + EValueType stateType, + EValueType resultType) + : Arguments(std::move(arguments)) + , Name(name) + , AggregateFunction(aggregateFunction) + , StateType(stateType) + , ResultType(resultType) + { } +}; + +using TAggregateItemList = std::vector<TAggregateItem>; + +//////////////////////////////////////////////////////////////////////////////// + +struct TMappedSchema +{ + TTableSchemaPtr Original; + std::vector<TColumnDescriptor> Mapping; + + std::vector<TColumnDescriptor> GetOrderedSchemaMapping() const + { + auto orderedSchemaMapping = Mapping; + std::sort(orderedSchemaMapping.begin(), orderedSchemaMapping.end(), + [] (const TColumnDescriptor& lhs, const TColumnDescriptor& rhs) { + return lhs.Index < rhs.Index; + }); + return orderedSchemaMapping; + } + + TKeyColumns GetKeyColumns() const + { + TKeyColumns result(Original->GetKeyColumnCount()); + for (const auto& item : Mapping) { + if (item.Index < Original->GetKeyColumnCount()) { + result[item.Index] = item.Name; + } + } + return result; + } + + TTableSchemaPtr GetRenamedSchema() const + { + TSchemaColumns result; + for (const auto& item : GetOrderedSchemaMapping()) { + result.emplace_back(item.Name, Original->Columns()[item.Index].LogicalType()); + } + return New<TTableSchema>(std::move(result)); + } +}; + +struct TSelfEquation +{ + TConstExpressionPtr Expression; + bool Evaluated; +}; + +struct TJoinClause + : public TRefCounted +{ + TMappedSchema Schema; + std::vector<TString> SelfJoinedColumns; + std::vector<TString> ForeignJoinedColumns; + + TConstExpressionPtr Predicate; + + std::vector<TConstExpressionPtr> ForeignEquations; + std::vector<TSelfEquation> SelfEquations; + + size_t CommonKeyPrefix = 0; + size_t ForeignKeyPrefix = 0; + + bool IsLeft = false; + + //! See #TDataSource::ObjectId. + NObjectClient::TObjectId ForeignObjectId; + //! See #TDataSource::CellId. + NObjectClient::TCellId ForeignCellId; + + TTableSchemaPtr GetRenamedSchema() const + { + return Schema.GetRenamedSchema(); + } + + TKeyColumns GetKeyColumns() const + { + return Schema.GetKeyColumns(); + } + + TTableSchemaPtr GetTableSchema(const TTableSchema& source) const + { + TSchemaColumns result; + + auto selfColumnNames = SelfJoinedColumns; + std::sort(selfColumnNames.begin(), selfColumnNames.end()); + for (const auto& column : source.Columns()) { + if (std::binary_search(selfColumnNames.begin(), selfColumnNames.end(), column.Name())) { + result.push_back(column); + } + } + + auto foreignColumnNames = ForeignJoinedColumns; + std::sort(foreignColumnNames.begin(), foreignColumnNames.end()); + auto renamedSchema = Schema.GetRenamedSchema(); + for (const auto& column : renamedSchema->Columns()) { + if (std::binary_search(foreignColumnNames.begin(), foreignColumnNames.end(), column.Name())) { + result.push_back(column); + } + } + + return New<TTableSchema>(std::move(result)); + } +}; + +DEFINE_REFCOUNTED_TYPE(TJoinClause) + +struct TGroupClause + : public TRefCounted +{ + TNamedItemList GroupItems; + TAggregateItemList AggregateItems; + ETotalsMode TotalsMode; + size_t CommonPrefixWithPrimaryKey = 0; + + void AddGroupItem(const TNamedItem& namedItem) + { + GroupItems.push_back(namedItem); + } + + void AddGroupItem(TConstExpressionPtr expression, TString name) + { + AddGroupItem(TNamedItem(expression, name)); + } + + TTableSchemaPtr GetTableSchema(bool isFinal) const + { + TSchemaColumns result; + + for (const auto& item : GroupItems) { + result.emplace_back(item.Name, item.Expression->LogicalType); + } + + for (const auto& item : AggregateItems) { + result.emplace_back(item.Name, isFinal ? item.ResultType : item.StateType); + } + + return New<TTableSchema>(std::move(result)); + } +}; + +DEFINE_REFCOUNTED_TYPE(TGroupClause) + +struct TOrderItem +{ + TConstExpressionPtr Expression; + bool Descending; +}; + +struct TOrderClause + : public TRefCounted +{ + std::vector<TOrderItem> OrderItems; +}; + +DEFINE_REFCOUNTED_TYPE(TOrderClause) + +struct TProjectClause + : public TRefCounted +{ + TNamedItemList Projections; + + void AddProjection(const TNamedItem& namedItem) + { + Projections.push_back(namedItem); + } + + void AddProjection(TConstExpressionPtr expression, TString name) + { + AddProjection(TNamedItem(expression, name)); + } + + TTableSchemaPtr GetTableSchema() const + { + TSchemaColumns result; + + for (const auto& item : Projections) { + result.emplace_back(item.Name, item.Expression->LogicalType); + } + + return New<TTableSchema>(std::move(result)); + } +}; + +DEFINE_REFCOUNTED_TYPE(TProjectClause) + +// Front Query is not Coordinatable +// IsMerge is always true for front Query and false for Bottom Query + +struct TBaseQuery + : public TRefCounted +{ + TGuid Id; + + // Merge and Final + bool IsFinal = true; + + TConstGroupClausePtr GroupClause; + TConstExpressionPtr HavingClause; + TConstOrderClausePtr OrderClause; + + TConstProjectClausePtr ProjectClause; + + i64 Offset = 0; + + // TODO: Update protocol and fix it + // If Limit == std::numeric_limits<i64>::max() - 1, then do ordered read with prefetch + i64 Limit = std::numeric_limits<i64>::max(); + bool UseDisjointGroupBy = false; + bool InferRanges = true; + + explicit TBaseQuery(TGuid id = TGuid::Create()) + : Id(id) + { } + + TBaseQuery(const TBaseQuery& other) + : Id(TGuid::Create()) + , IsFinal(other.IsFinal) + , GroupClause(other.GroupClause) + , HavingClause(other.HavingClause) + , OrderClause(other.OrderClause) + , ProjectClause(other.ProjectClause) + , Offset(other.Offset) + , Limit(other.Limit) + , UseDisjointGroupBy(other.UseDisjointGroupBy) + , InferRanges(other.InferRanges) + { } + + bool IsOrdered() const + { + if (Limit < std::numeric_limits<i64>::max()) { + return !OrderClause; + } else { + YT_VERIFY(!OrderClause); + return false; + } + } + + virtual TTableSchemaPtr GetReadSchema() const = 0; + virtual TTableSchemaPtr GetTableSchema() const = 0; +}; + +DEFINE_REFCOUNTED_TYPE(TBaseQuery) + +struct TQuery + : public TBaseQuery +{ + TMappedSchema Schema; + + // Bottom + std::vector<TConstJoinClausePtr> JoinClauses; + TConstExpressionPtr WhereClause; + + explicit TQuery(TGuid id = TGuid::Create()) + : TBaseQuery(id) + { } + + TQuery(const TQuery& other) = default; + + TKeyColumns GetKeyColumns() const + { + return Schema.GetKeyColumns(); + } + + TTableSchemaPtr GetReadSchema() const override + { + TSchemaColumns result; + + for (const auto& item : Schema.GetOrderedSchemaMapping()) { + result.emplace_back( + Schema.Original->Columns()[item.Index].Name(), + Schema.Original->Columns()[item.Index].LogicalType()); + } + + return New<TTableSchema>(std::move(result)); + } + + TTableSchemaPtr GetRenamedSchema() const + { + return Schema.GetRenamedSchema(); + } + + TTableSchemaPtr GetTableSchema() const override + { + if (ProjectClause) { + return ProjectClause->GetTableSchema(); + } + + if (GroupClause) { + return GroupClause->GetTableSchema(IsFinal); + } + + auto result = GetRenamedSchema(); + + for (const auto& joinClause : JoinClauses) { + result = joinClause->GetTableSchema(*result); + } + + return result; + } +}; + +DEFINE_REFCOUNTED_TYPE(TQuery) + +struct TFrontQuery + : public TBaseQuery +{ + explicit TFrontQuery(TGuid id = TGuid::Create()) + : TBaseQuery(id) + { } + + TFrontQuery(const TFrontQuery& other) = default; + + TTableSchemaPtr Schema; + + TTableSchemaPtr GetReadSchema() const override + { + return Schema; + } + + TTableSchemaPtr GetRenamedSchema() const + { + return Schema; + } + + TTableSchemaPtr GetTableSchema() const override + { + if (ProjectClause) { + return ProjectClause->GetTableSchema(); + } + + if (GroupClause) { + return GroupClause->GetTableSchema(IsFinal); + } + + return Schema; + } +}; + +DEFINE_REFCOUNTED_TYPE(TFrontQuery) + +template <class TResult, class TDerived, class TNode, class... TArgs> +struct TAbstractVisitor +{ + TDerived* Derived() + { + return static_cast<TDerived*>(this); + } + + TResult Visit(TNode node, TArgs... args) + { + auto expr = Derived()->GetExpression(node); + + if (auto literalExpr = expr->template As<TLiteralExpression>()) { + return Derived()->OnLiteral(literalExpr, args...); + } else if (auto referenceExpr = expr->template As<TReferenceExpression>()) { + return Derived()->OnReference(referenceExpr, args...); + } else if (auto unaryOp = expr->template As<TUnaryOpExpression>()) { + return Derived()->OnUnary(unaryOp, args...); + } else if (auto binaryOp = expr->template As<TBinaryOpExpression>()) { + return Derived()->OnBinary(binaryOp, args...); + } else if (auto functionExpr = expr->template As<TFunctionExpression>()) { + return Derived()->OnFunction(functionExpr, args...); + } else if (auto inExpr = expr->template As<TInExpression>()) { + return Derived()->OnIn(inExpr, args...); + } else if (auto betweenExpr = expr->template As<TBetweenExpression>()) { + return Derived()->OnBetween(betweenExpr, args...); + } else if (auto transformExpr = expr->template As<TTransformExpression>()) { + return Derived()->OnTransform(transformExpr, args...); + } + YT_ABORT(); + } + +}; + +template <class TResult, class TDerived> +struct TBaseVisitor + : TAbstractVisitor<TResult, TDerived, TConstExpressionPtr> +{ + const TExpression* GetExpression(const TConstExpressionPtr& expr) + { + return &*expr; + } + +}; + +template <class TDerived> +struct TVisitor + : public TBaseVisitor<void, TDerived> +{ + using TBase = TBaseVisitor<void, TDerived>; + using TBase::Derived; + using TBase::Visit; + + void OnLiteral(const TLiteralExpression* /*literalExpr*/) + { } + + void OnReference(const TReferenceExpression* /*referenceExpr*/) + { } + + void OnUnary(const TUnaryOpExpression* unaryExpr) + { + Visit(unaryExpr->Operand); + } + + void OnBinary(const TBinaryOpExpression* binaryExpr) + { + Visit(binaryExpr->Lhs); + Visit(binaryExpr->Rhs); + } + + void OnFunction(const TFunctionExpression* functionExpr) + { + for (auto argument : functionExpr->Arguments) { + Visit(argument); + } + } + + void OnIn(const TInExpression* inExpr) + { + for (auto argument : inExpr->Arguments) { + Visit(argument); + } + } + + void OnBetween(const TBetweenExpression* betweenExpr) + { + for (auto argument : betweenExpr->Arguments) { + Visit(argument); + } + } + + void OnTransform(const TTransformExpression* transformExpr) + { + for (auto argument : transformExpr->Arguments) { + Visit(argument); + } + } + +}; + +template <class TDerived> +struct TRewriter + : public TBaseVisitor<TConstExpressionPtr, TDerived> +{ + using TBase = TBaseVisitor<TConstExpressionPtr, TDerived>; + using TBase::Derived; + using TBase::Visit; + + TConstExpressionPtr OnLiteral(const TLiteralExpression* literalExpr) + { + return literalExpr; + } + + TConstExpressionPtr OnReference(const TReferenceExpression* referenceExpr) + { + return referenceExpr; + } + + TConstExpressionPtr OnUnary(const TUnaryOpExpression* unaryExpr) + { + auto newOperand = Visit(unaryExpr->Operand); + + if (newOperand == unaryExpr->Operand) { + return unaryExpr; + } + + return New<TUnaryOpExpression>( + unaryExpr->GetWireType(), + unaryExpr->Opcode, + newOperand); + } + + TConstExpressionPtr OnBinary(const TBinaryOpExpression* binaryExpr) + { + auto newLhs = Visit(binaryExpr->Lhs); + auto newRhs = Visit(binaryExpr->Rhs); + + if (newLhs == binaryExpr->Lhs && newRhs == binaryExpr->Rhs) { + return binaryExpr; + } + + return New<TBinaryOpExpression>( + binaryExpr->GetWireType(), + binaryExpr->Opcode, + newLhs, + newRhs); + } + + TConstExpressionPtr OnFunction(const TFunctionExpression* functionExpr) + { + std::vector<TConstExpressionPtr> newArguments; + bool allEqual = true; + for (auto argument : functionExpr->Arguments) { + auto newArgument = Visit(argument); + allEqual = allEqual && newArgument == argument; + newArguments.push_back(newArgument); + } + + if (allEqual) { + return functionExpr; + } + + return New<TFunctionExpression>( + functionExpr->GetWireType(), + functionExpr->FunctionName, + std::move(newArguments)); + } + + TConstExpressionPtr OnIn(const TInExpression* inExpr) + { + std::vector<TConstExpressionPtr> newArguments; + bool allEqual = true; + for (auto argument : inExpr->Arguments) { + auto newArgument = Visit(argument); + allEqual = allEqual && newArgument == argument; + newArguments.push_back(newArgument); + } + + if (allEqual) { + return inExpr; + } + + return New<TInExpression>( + std::move(newArguments), + inExpr->Values); + } + + TConstExpressionPtr OnBetween(const TBetweenExpression* betweenExpr) + { + std::vector<TConstExpressionPtr> newArguments; + bool allEqual = true; + for (auto argument : betweenExpr->Arguments) { + auto newArgument = Visit(argument); + allEqual = allEqual && newArgument == argument; + newArguments.push_back(newArgument); + } + + if (allEqual) { + return betweenExpr; + } + + return New<TBetweenExpression>( + std::move(newArguments), + betweenExpr->Ranges); + } + + TConstExpressionPtr OnTransform(const TTransformExpression* transformExpr) + { + std::vector<TConstExpressionPtr> newArguments; + bool allEqual = true; + for (auto argument : transformExpr->Arguments) { + auto newArgument = Visit(argument); + allEqual = allEqual && newArgument == argument; + newArguments.push_back(newArgument); + } + + TConstExpressionPtr newDefaultExpression; + if (const auto& defaultExpression = transformExpr->DefaultExpression) { + newDefaultExpression = Visit(defaultExpression); + allEqual = allEqual && newDefaultExpression == defaultExpression; + } + + if (allEqual) { + return transformExpr; + } + + return New<TTransformExpression>( + transformExpr->GetWireType(), + std::move(newArguments), + transformExpr->Values, + newDefaultExpression); + } + +}; + +template <class TDerived, class TNode, class... TArgs> +struct TAbstractExpressionPrinter + : TAbstractVisitor<void, TDerived, TNode, TArgs...> +{ + using TBase = TAbstractVisitor<void, TDerived, TNode, TArgs...>; + using TBase::Derived; + using TBase::Visit; + + TStringBuilderBase* Builder; + bool OmitValues; + + TAbstractExpressionPrinter(TStringBuilderBase* builder, bool omitValues) + : Builder(builder) + , OmitValues(omitValues) + { } + + static int GetOpPriority(EBinaryOp op) + { + switch (op) { + case EBinaryOp::Multiply: + case EBinaryOp::Divide: + case EBinaryOp::Modulo: + return 0; + + case EBinaryOp::Plus: + case EBinaryOp::Minus: + case EBinaryOp::Concatenate: + return 1; + + case EBinaryOp::LeftShift: + case EBinaryOp::RightShift: + return 2; + + case EBinaryOp::BitAnd: + return 3; + + case EBinaryOp::BitOr: + return 4; + + case EBinaryOp::Equal: + case EBinaryOp::NotEqual: + case EBinaryOp::Less: + case EBinaryOp::LessOrEqual: + case EBinaryOp::Greater: + case EBinaryOp::GreaterOrEqual: + return 5; + + case EBinaryOp::And: + return 6; + + case EBinaryOp::Or: + return 7; + + default: + YT_ABORT(); + } + } + + static bool CanOmitParenthesis(TConstExpressionPtr expr) + { + return + expr->As<TLiteralExpression>() || + expr->As<TReferenceExpression>() || + expr->As<TFunctionExpression>() || + expr->As<TUnaryOpExpression>() || + expr->As<TTransformExpression>(); + } + + const TExpression* GetExpression(const TConstExpressionPtr& expr) + { + return &*expr; + } + + void OnOperand(const TUnaryOpExpression* unaryExpr, TArgs... args) + { + Visit(unaryExpr->Operand, args...); + } + + void OnLhs(const TBinaryOpExpression* binaryExpr, TArgs... args) + { + Visit(binaryExpr->Lhs, args...); + } + + void OnRhs(const TBinaryOpExpression* binaryExpr, TArgs... args) + { + Visit(binaryExpr->Rhs, args...); + } + + void OnDefaultExpression(const TTransformExpression* transformExpr, TArgs... args) + { + if (const auto& defaultExpression = transformExpr->DefaultExpression) { + Builder->AppendString(", "); + Visit(defaultExpression, args...); + } + } + + template <class T> + void OnArguments(const T* expr, TArgs... args) + { + bool needComma = false; + for (const auto& argument : expr->Arguments) { + if (needComma) { + Builder->AppendString(", "); + } + Visit(argument, args...); + needComma = true; + } + } + + void OnLiteral(const TLiteralExpression* literalExpr, TArgs... /*args*/) + { + if (OmitValues) { + Builder->AppendChar('?'); + } else { + Builder->AppendString(ToString(static_cast<TValue>(literalExpr->Value))); + } + } + + void OnReference(const TReferenceExpression* referenceExpr, TArgs... /*args*/) + { + Builder->AppendString(referenceExpr->ColumnName); + } + + void OnUnary(const TUnaryOpExpression* unaryExpr, TArgs... args) + { + Builder->AppendString(GetUnaryOpcodeLexeme(unaryExpr->Opcode)); + Builder->AppendChar(' '); + + auto needParenthesis = !CanOmitParenthesis(unaryExpr->Operand); + if (needParenthesis) { + Builder->AppendChar('('); + } + Derived()->OnOperand(unaryExpr, args...); + if (needParenthesis) { + Builder->AppendChar(')'); + } + } + + void OnBinary(const TBinaryOpExpression* binaryExpr, TArgs... args) + { + auto needParenthesisLhs = !CanOmitParenthesis(binaryExpr->Lhs); + if (needParenthesisLhs) { + if (const auto* lhs = binaryExpr->Lhs->As<TBinaryOpExpression>()) { + if (GetOpPriority(lhs->Opcode) <= GetOpPriority(binaryExpr->Opcode)) { + needParenthesisLhs = false; + } + } + } + + if (needParenthesisLhs) { + Builder->AppendChar('('); + } + Derived()->OnLhs(binaryExpr, args...); + if (needParenthesisLhs) { + Builder->AppendChar(')'); + } + + Builder->AppendChar(' '); + Builder->AppendString(GetBinaryOpcodeLexeme(binaryExpr->Opcode)); + Builder->AppendChar(' '); + + auto needParenthesisRhs = !CanOmitParenthesis(binaryExpr->Rhs); + if (needParenthesisRhs) { + if (const auto* rhs = binaryExpr->Rhs->As<TBinaryOpExpression>()) { + if (GetOpPriority(rhs->Opcode) <= GetOpPriority(binaryExpr->Opcode)) { + needParenthesisRhs = false; + } + } + } + + if (needParenthesisRhs) { + Builder->AppendChar('('); + } + Derived()->OnRhs(binaryExpr, args...); + if (needParenthesisRhs) { + Builder->AppendChar(')'); + } + } + + void OnFunction(const TFunctionExpression* functionExpr, TArgs... args) + { + Builder->AppendString(functionExpr->FunctionName); + Builder->AppendChar('('); + Derived()->OnArguments(functionExpr, args...); + Builder->AppendChar(')'); + } + + void OnIn(const TInExpression* inExpr, TArgs... args) + { + auto needParenthesis = inExpr->Arguments.size() > 1; + if (needParenthesis) { + Builder->AppendChar('('); + } + Derived()->OnArguments(inExpr, args...); + if (needParenthesis) { + Builder->AppendChar(')'); + } + + Builder->AppendString(" IN ("); + + if (OmitValues) { + Builder->AppendString("??"); + } else { + JoinToString( + Builder, + inExpr->Values.begin(), + inExpr->Values.end(), + [&] (TStringBuilderBase* builder, const TRow& row) { + builder->AppendString(ToString(row)); + }); + } + Builder->AppendChar(')'); + } + + void OnBetween(const TBetweenExpression* betweenExpr, TArgs... args) + { + auto needParenthesis = betweenExpr->Arguments.size() > 1; + if (needParenthesis) { + Builder->AppendChar('('); + } + Derived()->OnArguments(betweenExpr, args...); + if (needParenthesis) { + Builder->AppendChar(')'); + } + + Builder->AppendString(" BETWEEN ("); + + if (OmitValues) { + Builder->AppendString("??"); + } else { + JoinToString( + Builder, + betweenExpr->Ranges.begin(), + betweenExpr->Ranges.end(), + [&] (TStringBuilderBase* builder, const TRowRange& range) { + builder->AppendString(ToString(range.first)); + builder->AppendString(" AND "); + builder->AppendString(ToString(range.second)); + }); + } + Builder->AppendChar(')'); + } + + void OnTransform(const TTransformExpression* transformExpr, TArgs... args) + { + Builder->AppendString("TRANSFORM("); + size_t argumentCount = transformExpr->Arguments.size(); + auto needParenthesis = argumentCount > 1; + if (needParenthesis) { + Builder->AppendChar('('); + } + Derived()->OnArguments(transformExpr, args...); + if (needParenthesis) { + Builder->AppendChar(')'); + } + + Builder->AppendString(", ("); + if (OmitValues) { + Builder->AppendString("??"); + } else { + JoinToString( + Builder, + transformExpr->Values.begin(), + transformExpr->Values.end(), + [&] (TStringBuilderBase* builder, const TRow& row) { + builder->AppendChar('['); + JoinToString( + builder, + row.Begin(), + row.Begin() + argumentCount, + [] (TStringBuilderBase* builder, const TValue& value) { + builder->AppendString(ToString(value)); + }); + builder->AppendChar(']'); + }); + } + Builder->AppendString("), ("); + + if (OmitValues) { + Builder->AppendString("??"); + } else { + JoinToString( + Builder, + transformExpr->Values.begin(), + transformExpr->Values.end(), + [&] (TStringBuilderBase* builder, const TRow& row) { + builder->AppendString(ToString(row[argumentCount])); + }); + } + + Builder->AppendChar(')'); + + Derived()->OnDefaultExpression(transformExpr, args...); + + Builder->AppendChar(')'); + } + +}; + +void ToProto(NProto::TQuery* serialized, const TConstQueryPtr& original); +void FromProto(TConstQueryPtr* original, const NProto::TQuery& serialized); + +void ToProto(NProto::TQueryOptions* serialized, const TQueryOptions& original); +void FromProto(TQueryOptions* original, const NProto::TQueryOptions& serialized); + +void ToProto(NProto::TDataSource* serialized, const TDataSource& original); +void FromProto(TDataSource* original, const NProto::TDataSource& serialized); + +struct TInferNameOptions +{ + bool OmitValues = false; + bool OmitAliases = false; + bool OmitJoinPredicate = false; + bool OmitOffsetAndLimit = false; +}; + +TString InferName(TConstExpressionPtr expr, bool omitValues = false); +TString InferName(TConstBaseQueryPtr query, TInferNameOptions options = {}); + +bool Compare( + TConstExpressionPtr lhs, + const TTableSchema& lhsSchema, + TConstExpressionPtr rhs, + const TTableSchema& rhsSchema, + size_t maxIndex = std::numeric_limits<size_t>::max()); + +std::vector<size_t> GetJoinGroups( + const std::vector<TConstJoinClausePtr>& joinClauses, + TTableSchemaPtr schema); + +NLogging::TLogger MakeQueryLogger(TGuid queryId); +NLogging::TLogger MakeQueryLogger(TConstBaseQueryPtr query); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/query_common.cpp b/yt/yt/library/query/base/query_common.cpp new file mode 100644 index 0000000000..8a47dc6282 --- /dev/null +++ b/yt/yt/library/query/base/query_common.cpp @@ -0,0 +1,186 @@ +#include "query_common.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +const char* GetUnaryOpcodeLexeme(EUnaryOp opcode) +{ + switch (opcode) { + case EUnaryOp::Plus: return "+"; + case EUnaryOp::Minus: return "-"; + case EUnaryOp::Not: return "NOT"; + case EUnaryOp::BitNot:return "~"; + default: YT_ABORT(); + } +} + +const char* GetBinaryOpcodeLexeme(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Plus: return "+"; + case EBinaryOp::Minus: return "-"; + case EBinaryOp::Multiply: return "*"; + case EBinaryOp::Divide: return "/"; + case EBinaryOp::Modulo: return "%"; + case EBinaryOp::LeftShift: return "<<"; + case EBinaryOp::RightShift: return ">>"; + case EBinaryOp::BitAnd: return "&"; + case EBinaryOp::BitOr: return "|"; + case EBinaryOp::And: return "AND"; + case EBinaryOp::Or: return "OR"; + case EBinaryOp::Equal: return "="; + case EBinaryOp::NotEqual: return "!="; + case EBinaryOp::Less: return "<"; + case EBinaryOp::LessOrEqual: return "<="; + case EBinaryOp::Greater: return ">"; + case EBinaryOp::GreaterOrEqual: return ">="; + case EBinaryOp::Concatenate: return "||"; + default: YT_ABORT(); + } +} + +EBinaryOp GetReversedBinaryOpcode(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Less: return EBinaryOp::Greater; + case EBinaryOp::LessOrEqual: return EBinaryOp::GreaterOrEqual; + case EBinaryOp::Greater: return EBinaryOp::Less; + case EBinaryOp::GreaterOrEqual: return EBinaryOp::LessOrEqual; + default: return opcode; + } +} + +EBinaryOp GetInversedBinaryOpcode(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Equal: return EBinaryOp::NotEqual; + case EBinaryOp::NotEqual: return EBinaryOp::Equal; + case EBinaryOp::Less: return EBinaryOp::GreaterOrEqual; + case EBinaryOp::LessOrEqual: return EBinaryOp::Greater; + case EBinaryOp::Greater: return EBinaryOp::LessOrEqual; + case EBinaryOp::GreaterOrEqual: return EBinaryOp::Less; + default: YT_ABORT(); + } +} + +bool IsArithmeticalBinaryOp(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Plus: + case EBinaryOp::Minus: + case EBinaryOp::Multiply: + case EBinaryOp::Divide: + return true; + default: + return false; + } +} + +bool IsIntegralBinaryOp(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Modulo: + case EBinaryOp::LeftShift: + case EBinaryOp::RightShift: + case EBinaryOp::BitOr: + case EBinaryOp::BitAnd: + return true; + default: + return false; + } +} + +bool IsLogicalBinaryOp(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::And: + case EBinaryOp::Or: + return true; + default: + return false; + } +} + +bool IsRelationalBinaryOp(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Equal: + case EBinaryOp::NotEqual: + case EBinaryOp::Less: + case EBinaryOp::LessOrEqual: + case EBinaryOp::Greater: + case EBinaryOp::GreaterOrEqual: + return true; + default: + return false; + } +} + +bool IsStringBinaryOp(EBinaryOp opcode) +{ + switch (opcode) { + case EBinaryOp::Concatenate: + return true; + default: + return false; + } +} + +TValue CastValueWithCheck(TValue value, EValueType targetType) +{ + if (value.Type == targetType || value.Type == EValueType::Null) { + return value; + } + + if (value.Type == EValueType::Int64) { + if (targetType == EValueType::Double) { + auto int64Value = value.Data.Int64; + if (i64(double(int64Value)) != int64Value) { + THROW_ERROR_EXCEPTION("Failed to cast %v to double: inaccurate conversion", int64Value); + } + value.Data.Double = int64Value; + } else { + YT_VERIFY(targetType == EValueType::Uint64); + } + } else if (value.Type == EValueType::Uint64) { + if (targetType == EValueType::Int64) { + if (value.Data.Uint64 > std::numeric_limits<i64>::max()) { + THROW_ERROR_EXCEPTION( + "Failed to cast %vu to int64: value is greater than maximum", value.Data.Uint64); + } + } else if (targetType == EValueType::Double) { + auto uint64Value = value.Data.Uint64; + if (ui64(double(uint64Value)) != uint64Value) { + THROW_ERROR_EXCEPTION("Failed to cast %vu to double: inaccurate conversion", uint64Value); + } + value.Data.Double = uint64Value; + } else { + YT_ABORT(); + } + } else if (value.Type == EValueType::Double) { + auto doubleValue = value.Data.Double; + if (targetType == EValueType::Uint64) { + if (double(ui64(doubleValue)) != doubleValue) { + THROW_ERROR_EXCEPTION("Failed to cast %v to uint64: inaccurate conversion", doubleValue); + } + value.Data.Uint64 = doubleValue; + } else if (targetType == EValueType::Int64) { + if (double(i64(doubleValue)) != doubleValue) { + THROW_ERROR_EXCEPTION("Failed to cast %v to int64: inaccurate conversion", doubleValue); + } + value.Data.Int64 = doubleValue; + } else { + YT_ABORT(); + } + } else { + YT_ABORT(); + } + + value.Type = targetType; + return value; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/query_common.h b/yt/yt/library/query/base/query_common.h new file mode 100644 index 0000000000..da4edcbe66 --- /dev/null +++ b/yt/yt/library/query/base/query_common.h @@ -0,0 +1,169 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/hydra/public.h> + +#include <yt/yt/client/table_client/public.h> + +#include <yt/yt/client/transaction_client/helpers.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <yt/yt/client/misc/workload.h> + +namespace NYT::NQueryClient { + +using NTransactionClient::TReadTimestampRange; + +//////////////////////////////////////////////////////////////////////////////// + +struct TDataSplit +{ + TGuid ObjectId; + TGuid CellId; + + TTableSchemaPtr TableSchema; + + TLegacyOwningKey LowerBound = NTableClient::MinKey(); + TLegacyOwningKey UpperBound = NTableClient::MaxKey(); +}; + +//////////////////////////////////////////////////////////////////////////////// + +using TSourceLocation = std::pair<int, int>; +static const TSourceLocation NullSourceLocation(0, 0); + +DEFINE_ENUM(EUnaryOp, + // Arithmetical operations. + (Plus) + (Minus) + // Integral operations. + (BitNot) + // Logical operations. + (Not) +); + +DEFINE_ENUM(EBinaryOp, + // Arithmetical operations. + (Plus) + (Minus) + (Multiply) + (Divide) + // Integral operations. + (Modulo) + (LeftShift) + (RightShift) + (BitOr) + (BitAnd) + // Logical operations. + (And) + (Or) + // Relational operations. + (Equal) + (NotEqual) + (Less) + (LessOrEqual) + (Greater) + (GreaterOrEqual) + // String operations. + (Concatenate) +); + +DEFINE_ENUM(ETotalsMode, + (None) + (BeforeHaving) + (AfterHaving) +); + +DEFINE_ENUM(EAggregateFunction, + (Sum) + (Min) + (Max) +); + +const char* GetUnaryOpcodeLexeme(EUnaryOp opcode); +const char* GetBinaryOpcodeLexeme(EBinaryOp opcode); + +//! Reverse binary opcode for comparison operations (for swapping arguments). +EBinaryOp GetReversedBinaryOpcode(EBinaryOp opcode); + +//! Inverse binary opcode for comparison operations (for inverting the operation). +EBinaryOp GetInversedBinaryOpcode(EBinaryOp opcode); + +//! Classifies binary opcode according to classification above. +bool IsArithmeticalBinaryOp(EBinaryOp opcode); + +//! Classifies binary opcode according to classification above. +bool IsIntegralBinaryOp(EBinaryOp opcode); + +//! Classifies binary opcode according to classification above. +bool IsLogicalBinaryOp(EBinaryOp opcode); + +//! Classifies binary opcode according to classification above. +bool IsRelationalBinaryOp(EBinaryOp opcode); + +//! Classifies binary opcode according to classification above. +bool IsStringBinaryOp(EBinaryOp opcode); + +//! Cast numeric values. +TValue CastValueWithCheck(TValue value, EValueType targetType); + +//////////////////////////////////////////////////////////////////////////////// + +// TODO(lukyan): Use opaque data descriptor instead of ObjectId, CellId and MountRevision. +struct TDataSource +{ + // Could be: + // * a table id; + // * a tablet id. + NObjectClient::TObjectId ObjectId; + // If #ObjectId is a tablet id then this is the id of the cell hosting this tablet. + // COMPAT(babenko): legacy clients may omit this field. + NObjectClient::TCellId CellId; + + NHydra::TRevision MountRevision; + + std::vector<NTableClient::TLogicalTypePtr> Schema; + + TSharedRange<TRowRange> Ranges; + TSharedRange<TRow> Keys; + + //! If |true|, these ranges could be reclassified into a set of discrete lookup keys. + bool LookupSupported = true; + + size_t KeyWidth = 0; +}; + +struct TQueryBaseOptions +{ + i64 InputRowLimit = std::numeric_limits<i64>::max(); + i64 OutputRowLimit = std::numeric_limits<i64>::max(); + + bool EnableCodeCache = true; + bool UseCanonicalNullRelations = false; + TReadSessionId ReadSessionId; + size_t MemoryLimitPerNode = std::numeric_limits<size_t>::max(); +}; + +struct TQueryOptions + : public TQueryBaseOptions +{ + TReadTimestampRange TimestampRange{ + .Timestamp = NTransactionClient::SyncLastCommittedTimestamp, + .RetentionTimestamp = NTransactionClient::NullTimestamp, + }; + bool VerboseLogging = false; + int MaxSubqueries = std::numeric_limits<int>::max(); + ui64 RangeExpansionLimit = 0; + TWorkloadDescriptor WorkloadDescriptor; + bool AllowFullScan = true; + TInstant Deadline = TInstant::Max(); + bool SuppressAccessTracking = false; + std::optional<TString> ExecutionPool; + // COMPAT(lukyan) + bool NewRangeInference = true; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/query_helpers.cpp b/yt/yt/library/query/base/query_helpers.cpp new file mode 100644 index 0000000000..dec6e2ab4a --- /dev/null +++ b/yt/yt/library/query/base/query_helpers.cpp @@ -0,0 +1,617 @@ +#include "query_helpers.h" +#include "query.h" + +#include <yt/yt/client/table_client/name_table.h> +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NQueryClient { + +using namespace NTableClient; + +using ::ToString; + +//////////////////////////////////////////////////////////////////////////////// + +bool IsTrue(TConstExpressionPtr expr) +{ + if (auto literalExpr = expr->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean; + } + } + return false; +} + +TConstExpressionPtr MakeAndExpression(TConstExpressionPtr lhs, TConstExpressionPtr rhs) +{ + if (auto literalExpr = lhs->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean ? rhs : lhs; + } + } + + if (auto literalExpr = rhs->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean ? lhs : rhs; + } + } + + return New<TBinaryOpExpression>( + EValueType::Boolean, + EBinaryOp::And, + lhs, + rhs); +} + +TConstExpressionPtr MakeOrExpression(TConstExpressionPtr lhs, TConstExpressionPtr rhs) +{ + if (auto literalExpr = lhs->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean ? lhs : rhs; + } + } + + if (auto literalExpr = rhs->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean ? rhs : lhs; + } + } + + return New<TBinaryOpExpression>( + EValueType::Boolean, + EBinaryOp::Or, + lhs, + rhs); +} + +namespace { + +int CompareRow(TRow lhs, TRow rhs, const std::vector<size_t>& mapping) +{ + for (auto index : mapping) { + int result = CompareRowValuesCheckingNan(lhs.Begin()[index], rhs.Begin()[index]); + + if (result != 0) { + return result; + } + } + return 0; +} + +void SortRows( + std::vector<TRow>::iterator begin, + std::vector<TRow>::iterator end, + const std::vector<size_t>& mapping) +{ + std::sort(begin, end, [&] (TRow lhs, TRow rhs) { + return CompareRow(lhs, rhs, mapping) < 0; + }); +} + +void SortRows( + std::vector<std::pair<TRow, size_t>>::iterator begin, + std::vector<std::pair<TRow, size_t>>::iterator end, + const std::vector<size_t>& mapping) +{ + std::sort(begin, end, [&] (const std::pair<TRow, size_t>& lhs, const std::pair<TRow, size_t>& rhs) { + return CompareRow(lhs.first, rhs.first, mapping) < 0; + }); +} + +} // namespace + +TConstExpressionPtr EliminateInExpression( + TRange<TRow> lookupKeys, + const TInExpression* inExpr, + const TKeyColumns& keyColumns, + size_t keyPrefixSize, + const std::vector<std::pair<TBound, TBound>>* bounds) +{ + static auto trueLiteral = New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(true)); + static auto falseLiteral = New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(false)); + + std::vector<size_t> valueMapping; + std::vector<size_t> keyMapping; + std::optional<size_t> rangeArgIndex; + + bool allArgsAreKey = true; + for (size_t argumentIndex = 0; argumentIndex < inExpr->Arguments.size(); ++argumentIndex) { + const auto& argument = inExpr->Arguments[argumentIndex]; + auto referenceExpr = argument->As<TReferenceExpression>(); + int keyIndex = referenceExpr + ? ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName) + : -1; + if (keyIndex == -1 || keyIndex >= static_cast<int>(keyPrefixSize)) { + allArgsAreKey = false; + } else { + valueMapping.push_back(argumentIndex); + keyMapping.push_back(keyIndex); + } + + if (bounds && keyIndex == static_cast<int>(keyPrefixSize)) { + rangeArgIndex = argumentIndex; + } + } + + auto compareKeyAndValue = [&] (TRow lhs, TRow rhs) { + for (int index = 0; index < std::ssize(valueMapping); ++index) { + int result = CompareRowValues(lhs.Begin()[keyMapping[index]], rhs.Begin()[valueMapping[index]]); + + if (result != 0) { + return result; + } + } + return 0; + }; + + std::vector<TRow> sortedValues(inExpr->Values.Begin(), inExpr->Values.End()); + if (!allArgsAreKey) { + SortRows(sortedValues.begin(), sortedValues.end(), valueMapping); + } + + std::vector<std::pair<TRow, size_t>> sortedKeys(lookupKeys.Size()); + for (size_t index = 0; index < lookupKeys.Size(); ++index) { + sortedKeys[index] = std::make_pair(lookupKeys[index], index); + } + SortRows(sortedKeys.begin(), sortedKeys.end(), keyMapping); + + std::vector<TRow> filteredValues; + bool hasExtraLookupKeys = false; + size_t keyIndex = 0; + size_t tupleIndex = 0; + while (keyIndex < sortedKeys.size() && tupleIndex < sortedValues.size()) { + auto currentKey = sortedKeys[keyIndex]; + auto currentValue = sortedValues[tupleIndex]; + + int result = compareKeyAndValue(currentKey.first, currentValue); + if (result == 0) { + auto keyIndexBegin = keyIndex; + do { + ++keyIndex; + } while (keyIndex < sortedKeys.size() + && CompareRow(currentKey.first, sortedKeys[keyIndex].first, keyMapping) == 0); + + // from keyIndexBegin to keyIndex + std::vector<TBound> unitedBounds; + if (bounds) { + std::vector<std::vector<TBound>> allBounds; + for (size_t index = keyIndexBegin; index < keyIndex; ++index) { + auto lowerAndUpper = (*bounds)[sortedKeys[index].second]; + + allBounds.push_back(std::vector<TBound>{lowerAndUpper.first, lowerAndUpper.second}); + } + + UniteBounds(&allBounds); + + YT_VERIFY(!allBounds.empty()); + unitedBounds = std::move(allBounds.front()); + } + + do { + if (!rangeArgIndex || Covers(unitedBounds, sortedValues[tupleIndex][*rangeArgIndex])) { + filteredValues.push_back(sortedValues[tupleIndex]); + } + ++tupleIndex; + } while (tupleIndex < sortedValues.size() && + CompareRow(currentValue, sortedValues[tupleIndex], valueMapping) == 0); + } else if (result < 0) { + hasExtraLookupKeys = true; + ++keyIndex; + } else { + ++tupleIndex; + } + } + + if (keyIndex != sortedKeys.size()) { + hasExtraLookupKeys = true; + } + + if (!hasExtraLookupKeys && allArgsAreKey) { + return trueLiteral; + } else { + if (filteredValues.empty()) { + return falseLiteral; + } else { + std::sort(filteredValues.begin(), filteredValues.end()); + return New<TInExpression>( + inExpr->Arguments, + MakeSharedRange(std::move(filteredValues), inExpr->Values)); + } + } +} + +TConstExpressionPtr EliminatePredicate( + TRange<TRowRange> keyRanges, + TConstExpressionPtr expr, + const TKeyColumns& keyColumns) +{ + auto trueLiteral = New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(true)); + auto falseLiteral = New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(false)); + + int minCommonPrefixSize = std::numeric_limits<int>::max(); + for (const auto& keyRange : keyRanges) { + int commonPrefixSize = 0; + while (commonPrefixSize < static_cast<int>(keyRange.first.GetCount()) + && commonPrefixSize + 1 < static_cast<int>(keyRange.second.GetCount()) + && keyRange.first[commonPrefixSize] == keyRange.second[commonPrefixSize]) + { + commonPrefixSize++; + } + minCommonPrefixSize = std::min(minCommonPrefixSize, commonPrefixSize); + } + + auto getBounds = [] (const TRowRange& keyRange, size_t keyPartIndex) -> std::pair<TBound, TBound> { + auto lower = keyPartIndex < keyRange.first.GetCount() + ? TBound(keyRange.first[keyPartIndex], true) + : TBound(MakeUnversionedSentinelValue(EValueType::Min), false); + + YT_VERIFY(keyPartIndex < keyRange.second.GetCount()); + auto upper = TBound(keyRange.second[keyPartIndex], keyPartIndex + 1 < keyRange.second.GetCount()); + + return std::make_pair(lower, upper); + }; + + // Is it a good idea? Heavy, not always useful calculation. + std::vector<std::vector<TBound>> unitedBoundsByColumn(minCommonPrefixSize + 1); + for (int keyPartIndex = 0; keyPartIndex <= minCommonPrefixSize; ++keyPartIndex) { + std::vector<std::vector<TBound>> allBounds; + for (const auto& keyRange : keyRanges) { + auto bounds = getBounds(keyRange, keyPartIndex); + allBounds.push_back(std::vector<TBound>{bounds.first, bounds.second}); + } + + UniteBounds(&allBounds); + YT_VERIFY(!allBounds.empty()); + unitedBoundsByColumn[keyPartIndex] = std::move(allBounds.front()); + } + + std::function<TConstExpressionPtr(TConstExpressionPtr expr)> refinePredicate = + [&] (TConstExpressionPtr expr)->TConstExpressionPtr + { + if (auto binaryOpExpr = expr->As<TBinaryOpExpression>()) { + auto opcode = binaryOpExpr->Opcode; + auto lhsExpr = binaryOpExpr->Lhs; + auto rhsExpr = binaryOpExpr->Rhs; + + if (opcode == EBinaryOp::And) { + return MakeAndExpression( // eliminate constants + refinePredicate(lhsExpr), + refinePredicate(rhsExpr)); + } else if (opcode == EBinaryOp::Or) { + return MakeOrExpression( + refinePredicate(lhsExpr), + refinePredicate(rhsExpr)); + } else { + if (rhsExpr->As<TReferenceExpression>()) { + // Ensure that references are on the left. + std::swap(lhsExpr, rhsExpr); + opcode = GetReversedBinaryOpcode(opcode); + } + + auto referenceExpr = lhsExpr->As<TReferenceExpression>(); + auto constantExpr = rhsExpr->As<TLiteralExpression>(); + + if (referenceExpr && constantExpr) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0 && keyPartIndex <= minCommonPrefixSize) { + auto value = TValue(constantExpr->Value); + + std::vector<TBound> bounds; + + switch (opcode) { + case EBinaryOp::Equal: + bounds.emplace_back(value, true); + bounds.emplace_back(value, true); + break; + + case EBinaryOp::NotEqual: + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + bounds.emplace_back(value, false); + bounds.emplace_back(value, false); + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + break; + + case EBinaryOp::Less: + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + bounds.emplace_back(value, false); + break; + + case EBinaryOp::LessOrEqual: + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + bounds.emplace_back(value, true); + break; + + case EBinaryOp::Greater: + bounds.emplace_back(value, false); + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + break; + + case EBinaryOp::GreaterOrEqual: + bounds.emplace_back(value, true); + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + break; + + default: + break; + } + + if (!bounds.empty()) { + auto resultBounds = IntersectBounds(bounds, unitedBoundsByColumn[keyPartIndex]); + + if (resultBounds.empty()) { + return falseLiteral; + } else if (resultBounds == unitedBoundsByColumn[keyPartIndex]) { + return trueLiteral; + } + } + } + } + } + } else if (auto inExpr = expr->As<TInExpression>()) { + std::vector<TRow> lookupKeys; + std::vector<std::pair<TBound, TBound>> bounds; + for (const auto& keyRange : keyRanges) { + lookupKeys.push_back(keyRange.first); + bounds.push_back(getBounds(keyRange, minCommonPrefixSize)); + } + + return EliminateInExpression(MakeRange(lookupKeys), inExpr, keyColumns, minCommonPrefixSize, &bounds); + } + + return expr; + }; + + return refinePredicate(expr); +} + +TConstExpressionPtr EliminatePredicate( + TRange<TRow> lookupKeys, + TConstExpressionPtr expr, + const TKeyColumns& keyColumns) +{ + std::function<TConstExpressionPtr(TConstExpressionPtr expr)> refinePredicate = + [&] (TConstExpressionPtr expr)->TConstExpressionPtr + { + if (auto binaryOpExpr = expr->As<TBinaryOpExpression>()) { + auto opcode = binaryOpExpr->Opcode; + auto lhsExpr = binaryOpExpr->Lhs; + auto rhsExpr = binaryOpExpr->Rhs; + + // Eliminate constants. + if (opcode == EBinaryOp::And) { + return MakeAndExpression( + refinePredicate(lhsExpr), + refinePredicate(rhsExpr)); + } else if (opcode == EBinaryOp::Or) { + return MakeOrExpression( + refinePredicate(lhsExpr), + refinePredicate(rhsExpr)); + } + } else if (auto inExpr = expr->As<TInExpression>()) { + return EliminateInExpression(lookupKeys, inExpr, keyColumns, keyColumns.size(), nullptr); + } + + return expr; + }; + + return refinePredicate(expr); +} + +TKeyRange Unite(const TKeyRange& first, const TKeyRange& second) +{ + const auto& lower = ChooseMinKey(first.first, second.first); + const auto& upper = ChooseMaxKey(first.second, second.second); + return std::make_pair(lower, upper); +} + +TRowRange Unite(const TRowRange& first, const TRowRange& second) +{ + const auto& lower = std::min(first.first, second.first); + const auto& upper = std::max(first.second, second.second); + return std::make_pair(lower, upper); +} + +TKeyRange Intersect(const TKeyRange& first, const TKeyRange& second) +{ + const auto* leftmost = &first; + const auto* rightmost = &second; + + if (leftmost->first > rightmost->first) { + std::swap(leftmost, rightmost); + } + + if (rightmost->first > leftmost->second) { + // Empty intersection. + return std::make_pair(rightmost->first, rightmost->first); + } + + if (rightmost->second > leftmost->second) { + return std::make_pair(rightmost->first, leftmost->second); + } else { + return std::make_pair(rightmost->first, rightmost->second); + } +} + +TRowRange Intersect(const TRowRange& first, const TRowRange& second) +{ + const auto* leftmost = &first; + const auto* rightmost = &second; + + if (leftmost->first > rightmost->first) { + std::swap(leftmost, rightmost); + } + + if (rightmost->first > leftmost->second) { + // Empty intersection. + return std::make_pair(rightmost->first, rightmost->first); + } + + if (rightmost->second > leftmost->second) { + return std::make_pair(rightmost->first, leftmost->second); + } else { + return std::make_pair(rightmost->first, rightmost->second); + } +} + +bool IsEmpty(const TKeyRange& keyRange) +{ + return keyRange.first >= keyRange.second; +} + +bool IsEmpty(const TRowRange& keyRange) +{ + return keyRange.first >= keyRange.second; +} + +bool AreAllReferencesInSchema(TConstExpressionPtr expr, const TTableSchema& tableSchema) +{ + if (auto referenceExpr = expr->As<TReferenceExpression>()) { + return tableSchema.FindColumn(referenceExpr->ColumnName); + } else if (expr->As<TLiteralExpression>()) { + return true; + } else if (auto binaryOpExpr = expr->As<TBinaryOpExpression>()) { + return AreAllReferencesInSchema(binaryOpExpr->Lhs, tableSchema) && AreAllReferencesInSchema(binaryOpExpr->Rhs, tableSchema); + } else if (auto unaryOpExpr = expr->As<TUnaryOpExpression>()) { + return AreAllReferencesInSchema(unaryOpExpr->Operand, tableSchema); + } else if (auto functionExpr = expr->As<TFunctionExpression>()) { + bool result = true; + for (const auto& argument : functionExpr->Arguments) { + result = result && AreAllReferencesInSchema(argument, tableSchema); + } + return result; + } else if (auto inExpr = expr->As<TInExpression>()) { + bool result = true; + for (const auto& argument : inExpr->Arguments) { + result = result && AreAllReferencesInSchema(argument, tableSchema); + } + return result; + } + + return false; +} + +TConstExpressionPtr ExtractPredicateForColumnSubset( + TConstExpressionPtr expr, + const TTableSchema& tableSchema) +{ + if (!expr) { + return nullptr; + } + + if (AreAllReferencesInSchema(expr, tableSchema)) { + return expr; + } else if (auto binaryOpExpr = expr->As<TBinaryOpExpression>()) { + auto opcode = binaryOpExpr->Opcode; + if (opcode == EBinaryOp::And) { + return MakeAndExpression( + ExtractPredicateForColumnSubset(binaryOpExpr->Lhs, tableSchema), + ExtractPredicateForColumnSubset(binaryOpExpr->Rhs, tableSchema)); + } else if (opcode == EBinaryOp::Or) { + return MakeOrExpression( + ExtractPredicateForColumnSubset(binaryOpExpr->Lhs, tableSchema), + ExtractPredicateForColumnSubset(binaryOpExpr->Rhs, tableSchema)); + } + } + + return New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(true)); +} + +void CollectOperands(std::vector<TConstExpressionPtr>* operands, TConstExpressionPtr expr) +{ + if (!expr) { + return; + } + + if (auto binaryOpExpr = expr->As<TBinaryOpExpression>()) { + auto opcode = binaryOpExpr->Opcode; + if (opcode == EBinaryOp::And) { + CollectOperands(operands, binaryOpExpr->Lhs); + CollectOperands(operands, binaryOpExpr->Rhs); + } else { + operands->push_back(expr); + } + } else { + operands->push_back(expr); + } +} + +std::pair<TConstExpressionPtr, TConstExpressionPtr> SplitPredicateByColumnSubset( + TConstExpressionPtr root, + const TTableSchema& tableSchema) +{ + // collect AND operands + std::vector<TConstExpressionPtr> operands; + + CollectOperands(&operands, root); + TConstExpressionPtr projected = New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(true)); + TConstExpressionPtr remaining = New<TLiteralExpression>( + EValueType::Boolean, + MakeUnversionedBooleanValue(true)); + + for (auto expr : operands) { + auto& target = AreAllReferencesInSchema(expr, tableSchema) ? projected : remaining; + target = MakeAndExpression(target, expr); + } + + return std::make_pair(projected, remaining); +} + +// Wrapper around CompareRowValues that checks that its arguments are not nan. +int CompareRowValuesCheckingNan(const TUnversionedValue& lhs, const TUnversionedValue& rhs) +{ + if (lhs.Type == rhs.Type && lhs.Type == EValueType::Double && + (std::isnan(lhs.Data.Double) || std::isnan(rhs.Data.Double))) + { + THROW_ERROR_EXCEPTION(NTableClient::EErrorCode::InvalidDoubleValue, "NaN value is not comparable"); + } + return CompareRowValues(lhs, rhs); +} + +ui64 GetEvaluatedColumnModulo(const TConstExpressionPtr& expr) +{ + ui64 moduloExpansion = 1; + auto binaryExpr = expr->As<TBinaryOpExpression>(); + + if (binaryExpr && binaryExpr->Opcode == EBinaryOp::Modulo) { + if (auto literalExpr = binaryExpr->Rhs->As<TLiteralExpression>()) { + TUnversionedValue value = literalExpr->Value; + switch (value.Type) { + case EValueType::Int64: + moduloExpansion *= value.Data.Int64 * 2; + break; + + case EValueType::Uint64: + moduloExpansion *= value.Data.Uint64 + 1; + break; + + default: + break; + } + } + } + + return moduloExpansion; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/base/query_helpers.h b/yt/yt/library/query/base/query_helpers.h new file mode 100644 index 0000000000..5e3d3186d5 --- /dev/null +++ b/yt/yt/library/query/base/query_helpers.h @@ -0,0 +1,82 @@ +#pragma once + +#include "public.h" +#include "key_trie.h" + +#include <yt/yt/client/table_client/row_buffer.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <yt/yt/core/misc/range.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +//! Returns a minimal key range that cover both inputs. +TKeyRange Unite(const TKeyRange& first, const TKeyRange& second); +TRowRange Unite(const TRowRange& first, const TRowRange& second); + +//! Returns a maximal key range covered by both inputs. +TKeyRange Intersect(const TKeyRange& first, const TKeyRange& second); +TRowRange Intersect(const TRowRange& first, const TRowRange& second); + +//! Checks whether key range is empty. +bool IsEmpty(const TKeyRange& keyRange); +bool IsEmpty(const TRowRange& keyRange); + +bool IsTrue(TConstExpressionPtr expr); +TConstExpressionPtr MakeAndExpression(TConstExpressionPtr lhs, TConstExpressionPtr rhs); +TConstExpressionPtr MakeOrExpression(TConstExpressionPtr lhs, TConstExpressionPtr rhs); + +TConstExpressionPtr EliminatePredicate( + TRange<TRowRange> keyRanges, + TConstExpressionPtr expr, + const TKeyColumns& keyColumns); + +TConstExpressionPtr EliminatePredicate( + TRange<TRow> lookupKeys, + TConstExpressionPtr expr, + const TKeyColumns& keyColumns); + +TConstExpressionPtr ExtractPredicateForColumnSubset( + TConstExpressionPtr expr, + const TTableSchema& tableSchema); + +std::pair<TConstExpressionPtr, TConstExpressionPtr> SplitPredicateByColumnSubset( + TConstExpressionPtr root, + const TTableSchema& tableSchema); + +// Wrapper around CompareRowValues that checks that its arguments are not nan. +int CompareRowValuesCheckingNan(const NTableClient::TUnversionedValue& lhs, const NTableClient::TUnversionedValue& rhs); + +ui64 GetEvaluatedColumnModulo(const TConstExpressionPtr& expr); + +template <class TIter> +TIter MergeOverlappingRanges(TIter begin, TIter end) +{ + if (begin == end) { + return end; + } + + auto it = begin; + auto dest = it; + ++it; + + for (; it != end; ++it) { + if (dest->second < it->first) { + if (++dest != it) { + *dest = std::move(*it); + } + } else if (dest->second < it->second) { + dest->second = std::move(it->second); + } + } + + ++dest; + return dest; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/base/query_preparer.cpp b/yt/yt/library/query/base/query_preparer.cpp new file mode 100644 index 0000000000..814568cf0d --- /dev/null +++ b/yt/yt/library/query/base/query_preparer.cpp @@ -0,0 +1,3117 @@ +#include "query_preparer.h" +#include "private.h" +#include "callbacks.h" +#include "functions.h" +#include "lexer.h" +#include "parser.h" +#include "query_helpers.h" + +#include <yt/yt_proto/yt/client/chunk_client/proto/chunk_spec.pb.h> + +#include <yt/yt/client/tablet_client/public.h> + +#include <yt/yt/core/ytree/convert.h> + +#include <yt/yt/core/misc/collection_helpers.h> +#include <yt/yt/core/misc/finally.h> + +#include <library/cpp/yt/misc/variant.h> + +#include <unordered_set> + +namespace NYT::NQueryClient { + +using namespace NConcurrency; +using namespace NTableClient; +using namespace NYson; + +//////////////////////////////////////////////////////////////////////////////// + +static constexpr size_t MaxExpressionDepth = 50; + +#ifdef _asan_enabled_ +static const int MinimumStackFreeSpace = 128_KB; +#else +static const int MinimumStackFreeSpace = 16_KB; +#endif + +struct TQueryPreparerBufferTag +{ }; + +constexpr ssize_t MaxQueryLimit = 10000000; + +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +void CheckStackDepth() +{ + if (!CheckFreeStackSpace(MinimumStackFreeSpace)) { + THROW_ERROR_EXCEPTION( + NTabletClient::EErrorCode::QueryExpressionDepthLimitExceeded, + "Expression depth causes stack overflow"); + } +} + +void ExtractFunctionNames( + const NAst::TNullableExpressionList& exprs, + std::vector<TString>* functions); + +void ExtractFunctionNames( + const NAst::TExpressionPtr& expr, + std::vector<TString>* functions) +{ + if (auto functionExpr = expr->As<NAst::TFunctionExpression>()) { + functions->push_back(to_lower(functionExpr->FunctionName)); + ExtractFunctionNames(functionExpr->Arguments, functions); + } else if (auto unaryExpr = expr->As<NAst::TUnaryOpExpression>()) { + ExtractFunctionNames(unaryExpr->Operand, functions); + } else if (auto binaryExpr = expr->As<NAst::TBinaryOpExpression>()) { + ExtractFunctionNames(binaryExpr->Lhs, functions); + ExtractFunctionNames(binaryExpr->Rhs, functions); + } else if (auto inExpr = expr->As<NAst::TInExpression>()) { + ExtractFunctionNames(inExpr->Expr, functions); + } else if (auto betweenExpr = expr->As<NAst::TBetweenExpression>()) { + ExtractFunctionNames(betweenExpr->Expr, functions); + } else if (auto transformExpr = expr->As<NAst::TTransformExpression>()) { + ExtractFunctionNames(transformExpr->Expr, functions); + ExtractFunctionNames(transformExpr->DefaultExpr, functions); + } else if (expr->As<NAst::TLiteralExpression>()) { + } else if (expr->As<NAst::TReferenceExpression>()) { + } else if (expr->As<NAst::TAliasExpression>()) { + } else { + YT_ABORT(); + } +} + +void ExtractFunctionNames( + const NAst::TNullableExpressionList& exprs, + std::vector<TString>* functions) +{ + if (!exprs) { + return; + } + + CheckStackDepth(); + + for (const auto& expr : *exprs) { + ExtractFunctionNames(expr, functions); + } +} + +std::vector<TString> ExtractFunctionNames( + const NAst::TQuery& query, + const NAst::TAliasMap& aliasMap) +{ + std::vector<TString> functions; + + ExtractFunctionNames(query.WherePredicate, &functions); + ExtractFunctionNames(query.HavingPredicate, &functions); + ExtractFunctionNames(query.SelectExprs, &functions); + + if (query.GroupExprs) { + for (const auto& expr : query.GroupExprs->first) { + ExtractFunctionNames(expr, &functions); + } + } + + for (const auto& join : query.Joins) { + ExtractFunctionNames(join.Lhs, &functions); + ExtractFunctionNames(join.Rhs, &functions); + } + + for (const auto& orderExpression : query.OrderExpressions) { + for (const auto& expr : orderExpression.first) { + ExtractFunctionNames(expr, &functions); + } + } + + for (const auto& aliasedExpression : aliasMap) { + ExtractFunctionNames(aliasedExpression.second, &functions); + } + + std::sort(functions.begin(), functions.end()); + functions.erase( + std::unique(functions.begin(), functions.end()), + functions.end()); + + return functions; +} + +//////////////////////////////////////////////////////////////////////////////// + +TTypeSet ComparableTypes({ + EValueType::Boolean, + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::String}); + +//////////////////////////////////////////////////////////////////////////////// + +EValueType GetType(const NAst::TLiteralValue& literalValue) +{ + return Visit(literalValue, + [] (const NAst::TNullLiteralValue&) { + return EValueType::Null; + }, + [] (i64) { + return EValueType::Int64; + }, + [] (ui64) { + return EValueType::Uint64; + }, + [] (double) { + return EValueType::Double; + }, + [] (bool) { + return EValueType::Boolean; + }, + [] (const TString&) { + return EValueType::String; + }); +} + +TTypeSet GetTypes(const NAst::TLiteralValue& literalValue) +{ + return Visit(literalValue, + [] (const NAst::TNullLiteralValue&) { + return TTypeSet({ + EValueType::Null, + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Any + }); + }, + [] (i64) { + return TTypeSet({ + EValueType::Int64, + EValueType::Uint64, + EValueType::Double + }); + }, + [] (ui64) { + return TTypeSet({ + EValueType::Uint64, + EValueType::Double + }); + }, + [] (double) { + return TTypeSet({ + EValueType::Double + }); + }, + [] (bool) { + return TTypeSet({ + EValueType::Boolean + }); + }, + [] (const TString&) { + return TTypeSet({ + EValueType::String + }); + }); +} + +TValue GetValue(const NAst::TLiteralValue& literalValue) +{ + return Visit(literalValue, + [] (const NAst::TNullLiteralValue&) { + return MakeUnversionedSentinelValue(EValueType::Null); + }, + [] (i64 value) { + return MakeUnversionedInt64Value(value); + }, + [] (ui64 value) { + return MakeUnversionedUint64Value(value); + }, + [] (double value) { + return MakeUnversionedDoubleValue(value); + }, + [] (bool value) { + return MakeUnversionedBooleanValue(value); + }, + [] (const TString& value) { + return MakeUnversionedStringValue(TStringBuf(value.c_str(), value.length())); + }); +} + +void BuildRow( + TUnversionedRowBuilder* rowBuilder, + const NAst::TLiteralValueTuple& tuple, + const std::vector<EValueType>& argTypes, + TStringBuf source) +{ + for (int i = 0; i < std::ssize(tuple); ++i) { + auto valueType = GetType(tuple[i]); + auto value = GetValue(tuple[i]); + + if (valueType == EValueType::Null) { + value = MakeUnversionedSentinelValue(EValueType::Null); + } else if (valueType != argTypes[i]) { + if (IsArithmeticType(valueType) && IsArithmeticType(argTypes[i])) { + value = CastValueWithCheck(value, argTypes[i]); + } else { + THROW_ERROR_EXCEPTION("Types mismatch in tuple") + << TErrorAttribute("source", source) + << TErrorAttribute("actual_type", valueType) + << TErrorAttribute("expected_type", argTypes[i]); + } + } + rowBuilder->AddValue(value); + } +} + +TSharedRange<TRow> LiteralTupleListToRows( + const NAst::TLiteralValueTupleList& literalTuples, + const std::vector<EValueType>& argTypes, + TStringBuf source) +{ + auto rowBuffer = New<TRowBuffer>(TQueryPreparerBufferTag()); + TUnversionedRowBuilder rowBuilder; + std::vector<TRow> rows; + for (const auto& tuple : literalTuples) { + if (tuple.size() != argTypes.size()) { + THROW_ERROR_EXCEPTION("Arguments size mismatch in tuple") + << TErrorAttribute("source", source); + } + + BuildRow(&rowBuilder, tuple, argTypes, source); + + rows.push_back(rowBuffer->CaptureRow(rowBuilder.GetRow())); + rowBuilder.Reset(); + } + + std::sort(rows.begin(), rows.end()); + return MakeSharedRange(std::move(rows), std::move(rowBuffer)); +} + +TSharedRange<TRowRange> LiteralRangesListToRows( + const NAst::TLiteralValueRangeList& literalRanges, + const std::vector<EValueType>& argTypes, + TStringBuf source) +{ + auto rowBuffer = New<TRowBuffer>(TQueryPreparerBufferTag()); + TUnversionedRowBuilder rowBuilder; + std::vector<TRowRange> ranges; + for (const auto& range : literalRanges) { + if (range.first.size() > argTypes.size()) { + THROW_ERROR_EXCEPTION("Arguments size mismatch in tuple") + << TErrorAttribute("source", source); + } + + if (range.second.size() > argTypes.size()) { + THROW_ERROR_EXCEPTION("Arguments size mismatch in tuple") + << TErrorAttribute("source", source); + } + + BuildRow(&rowBuilder, range.first, argTypes, source); + auto lower = rowBuffer->CaptureRow(rowBuilder.GetRow()); + rowBuilder.Reset(); + + BuildRow(&rowBuilder, range.second, argTypes, source); + auto upper = rowBuffer->CaptureRow(rowBuilder.GetRow()); + rowBuilder.Reset(); + + if (CompareRows(lower, upper, std::min(lower.GetCount(), upper.GetCount())) > 0) { + THROW_ERROR_EXCEPTION("Lower bound is greater than upper") + << TErrorAttribute("lower", lower) + << TErrorAttribute("upper", upper); + } + + ranges.emplace_back(lower, upper); + } + + std::sort(ranges.begin(), ranges.end()); + + for (int index = 1; index < std::ssize(ranges); ++index) { + TRow previousUpper = ranges[index - 1].second; + TRow currentLower = ranges[index].first; + + if (CompareRows( + previousUpper, + currentLower, + std::min(previousUpper.GetCount(), currentLower.GetCount())) >= 0) + { + THROW_ERROR_EXCEPTION("Ranges are not disjoint") + << TErrorAttribute("first", ranges[index - 1]) + << TErrorAttribute("second", ranges[index]); + } + } + + return MakeSharedRange(std::move(ranges), std::move(rowBuffer)); +} + +std::optional<TUnversionedValue> FoldConstants( + EUnaryOp opcode, + const TConstExpressionPtr& operand) +{ + if (auto literalExpr = operand->As<TLiteralExpression>()) { + if (opcode == EUnaryOp::Plus) { + return static_cast<TUnversionedValue>(literalExpr->Value); + } else if (opcode == EUnaryOp::Minus) { + TUnversionedValue value = literalExpr->Value; + switch (value.Type) { + case EValueType::Int64: + value.Data.Int64 = -value.Data.Int64; + break; + case EValueType::Uint64: + value.Data.Uint64 = -value.Data.Uint64; + break; + case EValueType::Double: + value.Data.Double = -value.Data.Double; + break; + case EValueType::Null: + break; + default: + YT_ABORT(); + } + return value; + } else if (opcode == EUnaryOp::BitNot) { + TUnversionedValue value = literalExpr->Value; + switch (value.Type) { + case EValueType::Int64: + value.Data.Int64 = ~value.Data.Int64; + break; + case EValueType::Uint64: + value.Data.Uint64 = ~value.Data.Uint64; + break; + case EValueType::Null: + break; + default: + YT_ABORT(); + } + return value; + } + } + return std::nullopt; +} + +std::optional<TUnversionedValue> FoldConstants( + EBinaryOp opcode, + const TConstExpressionPtr& lhsExpr, + const TConstExpressionPtr& rhsExpr) +{ + auto lhsLiteral = lhsExpr->As<TLiteralExpression>(); + auto rhsLiteral = rhsExpr->As<TLiteralExpression>(); + if (lhsLiteral && rhsLiteral) { + auto lhs = static_cast<TUnversionedValue>(lhsLiteral->Value); + auto rhs = static_cast<TUnversionedValue>(rhsLiteral->Value); + + auto checkType = [&] () { + if (lhs.Type != rhs.Type) { + if (IsArithmeticType(lhs.Type) && IsArithmeticType(rhs.Type)) { + auto targetType = std::max(lhs.Type, rhs.Type); + lhs = CastValueWithCheck(lhs, targetType); + rhs = CastValueWithCheck(rhs, targetType); + } else { + ThrowTypeMismatchError(lhs.Type, rhs.Type, "", InferName(lhsExpr), InferName(rhsExpr)); + } + } + }; + + auto checkTypeIfNotNull = [&] () { + if (lhs.Type != EValueType::Null && rhs.Type != EValueType::Null) { + checkType(); + } + }; + + #define CHECK_TYPE() \ + if (lhs.Type == EValueType::Null) { \ + return MakeUnversionedSentinelValue(EValueType::Null); \ + } \ + if (rhs.Type == EValueType::Null) { \ + return MakeUnversionedSentinelValue(EValueType::Null); \ + } \ + checkType(); + + auto evaluateLogicalOp = [&] (bool parameter) { + YT_VERIFY(lhs.Type == EValueType::Null || lhs.Type == EValueType::Boolean); + YT_VERIFY(rhs.Type == EValueType::Null || rhs.Type == EValueType::Boolean); + + if (lhs.Type == EValueType::Null) { + if (rhs.Type != EValueType::Null && rhs.Data.Boolean == parameter) { + return rhs; + } else { + return lhs; + } + } else if (lhs.Data.Boolean == parameter) { + return lhs; + } else { + return rhs; + } + }; + + switch (opcode) { + case EBinaryOp::Plus: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + lhs.Data.Int64 += rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + lhs.Data.Uint64 += rhs.Data.Uint64; + return lhs; + case EValueType::Double: + lhs.Data.Double += rhs.Data.Double; + return lhs; + default: + break; + } + break; + case EBinaryOp::Minus: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + lhs.Data.Int64 -= rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + lhs.Data.Uint64 -= rhs.Data.Uint64; + return lhs; + case EValueType::Double: + lhs.Data.Double -= rhs.Data.Double; + return lhs; + default: + break; + } + break; + case EBinaryOp::Concatenate: + break; + case EBinaryOp::Multiply: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + lhs.Data.Int64 *= rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + lhs.Data.Uint64 *= rhs.Data.Uint64; + return lhs; + case EValueType::Double: + lhs.Data.Double *= rhs.Data.Double; + return lhs; + default: + break; + } + break; + case EBinaryOp::Divide: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + if (rhs.Data.Int64 == 0) { + THROW_ERROR_EXCEPTION("Division by zero"); + } + + if (lhs.Data.Int64 == std::numeric_limits<i64>::min() && rhs.Data.Int64 == -1) { + THROW_ERROR_EXCEPTION("Division of INT_MIN by -1"); + } + + lhs.Data.Int64 /= rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + if (rhs.Data.Uint64 == 0) { + THROW_ERROR_EXCEPTION("Division by zero"); + } + lhs.Data.Uint64 /= rhs.Data.Uint64; + return lhs; + case EValueType::Double: + lhs.Data.Double /= rhs.Data.Double; + return lhs; + default: + break; + } + break; + case EBinaryOp::Modulo: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + if (rhs.Data.Int64 == 0) { + THROW_ERROR_EXCEPTION("Division by zero"); + } + + if (lhs.Data.Int64 == std::numeric_limits<i64>::min() && rhs.Data.Int64 == -1) { + THROW_ERROR_EXCEPTION("Division of INT_MIN by -1"); + } + + lhs.Data.Int64 %= rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + if (rhs.Data.Uint64 == 0) { + THROW_ERROR_EXCEPTION("Division by zero"); + } + lhs.Data.Uint64 %= rhs.Data.Uint64; + return lhs; + default: + break; + } + break; + case EBinaryOp::LeftShift: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + lhs.Data.Int64 <<= rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + lhs.Data.Uint64 <<= rhs.Data.Uint64; + return lhs; + default: + break; + } + break; + case EBinaryOp::RightShift: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Int64: + lhs.Data.Int64 >>= rhs.Data.Int64; + return lhs; + case EValueType::Uint64: + lhs.Data.Uint64 >>= rhs.Data.Uint64; + return lhs; + default: + break; + } + break; + case EBinaryOp::BitOr: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Uint64: + lhs.Data.Uint64 = lhs.Data.Uint64 | rhs.Data.Uint64; + return lhs; + case EValueType::Int64: + lhs.Data.Int64 = lhs.Data.Int64 | rhs.Data.Int64; + return lhs; + default: + break; + } + break; + case EBinaryOp::BitAnd: + CHECK_TYPE(); + switch (lhs.Type) { + case EValueType::Uint64: + lhs.Data.Uint64 = lhs.Data.Uint64 & rhs.Data.Uint64; + return lhs; + case EValueType::Int64: + lhs.Data.Int64 = lhs.Data.Int64 & rhs.Data.Int64; + return lhs; + default: + break; + } + break; + case EBinaryOp::And: + return evaluateLogicalOp(false); + break; + case EBinaryOp::Or: + return evaluateLogicalOp(true); + break; + case EBinaryOp::Equal: + checkTypeIfNotNull(); + return MakeUnversionedBooleanValue(CompareRowValuesCheckingNan(lhs, rhs) == 0); + break; + case EBinaryOp::NotEqual: + checkTypeIfNotNull(); + return MakeUnversionedBooleanValue(CompareRowValuesCheckingNan(lhs, rhs) != 0); + break; + case EBinaryOp::Less: + checkTypeIfNotNull(); + return MakeUnversionedBooleanValue(CompareRowValuesCheckingNan(lhs, rhs) < 0); + break; + case EBinaryOp::Greater: + checkTypeIfNotNull(); + return MakeUnversionedBooleanValue(CompareRowValuesCheckingNan(lhs, rhs) > 0); + break; + case EBinaryOp::LessOrEqual: + checkTypeIfNotNull(); + return MakeUnversionedBooleanValue(CompareRowValuesCheckingNan(lhs, rhs) <= 0); + break; + case EBinaryOp::GreaterOrEqual: + checkTypeIfNotNull(); + return MakeUnversionedBooleanValue(CompareRowValuesCheckingNan(lhs, rhs) >= 0); + break; + default: + break; + } + } + return std::nullopt; +} + +struct TNotExpressionPropagator + : TRewriter<TNotExpressionPropagator> +{ + using TBase = TRewriter<TNotExpressionPropagator>; + + TConstExpressionPtr OnUnary(const TUnaryOpExpression* unaryExpr) + { + auto& operand = unaryExpr->Operand; + if (unaryExpr->Opcode == EUnaryOp::Not) { + if (auto operandUnaryOp = operand->As<TUnaryOpExpression>()) { + if (operandUnaryOp->Opcode == EUnaryOp::Not) { + return Visit(operandUnaryOp->Operand); + } + } else if (auto operandBinaryOp = operand->As<TBinaryOpExpression>()) { + if (operandBinaryOp->Opcode == EBinaryOp::And) { + return Visit(New<TBinaryOpExpression>( + EValueType::Boolean, + EBinaryOp::Or, + New<TUnaryOpExpression>( + operandBinaryOp->Lhs->GetWireType(), + EUnaryOp::Not, + operandBinaryOp->Lhs), + New<TUnaryOpExpression>( + operandBinaryOp->Rhs->GetWireType(), + EUnaryOp::Not, + operandBinaryOp->Rhs))); + } else if (operandBinaryOp->Opcode == EBinaryOp::Or) { + return Visit(New<TBinaryOpExpression>( + EValueType::Boolean, + EBinaryOp::And, + New<TUnaryOpExpression>( + operandBinaryOp->Lhs->GetWireType(), + EUnaryOp::Not, + operandBinaryOp->Lhs), + New<TUnaryOpExpression>( + operandBinaryOp->Rhs->GetWireType(), + EUnaryOp::Not, + operandBinaryOp->Rhs))); + } else if (IsRelationalBinaryOp(operandBinaryOp->Opcode)) { + return Visit(New<TBinaryOpExpression>( + operandBinaryOp->GetWireType(), + GetInversedBinaryOpcode(operandBinaryOp->Opcode), + operandBinaryOp->Lhs, + operandBinaryOp->Rhs)); + } + } else if (auto literal = operand->As<TLiteralExpression>()) { + TUnversionedValue value = literal->Value; + value.Data.Boolean = !value.Data.Boolean; + return New<TLiteralExpression>( + literal->GetWireType(), + value); + } + } + + return TBase::OnUnary(unaryExpr); + } +}; + +struct TCastEliminator + : TRewriter<TCastEliminator> +{ + using TBase = TRewriter<TCastEliminator>; + + TConstExpressionPtr OnFunction(const TFunctionExpression* functionExpr) + { + if (IsUserCastFunction(functionExpr->FunctionName)) { + YT_VERIFY(functionExpr->Arguments.size() == 1); + + if (*functionExpr->LogicalType == *functionExpr->Arguments[0]->LogicalType) { + return Visit(functionExpr->Arguments[0]); + } + } + + return TBase::OnFunction(functionExpr); + } +}; + +struct TExpressionSimplifier + : TRewriter<TExpressionSimplifier> +{ + using TBase = TRewriter<TExpressionSimplifier>; + + TConstExpressionPtr OnFunction(const TFunctionExpression* functionExpr) + { + if (functionExpr->FunctionName == "if") { + if (auto functionCondition = functionExpr->Arguments[0]->As<TFunctionExpression>()) { + auto reference1 = functionExpr->Arguments[2]->As<TReferenceExpression>(); + if (functionCondition->FunctionName == "is_null" && reference1) { + auto reference0 = functionCondition->Arguments[0]->As<TReferenceExpression>(); + if (reference0 && reference1->ColumnName == reference0->ColumnName) { + return New<TFunctionExpression>( + functionExpr->GetWireType(), + "if_null", + std::vector<TConstExpressionPtr>{ + functionCondition->Arguments[0], + functionExpr->Arguments[1]}); + + } + } + } + } + + return TBase::OnFunction(functionExpr); + } +}; + +bool Unify(TTypeSet* genericAssignments, const TTypeSet& types) +{ + auto intersection = *genericAssignments & types; + + if (intersection.IsEmpty()) { + return false; + } else { + *genericAssignments = intersection; + return true; + } +} + +EValueType GetFrontWithCheck(const TTypeSet& typeSet, TStringBuf source) +{ + auto result = typeSet.GetFront(); + if (result == EValueType::Null) { + THROW_ERROR_EXCEPTION("Type inference failed") + << TErrorAttribute("actual_type", EValueType::Null) + << TErrorAttribute("source", source); + } + return result; +} + +TTypeSet InferFunctionTypes( + const TFunctionTypeInferrer* inferrer, + const std::vector<TTypeSet>& effectiveTypes, + std::vector<TTypeSet>* genericAssignments, + TStringBuf functionName, + TStringBuf source) +{ + std::vector<TTypeSet> typeConstraints; + std::vector<int> formalArguments; + std::optional<std::pair<int, bool>> repeatedType; + int formalResultType = inferrer->GetNormalizedConstraints( + &typeConstraints, + &formalArguments, + &repeatedType); + + *genericAssignments = typeConstraints; + + int argIndex = 1; + auto arg = effectiveTypes.begin(); + auto formalArg = formalArguments.begin(); + for (; + formalArg != formalArguments.end() && arg != effectiveTypes.end(); + arg++, formalArg++, argIndex++) + { + auto& constraints = (*genericAssignments)[*formalArg]; + if (!Unify(&constraints, *arg)) { + THROW_ERROR_EXCEPTION( + "Wrong type for argument %v to function %Qv: expected %Qv, got %Qv", + argIndex, + functionName, + constraints, + *arg) + << TErrorAttribute("expression", source); + } + } + + bool hasNoRepeatedArgument = !repeatedType.operator bool(); + + if (formalArg != formalArguments.end() || + (arg != effectiveTypes.end() && hasNoRepeatedArgument)) + { + THROW_ERROR_EXCEPTION( + "Wrong number of arguments to function %Qv: expected %v, got %v", + functionName, + formalArguments.size(), + effectiveTypes.size()) + << TErrorAttribute("expression", source); + } + + for (; arg != effectiveTypes.end(); arg++) { + int constraintIndex = repeatedType->first; + if (repeatedType->second) { + constraintIndex = genericAssignments->size(); + genericAssignments->push_back((*genericAssignments)[repeatedType->first]); + } + auto& constraints = (*genericAssignments)[constraintIndex]; + if (!Unify(&constraints, *arg)) { + THROW_ERROR_EXCEPTION( + "Wrong type for repeated argument to function %Qv: expected %Qv, got %Qv", + functionName, + constraints, + *arg) + << TErrorAttribute("expression", source); + } + } + + return (*genericAssignments)[formalResultType]; +} + +std::vector<EValueType> RefineFunctionTypes( + const TFunctionTypeInferrer* inferrer, + EValueType resultType, + int argumentCount, + std::vector<TTypeSet>* genericAssignments, + TStringBuf source) +{ + std::vector<TTypeSet> typeConstraints; + std::vector<int> formalArguments; + std::optional<std::pair<int, bool>> repeatedType; + int formalResultType = inferrer->GetNormalizedConstraints( + &typeConstraints, + &formalArguments, + &repeatedType); + + (*genericAssignments)[formalResultType] = TTypeSet({resultType}); + + std::vector<EValueType> genericAssignmentsMin; + for (auto& constraint : *genericAssignments) { + genericAssignmentsMin.push_back(GetFrontWithCheck(constraint, source)); + } + + std::vector<EValueType> effectiveTypes; + int argIndex = 0; + auto formalArg = formalArguments.begin(); + for (; + formalArg != formalArguments.end() && argIndex < argumentCount; + ++formalArg, ++argIndex) + { + effectiveTypes.push_back(genericAssignmentsMin[*formalArg]); + } + + for (; argIndex < argumentCount; ++argIndex) { + int constraintIndex = repeatedType->first; + if (repeatedType->second) { + constraintIndex = genericAssignments->size() - (argumentCount - argIndex); + } + + effectiveTypes.push_back(genericAssignmentsMin[constraintIndex]); + } + + return effectiveTypes; +} + +// 1. Init generic assignments with constraints +// Intersect generic assignments with argument types and save them +// Infer feasible result types +// 2. Apply result types and restrict generic assignments and argument types + +void IntersectGenericsWithArgumentTypes( + const std::vector<TTypeSet>& effectiveTypes, + std::vector<TTypeSet>* genericAssignments, + const std::vector<int>& formalArguments, + TStringBuf functionName, + TStringBuf source) +{ + if (formalArguments.size() != effectiveTypes.size()) { + THROW_ERROR_EXCEPTION("Expected %v number of arguments to function %Qv, got %v", + formalArguments.size(), + functionName, + effectiveTypes.size()); + } + + for (int argIndex = 0; argIndex < std::ssize(formalArguments); ++argIndex) + { + auto& constraints = (*genericAssignments)[formalArguments[argIndex]]; + if (!Unify(&constraints, effectiveTypes[argIndex])) { + THROW_ERROR_EXCEPTION("Wrong type for argument %v to function %Qv: expected %Qv, got %Qv", + argIndex + 1, + functionName, + constraints, + effectiveTypes[argIndex]) + << TErrorAttribute("expression", source); + } + } +} + +std::vector<EValueType> RefineFunctionTypes( + int formalResultType, + int formalStateType, + const std::vector<int>& formalArguments, + EValueType resultType, + EValueType* stateType, + std::vector<TTypeSet>* genericAssignments, + TStringBuf source) +{ + (*genericAssignments)[formalResultType] = TTypeSet({resultType}); + + std::vector<EValueType> genericAssignmentsMin; + for (auto& constraint : *genericAssignments) { + genericAssignmentsMin.push_back(GetFrontWithCheck(constraint, source)); + } + + *stateType = genericAssignmentsMin[formalStateType]; + + std::vector<EValueType> effectiveTypes; + for (int formalArgConstraint : formalArguments) + { + effectiveTypes.push_back(genericAssignmentsMin[formalArgConstraint]); + } + + return effectiveTypes; +} + +struct TOperatorTyper +{ + TTypeSet Constraint; + std::optional<EValueType> ResultType; +}; + +TEnumIndexedVector<EBinaryOp, TOperatorTyper> BuildBinaryOperatorTypers() +{ + TEnumIndexedVector<EBinaryOp, TOperatorTyper> result; + + for (auto op : { + EBinaryOp::Plus, + EBinaryOp::Minus, + EBinaryOp::Multiply, + EBinaryOp::Divide}) + { + result[op] = { + TTypeSet({EValueType::Int64, EValueType::Uint64, EValueType::Double}), + std::nullopt + }; + } + + for (auto op : { + EBinaryOp::Modulo, + EBinaryOp::LeftShift, + EBinaryOp::RightShift, + EBinaryOp::BitOr, + EBinaryOp::BitAnd}) + { + result[op] = { + TTypeSet({EValueType::Int64, EValueType::Uint64}), + std::nullopt + }; + } + + for (auto op : { + EBinaryOp::And, + EBinaryOp::Or}) + { + result[op] = { + TTypeSet({EValueType::Boolean}), + EValueType::Boolean + }; + } + + for (auto op : { + EBinaryOp::Equal, + EBinaryOp::NotEqual, + EBinaryOp::Less, + EBinaryOp::Greater, + EBinaryOp::LessOrEqual, + EBinaryOp::GreaterOrEqual}) + { + result[op] = { + TTypeSet({ + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Any}), + EValueType::Boolean + }; + } + + for (auto op : {EBinaryOp::Concatenate}) { + result[op] = { + TTypeSet({ EValueType::String, }), + EValueType::String + }; + } + + return result; +} + +const TEnumIndexedVector<EBinaryOp, TOperatorTyper>& GetBinaryOperatorTypers() +{ + static auto result = BuildBinaryOperatorTypers(); + return result; +} + +TEnumIndexedVector<EUnaryOp, TOperatorTyper> BuildUnaryOperatorTypers() +{ + TEnumIndexedVector<EUnaryOp, TOperatorTyper> result; + + for (auto op : { + EUnaryOp::Plus, + EUnaryOp::Minus}) + { + result[op] = { + TTypeSet({EValueType::Int64, EValueType::Uint64, EValueType::Double}), + std::nullopt + }; + } + + result[EUnaryOp::BitNot] = { + TTypeSet({EValueType::Int64, EValueType::Uint64}), + std::nullopt + }; + + result[EUnaryOp::Not] = { + TTypeSet({EValueType::Boolean}), + std::nullopt + }; + + return result; +} + +const TEnumIndexedVector<EUnaryOp, TOperatorTyper>& GetUnaryOperatorTypers() +{ + static auto result = BuildUnaryOperatorTypers(); + return result; +} + +TTypeSet InferBinaryExprTypes( + EBinaryOp opCode, + const TTypeSet& lhsTypes, + const TTypeSet& rhsTypes, + TTypeSet* genericAssignments, + TStringBuf lhsSource, + TStringBuf rhsSource) +{ + if (IsRelationalBinaryOp(opCode) && (lhsTypes & rhsTypes).IsEmpty()) { + return TTypeSet{EValueType::Boolean}; + } + + const auto& binaryOperators = GetBinaryOperatorTypers(); + + *genericAssignments = binaryOperators[opCode].Constraint; + + if (!Unify(genericAssignments, lhsTypes)) { + THROW_ERROR_EXCEPTION("Type mismatch in expression %Qv: expected %Qv, got %Qv", + opCode, + *genericAssignments, + lhsTypes) + << TErrorAttribute("lhs_source", lhsSource) + << TErrorAttribute("rhs_source", rhsSource); + } + + if (!Unify(genericAssignments, rhsTypes)) { + THROW_ERROR_EXCEPTION("Type mismatch in expression %Qv: expected %Qv, got %Qv", + opCode, + *genericAssignments, + rhsTypes) + << TErrorAttribute("lhs_source", lhsSource) + << TErrorAttribute("rhs_source", rhsSource); + } + + TTypeSet resultTypes; + if (binaryOperators[opCode].ResultType) { + resultTypes = TTypeSet({*binaryOperators[opCode].ResultType}); + } else { + resultTypes = *genericAssignments; + } + + return resultTypes; +} + +std::pair<EValueType, EValueType> RefineBinaryExprTypes( + EBinaryOp opCode, + EValueType resultType, + const TTypeSet& lhsTypes, + const TTypeSet& rhsTypes, + TTypeSet* genericAssignments, + TStringBuf lhsSource, + TStringBuf rhsSource, + TStringBuf source) +{ + if (IsRelationalBinaryOp(opCode) && (lhsTypes & rhsTypes).IsEmpty()) { + // Empty intersection (Any, alpha) || (alpha, Any), where alpha = {bool, int, uint, double, string} + if (lhsTypes.Get(EValueType::Any)) { + return std::make_pair(EValueType::Any, GetFrontWithCheck(rhsTypes, rhsSource)); + } + + if (rhsTypes.Get(EValueType::Any)) { + return std::make_pair(GetFrontWithCheck(lhsTypes, lhsSource), EValueType::Any); + } + + THROW_ERROR_EXCEPTION("Type mismatch in expression") + << TErrorAttribute("lhs_source", lhsSource) + << TErrorAttribute("rhs_source", rhsSource); + } + + const auto& binaryOperators = GetBinaryOperatorTypers(); + + EValueType argType; + if (binaryOperators[opCode].ResultType) { + argType = GetFrontWithCheck(*genericAssignments, source); + } else { + YT_VERIFY(genericAssignments->Get(resultType)); + argType = resultType; + } + + return std::make_pair(argType, argType); +} + +TTypeSet InferUnaryExprTypes( + EUnaryOp opCode, + const TTypeSet& argTypes, + TTypeSet* genericAssignments, + TStringBuf opSource) +{ + const auto& unaryOperators = GetUnaryOperatorTypers(); + + *genericAssignments = unaryOperators[opCode].Constraint; + + if (!Unify(genericAssignments, argTypes)) { + THROW_ERROR_EXCEPTION("Type mismatch in expression %Qv: expected %Qv, got %Qv", + opCode, + *genericAssignments, + argTypes) + << TErrorAttribute("op_source", opSource); + } + + TTypeSet resultTypes; + if (unaryOperators[opCode].ResultType) { + resultTypes = TTypeSet({*unaryOperators[opCode].ResultType}); + } else { + resultTypes = *genericAssignments; + } + + return resultTypes; +} + +EValueType RefineUnaryExprTypes( + EUnaryOp opCode, + EValueType resultType, + TTypeSet* genericAssignments, + TStringBuf opSource) +{ + const auto& unaryOperators = GetUnaryOperatorTypers(); + + EValueType argType; + if (unaryOperators[opCode].ResultType) { + argType = GetFrontWithCheck(*genericAssignments, opSource); + } else { + YT_VERIFY(genericAssignments->Get(resultType)); + argType = resultType; + } + + return argType; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct TBaseColumn +{ + TBaseColumn(const TString& name, TLogicalTypePtr type) + : Name(name) + , LogicalType(type) + { } + + TString Name; + TLogicalTypePtr LogicalType; +}; + + +struct TBuilderCtxBase +{ +private: + struct TTable + { + const TTableSchema& Schema; + std::optional<TString> Alias; + std::vector<TColumnDescriptor>* Mapping = nullptr; + }; + + // TODO: Enrich TMappedSchema with alias and keep here pointers to TMappedSchema. + std::vector<TTable> Tables; + +protected: + // TODO: Combine in Structure? Move out? + const TNamedItemList* GroupItems = nullptr; + TAggregateItemList* AggregateItems = nullptr; + + bool AfterGroupBy = false; + +public: + struct TColumnEntry + { + TBaseColumn Column; + + size_t LastTableIndex; + size_t OriginTableIndex; + }; + + THashMap<NAst::TReference, TColumnEntry> Lookup; + + TBuilderCtxBase( + const TTableSchema& schema, + std::optional<TString> alias, + std::vector<TColumnDescriptor>* mapping) + { + Tables.push_back(TTable{schema, alias, mapping}); + } + + // Columns already presented in Lookup are shared. + // In mapping presented all columns needed for read and renamed schema. + // SelfJoinedColumns and ForeignJoinedColumns are builded from Lookup using OriginTableIndex and LastTableIndex. + void Merge(TBuilderCtxBase& other) + { + size_t otherTablesCount = other.Tables.size(); + size_t tablesCount = Tables.size(); + size_t lastTableIndex = tablesCount + otherTablesCount - 1; + + std::move(other.Tables.begin(), other.Tables.end(), std::back_inserter(Tables)); + + for (const auto& [reference, entry] : other.Lookup) { + auto [it, emplaced] = Lookup.emplace( + reference, + TColumnEntry{ + entry.Column, + 0, // Consider not used yet. + tablesCount + entry.OriginTableIndex}); + + if (!emplaced) { + // Column is shared. Increment LastTableIndex to prevent search in new (other merged) tables. + it->second.LastTableIndex = lastTableIndex; + } + } + } + + void PopulateAllColumns() + { + for (const auto& table : Tables) { + for (const auto& column : table.Schema.Columns()) { + GetColumnPtr(NAst::TReference(column.Name(), table.Alias)); + } + } + } + + void SetGroupData(const TNamedItemList* groupItems, TAggregateItemList* aggregateItems) + { + YT_VERIFY(!GroupItems && !AggregateItems); + + GroupItems = groupItems; + AggregateItems = aggregateItems; + AfterGroupBy = true; + } + + void CheckNoOtherColumn(const NAst::TReference& reference, size_t startTableIndex) const + { + for (int index = startTableIndex; index < std::ssize(Tables); ++index) { + auto& [schema, alias, mapping] = Tables[index]; + + if (alias == reference.TableName && schema.FindColumn(reference.ColumnName)) { + THROW_ERROR_EXCEPTION("Ambiguous resolution for column %Qv", + NAst::InferColumnName(reference)); + } + } + } + + std::pair<const TTable*, TLogicalTypePtr> ResolveColumn(const NAst::TReference& reference) const + { + const TTable* result = nullptr; + TLogicalTypePtr type; + + int index = 0; + for (; index < std::ssize(Tables); ++index) { + auto& [schema, alias, mapping] = Tables[index]; + + if (alias != reference.TableName) { + continue; + } + + if (auto* column = schema.FindColumn(reference.ColumnName)) { + auto formattedName = NAst::InferColumnName(reference); + + if (mapping) { + mapping->push_back(TColumnDescriptor{ + formattedName, + schema.GetColumnIndex(*column) + }); + } + result = &Tables[index]; + type = column->LogicalType(); + ++index; + break; + } + } + + CheckNoOtherColumn(reference, index); + + return {result, type}; + } + + static const std::optional<TBaseColumn> FindColumn(const TNamedItemList& schema, const TString& name) + { + for (int index = 0; index < std::ssize(schema); ++index) { + if (schema[index].Name == name) { + return TBaseColumn(name, schema[index].Expression->LogicalType); + } + } + return std::nullopt; + } + + std::optional<TBaseColumn> GetColumnPtr(const NAst::TReference& reference) + { + if (AfterGroupBy) { + // Search other way after group by. + if (reference.TableName) { + return std::nullopt; + } + + return FindColumn(*GroupItems, reference.ColumnName); + } + + size_t lastTableIndex = Tables.size() - 1; + + auto found = Lookup.find(reference); + if (found != Lookup.end()) { + // Provide column from max table index till end. + + size_t nextTableIndex = std::max(found->second.OriginTableIndex, found->second.LastTableIndex) + 1; + + CheckNoOtherColumn(reference, nextTableIndex); + + // Update LastTableIndex after check. + found->second.LastTableIndex = lastTableIndex; + + return found->second.Column; + } else if (auto [table, type] = ResolveColumn(reference); table) { + auto formattedName = NAst::InferColumnName(reference); + auto column = TBaseColumn(formattedName, type); + + auto emplaced = Lookup.emplace( + reference, + TColumnEntry{ + column, + lastTableIndex, + size_t(table - Tables.data())}); + + YT_VERIFY(emplaced.second); + return column; + } else { + return std::nullopt; + } + } +}; + +using TExpressionGenerator = std::function<TConstExpressionPtr(EValueType)>; + +struct TUntypedExpression +{ + TTypeSet FeasibleTypes; + TExpressionGenerator Generator; + bool IsConstant; +}; + +struct TBuilderCtx + : public TBuilderCtxBase +{ +public: + const TString& Source; + const TConstTypeInferrerMapPtr Functions; + const NAst::TAliasMap& AliasMap; + +private: + std::set<TString> UsedAliases; + size_t Depth = 0; + + THashMap<std::pair<TString, EValueType>, TConstAggregateFunctionExpressionPtr> AggregateLookup; + +public: + TBuilderCtx( + const TString& source, + const TConstTypeInferrerMapPtr& functions, + const NAst::TAliasMap& aliasMap, + const TTableSchema& schema, + std::optional<TString> alias, + std::vector<TColumnDescriptor>* mapping) + : TBuilderCtxBase(schema, alias, mapping) + , Source(source) + , Functions(functions) + , AliasMap(aliasMap) + { } + + // TODO: Move ProvideAggregateColumn and GetAggregateColumnPtr to TBuilderCtxBase and provide callback + // OnExpression. + // Or split into two functions. GetAggregate and SetAggregate. + std::pair<TTypeSet, std::function<TConstExpressionPtr(EValueType)>> ProvideAggregateColumn( + const TString& name, + const TAggregateTypeInferrer* aggregateItem, + const NAst::TExpression* argument, + const TString& subexpressionName) + { + YT_VERIFY(AfterGroupBy); + + // TODO: Use guard. + AfterGroupBy = false; + auto untypedOperand = OnExpression(argument); + AfterGroupBy = true; + + TTypeSet constraint; + std::optional<EValueType> stateType; + std::optional<EValueType> resultType; + + aggregateItem->GetNormalizedConstraints(&constraint, &stateType, &resultType, name); + + TTypeSet resultTypes; + TTypeSet genericAssignments = constraint; + + if (!Unify(&genericAssignments, untypedOperand.FeasibleTypes)) { + THROW_ERROR_EXCEPTION("Type mismatch in function %Qv: expected %v, actual %v", + name, + genericAssignments, + untypedOperand.FeasibleTypes) + << TErrorAttribute("source", subexpressionName); + } + + if (resultType) { + resultTypes = TTypeSet({*resultType}); + } else { + resultTypes = genericAssignments; + } + + return std::make_pair(resultTypes, [=, this] (EValueType type) { + EValueType argType; + if (resultType) { + YT_VERIFY(!genericAssignments.IsEmpty()); + argType = GetFrontWithCheck(genericAssignments, argument->GetSource(Source)); + } else { + argType = type; + } + + EValueType effectiveStateType; + if (stateType) { + effectiveStateType = *stateType; + } else { + effectiveStateType = argType; + } + + auto typedOperand = untypedOperand.Generator(argType); + + typedOperand = TCastEliminator().Visit(typedOperand); + typedOperand = TExpressionSimplifier().Visit(typedOperand); + typedOperand = TNotExpressionPropagator().Visit(typedOperand); + + AggregateItems->emplace_back( + std::vector<TConstExpressionPtr>{typedOperand}, + name, + subexpressionName, + effectiveStateType, + type); + + return typedOperand; + }); + } + + TUntypedExpression GetAggregateColumnPtr( + const TString& functionName, + const TAggregateTypeInferrer* aggregateItem, + const NAst::TExpression* argument, + const TString& subexpressionName) + { + if (!AfterGroupBy) { + THROW_ERROR_EXCEPTION("Misuse of aggregate function %Qv", functionName); + } + + auto typer = ProvideAggregateColumn( + functionName, + aggregateItem, + argument, + subexpressionName); + + TExpressionGenerator generator = [=, this] (EValueType type) { + auto key = std::make_pair(subexpressionName, type); + auto found = AggregateLookup.find(key); + if (found != AggregateLookup.end()) { + return found->second; + } else { + auto argExpression = typer.second(type); + TConstAggregateFunctionExpressionPtr expr = New<TAggregateFunctionExpression>( + MakeLogicalType(GetLogicalType(type), false), + subexpressionName, + std::vector{argExpression}, + type, + type, + functionName); + YT_VERIFY(AggregateLookup.emplace(key, expr).second); + return expr; + } + }; + + return TUntypedExpression{typer.first, std::move(generator), false}; + } + + + TUntypedExpression OnExpression( + const NAst::TExpression* expr); + +private: + TUntypedExpression OnReference( + const NAst::TReference& reference); + + TUntypedExpression OnFunction( + const NAst::TFunctionExpression* functionExpr); + + TUntypedExpression OnUnaryOp( + const NAst::TUnaryOpExpression* unaryExpr); + + TUntypedExpression MakeBinaryExpr( + const NAst::TBinaryOpExpression* binaryExpr, + EBinaryOp op, + TUntypedExpression lhs, + TUntypedExpression rhs, + std::optional<size_t> offset); + + friend struct TBinaryOpGenerator; + + TUntypedExpression OnBinaryOp( + const NAst::TBinaryOpExpression* binaryExpr); + + void InferArgumentTypes( + std::vector<TConstExpressionPtr>* typedArguments, + std::vector<EValueType>* argTypes, + const NAst::TExpressionList& expressions, + TStringBuf operatorName, + TStringBuf source); + + TUntypedExpression OnInOp( + const NAst::TInExpression* inExpr); + + TUntypedExpression OnBetweenOp( + const NAst::TBetweenExpression* betweenExpr); + + TUntypedExpression OnTransformOp( + const NAst::TTransformExpression* transformExpr); + +public: + TConstExpressionPtr BuildTypedExpression( + const NAst::TExpression* expr, + TTypeSet feasibleTypes = TTypeSet({ + EValueType::Null, + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Any, + EValueType::Composite})) + { + auto expressionTyper = OnExpression(expr); + YT_VERIFY(!expressionTyper.FeasibleTypes.IsEmpty()); + + if (!Unify(&feasibleTypes, expressionTyper.FeasibleTypes)) { + THROW_ERROR_EXCEPTION("Type mismatch in expression: expected %Qv, got %Qv", + feasibleTypes, + expressionTyper.FeasibleTypes) + << TErrorAttribute("source", expr->GetSource(Source)); + } + + auto result = expressionTyper.Generator( + GetFrontWithCheck(feasibleTypes, expr->GetSource(Source))); + + result = TCastEliminator().Visit(result); + result = TExpressionSimplifier().Visit(result); + result = TNotExpressionPropagator().Visit(result); + return result; + } + +}; + +TUntypedExpression TBuilderCtx::OnExpression( + const NAst::TExpression* expr) +{ + CheckStackDepth(); + + ++Depth; + auto depthGuard = Finally([&] { + --Depth; + }); + + if (Depth > MaxExpressionDepth) { + THROW_ERROR_EXCEPTION("Maximum expression depth exceeded") + << TErrorAttribute("max_expression_depth", MaxExpressionDepth); + } + + if (auto literalExpr = expr->As<NAst::TLiteralExpression>()) { + const auto& literalValue = literalExpr->Value; + + auto resultTypes = GetTypes(literalValue); + TExpressionGenerator generator = [literalValue] (EValueType type) { + return New<TLiteralExpression>( + type, + CastValueWithCheck(GetValue(literalValue), type)); + }; + return TUntypedExpression{resultTypes, std::move(generator), true}; + } else if (auto aliasExpr = expr->As<NAst::TAliasExpression>()) { + return OnReference(NAst::TReference(aliasExpr->Name)); + } else if (auto referenceExpr = expr->As<NAst::TReferenceExpression>()) { + return OnReference(referenceExpr->Reference); + } else if (auto functionExpr = expr->As<NAst::TFunctionExpression>()) { + return OnFunction(functionExpr); + } else if (auto unaryExpr = expr->As<NAst::TUnaryOpExpression>()) { + return OnUnaryOp(unaryExpr); + } else if (auto binaryExpr = expr->As<NAst::TBinaryOpExpression>()) { + return OnBinaryOp(binaryExpr); + } else if (auto inExpr = expr->As<NAst::TInExpression>()) { + return OnInOp(inExpr); + } else if (auto betweenExpr = expr->As<NAst::TBetweenExpression>()) { + return OnBetweenOp(betweenExpr); + } else if (auto transformExpr = expr->As<NAst::TTransformExpression>()) { + return OnTransformOp(transformExpr); + } + + YT_ABORT(); +} + + +TUntypedExpression TBuilderCtx::OnReference(const NAst::TReference& reference) +{ + if (AfterGroupBy) { + if (auto column = GetColumnPtr(reference)) { + TTypeSet resultTypes({GetWireType(column->LogicalType)}); + TExpressionGenerator generator = [column = *column] (EValueType) { + return New<TReferenceExpression>(column.LogicalType, column.Name); + }; + return TUntypedExpression{resultTypes, std::move(generator), false}; + } + } + + if (!reference.TableName) { + const auto& columnName = reference.ColumnName; + auto found = AliasMap.find(columnName); + + if (found != AliasMap.end()) { + // try InferName(found, expand aliases = true) + + if (UsedAliases.insert(columnName).second) { + auto aliasExpr = OnExpression(found->second); + UsedAliases.erase(columnName); + return aliasExpr; + } + } + } + + if (!AfterGroupBy) { + if (auto column = GetColumnPtr(reference)) { + TTypeSet resultTypes({GetWireType(column->LogicalType)}); + TExpressionGenerator generator = [column = *column] (EValueType) { + return New<TReferenceExpression>(column.LogicalType, column.Name); + }; + return TUntypedExpression{resultTypes, std::move(generator), false}; + } + } + + THROW_ERROR_EXCEPTION("Undefined reference %Qv", + NAst::InferColumnName(reference)); +} + +TUntypedExpression TBuilderCtx::OnFunction(const NAst::TFunctionExpression* functionExpr) +{ + auto functionName = functionExpr->FunctionName; + functionName.to_lower(); + + const auto& descriptor = Functions->GetFunction(functionName); + + if (const auto* aggregateFunction = descriptor->As<TAggregateFunctionTypeInferrer>()) { + auto subexpressionName = InferColumnName(*functionExpr); + + std::vector<TTypeSet> argTypes; + std::vector<TTypeSet> genericAssignments; + std::vector<TExpressionGenerator> operandTypers; + std::vector<int> formalArguments; + + YT_VERIFY(AfterGroupBy); + + AfterGroupBy = false; + for (const auto& argument : functionExpr->Arguments) { + auto untypedArgument = OnExpression(argument); + argTypes.push_back(untypedArgument.FeasibleTypes); + operandTypers.push_back(untypedArgument.Generator); + } + AfterGroupBy = true; + + int stateConstraintIndex; + int resultConstraintIndex; + + std::tie(stateConstraintIndex, resultConstraintIndex) = aggregateFunction->GetNormalizedConstraints( + &genericAssignments, + &formalArguments); + IntersectGenericsWithArgumentTypes( + argTypes, + &genericAssignments, + formalArguments, + functionName, + functionExpr->GetSource(Source)); + + auto resultTypes = genericAssignments[resultConstraintIndex]; + + TExpressionGenerator generator = [ + this, + stateConstraintIndex, + resultConstraintIndex, + functionName = std::move(functionName), + subexpressionName = std::move(subexpressionName), + operandTypers = std::move(operandTypers), + genericAssignments = std::move(genericAssignments), + formalArguments = std::move(formalArguments), + source = functionExpr->GetSource(Source) + ] (EValueType type) mutable { + auto key = std::make_pair(subexpressionName, type); + auto foundCached = AggregateLookup.find(key); + if (foundCached != AggregateLookup.end()) { + return foundCached->second; + } + + EValueType stateType; + auto effectiveTypes = RefineFunctionTypes( + resultConstraintIndex, + stateConstraintIndex, + formalArguments, + type, + &stateType, + &genericAssignments, + source); + + std::vector<TConstExpressionPtr> typedOperands; + for (int index = 0; index < std::ssize(effectiveTypes); ++index) { + typedOperands.push_back(operandTypers[index](effectiveTypes[index])); + typedOperands.back() = TCastEliminator().Visit(typedOperands.back()); + typedOperands.back() = TExpressionSimplifier().Visit(typedOperands.back()); + typedOperands.back() = TNotExpressionPropagator().Visit(typedOperands.back()); + } + + AggregateItems->emplace_back( + typedOperands, + functionName, + subexpressionName, + stateType, + type); + + TConstAggregateFunctionExpressionPtr expr = New<TAggregateFunctionExpression>( + MakeLogicalType(GetLogicalType(type), false), + subexpressionName, + typedOperands, + stateType, + type, + functionName); + AggregateLookup.emplace(key, expr); + + return expr; + }; + + return TUntypedExpression{resultTypes, std::move(generator), false}; + } else if (const auto* aggregateItem = descriptor->As<TAggregateTypeInferrer>()) { + auto subexpressionName = InferColumnName(*functionExpr); + + try { + if (functionExpr->Arguments.size() != 1) { + THROW_ERROR_EXCEPTION("Aggregate function %Qv must have exactly one argument", functionName); + } + + auto aggregateColumn = GetAggregateColumnPtr( + functionName, + aggregateItem, + functionExpr->Arguments.front(), + subexpressionName); + + return aggregateColumn; + } catch (const std::exception& ex) { + THROW_ERROR_EXCEPTION("Error creating aggregate") + << TErrorAttribute("source", functionExpr->GetSource(Source)) + << ex; + } + } else if (const auto* regularFunction = descriptor->As<TFunctionTypeInferrer>()) { + std::vector<TTypeSet> argTypes; + std::vector<TExpressionGenerator> operandTypers; + for (const auto& argument : functionExpr->Arguments) { + auto untypedArgument = OnExpression(argument); + argTypes.push_back(untypedArgument.FeasibleTypes); + operandTypers.push_back(untypedArgument.Generator); + } + + std::vector<TTypeSet> genericAssignments; + auto resultTypes = InferFunctionTypes( + regularFunction, + argTypes, + &genericAssignments, + functionName, + functionExpr->GetSource(Source)); + + TExpressionGenerator generator = [ + functionName, + regularFunction, + operandTypers, + genericAssignments, + source = functionExpr->GetSource(Source) + ] (EValueType type) mutable { + auto effectiveTypes = RefineFunctionTypes( + regularFunction, + type, + operandTypers.size(), + &genericAssignments, + source); + + std::vector<TConstExpressionPtr> typedOperands; + for (int index = 0; index < std::ssize(effectiveTypes); ++index) { + typedOperands.push_back(operandTypers[index](effectiveTypes[index])); + } + + return New<TFunctionExpression>(type, functionName, typedOperands); + }; + + return TUntypedExpression{resultTypes, std::move(generator), false}; + } else { + YT_ABORT(); + } +} + +TUntypedExpression TBuilderCtx::OnUnaryOp(const NAst::TUnaryOpExpression* unaryExpr) +{ + if (unaryExpr->Operand.size() != 1) { + THROW_ERROR_EXCEPTION( + "Unary operator %Qv must have exactly one argument", + unaryExpr->Opcode); + } + + auto untypedOperand = OnExpression(unaryExpr->Operand.front()); + + TTypeSet genericAssignments; + auto resultTypes = InferUnaryExprTypes( + unaryExpr->Opcode, + untypedOperand.FeasibleTypes, + &genericAssignments, + unaryExpr->Operand.front()->GetSource(Source)); + + if (untypedOperand.IsConstant) { + auto value = untypedOperand.Generator(untypedOperand.FeasibleTypes.GetFront()); + if (auto foldedExpr = FoldConstants(unaryExpr->Opcode, value)) { + TExpressionGenerator generator = [foldedExpr] (EValueType type) { + return New<TLiteralExpression>( + type, + CastValueWithCheck(*foldedExpr, type)); + }; + return TUntypedExpression{resultTypes, std::move(generator), true}; + } + } + + TExpressionGenerator generator = [ + op = unaryExpr->Opcode, + untypedOperand, + genericAssignments, + opSource = unaryExpr->Operand.front()->GetSource(Source) + ] (EValueType type) mutable { + auto argType = RefineUnaryExprTypes( + op, + type, + &genericAssignments, + opSource); + return New<TUnaryOpExpression>(type, op, untypedOperand.Generator(argType)); + }; + return TUntypedExpression{resultTypes, std::move(generator), false}; +} + +TUntypedExpression TBuilderCtx::MakeBinaryExpr( + const NAst::TBinaryOpExpression* binaryExpr, + EBinaryOp op, + TUntypedExpression lhs, + TUntypedExpression rhs, + std::optional<size_t> offset) +{ + TTypeSet genericAssignments; + + auto lhsSource = offset ? binaryExpr->Lhs[*offset]->GetSource(Source) : ""; + auto rhsSource = offset ? binaryExpr->Rhs[*offset]->GetSource(Source) : ""; + + auto resultTypes = InferBinaryExprTypes( + op, + lhs.FeasibleTypes, + rhs.FeasibleTypes, + &genericAssignments, + lhsSource, + rhsSource); + + if (lhs.IsConstant && rhs.IsConstant) { + auto lhsValue = lhs.Generator(lhs.FeasibleTypes.GetFront()); + auto rhsValue = rhs.Generator(rhs.FeasibleTypes.GetFront()); + if (auto foldedExpr = FoldConstants(op, lhsValue, rhsValue)) { + TExpressionGenerator generator = [foldedExpr] (EValueType type) { + return New<TLiteralExpression>( + type, + CastValueWithCheck(*foldedExpr, type)); + }; + return TUntypedExpression{resultTypes, std::move(generator), true}; + } + } + + TExpressionGenerator generator = [ + op, + lhs, + rhs, + genericAssignments, + lhsSource, + rhsSource, + source = binaryExpr->GetSource(Source) + ] (EValueType type) mutable { + auto argTypes = RefineBinaryExprTypes( + op, + type, + lhs.FeasibleTypes, + rhs.FeasibleTypes, + &genericAssignments, + lhsSource, + rhsSource, + source); + + return New<TBinaryOpExpression>( + type, + op, + lhs.Generator(argTypes.first), + rhs.Generator(argTypes.second)); + }; + return TUntypedExpression{resultTypes, std::move(generator), false}; +} + +struct TBinaryOpGenerator +{ + TBuilderCtx& Builder; + const NAst::TBinaryOpExpression* BinaryExpr; + + TUntypedExpression Do(size_t keySize, EBinaryOp op) + { + YT_VERIFY(keySize > 0); + size_t offset = keySize - 1; + + auto untypedLhs = Builder.OnExpression(BinaryExpr->Lhs[offset]); + auto untypedRhs = Builder.OnExpression(BinaryExpr->Rhs[offset]); + + auto result = Builder.MakeBinaryExpr(BinaryExpr, op, std::move(untypedLhs), std::move(untypedRhs), offset); + + while (offset > 0) { + --offset; + auto untypedLhs = Builder.OnExpression(BinaryExpr->Lhs[offset]); + auto untypedRhs = Builder.OnExpression(BinaryExpr->Rhs[offset]); + + auto eq = Builder.MakeBinaryExpr( + BinaryExpr, + op == EBinaryOp::NotEqual ? EBinaryOp::Or : EBinaryOp::And, + Builder.MakeBinaryExpr( + BinaryExpr, + op == EBinaryOp::NotEqual ? EBinaryOp::NotEqual : EBinaryOp::Equal, + untypedLhs, + untypedRhs, + offset), + std::move(result), + std::nullopt); + + if (op == EBinaryOp::Equal || op == EBinaryOp::NotEqual) { + result = eq; + continue; + } + + EBinaryOp strongOp = op; + if (op == EBinaryOp::LessOrEqual) { + strongOp = EBinaryOp::Less; + } else if (op == EBinaryOp::GreaterOrEqual) { + strongOp = EBinaryOp::Greater; + } + + result = Builder.MakeBinaryExpr( + BinaryExpr, + EBinaryOp::Or, + Builder.MakeBinaryExpr( + BinaryExpr, + strongOp, + std::move(untypedLhs), + std::move(untypedRhs), + offset), + std::move(eq), + std::nullopt); + } + + return result; + } +}; + +TUntypedExpression TBuilderCtx::OnBinaryOp( + const NAst::TBinaryOpExpression* binaryExpr) +{ + if (IsRelationalBinaryOp(binaryExpr->Opcode)) { + if (binaryExpr->Lhs.size() != binaryExpr->Rhs.size()) { + THROW_ERROR_EXCEPTION("Tuples of same size are expected but got %v vs %v", + binaryExpr->Lhs.size(), + binaryExpr->Rhs.size()) + << TErrorAttribute("source", binaryExpr->GetSource(Source)); + } + + int keySize = binaryExpr->Lhs.size(); + return TBinaryOpGenerator{*this, binaryExpr}.Do(keySize, binaryExpr->Opcode); + } else { + if (binaryExpr->Lhs.size() != 1) { + THROW_ERROR_EXCEPTION("Expecting scalar expression") + << TErrorAttribute("source", FormatExpression(binaryExpr->Lhs)); + } + + if (binaryExpr->Rhs.size() != 1) { + THROW_ERROR_EXCEPTION("Expecting scalar expression") + << TErrorAttribute("source", FormatExpression(binaryExpr->Rhs)); + } + + auto untypedLhs = OnExpression(binaryExpr->Lhs.front()); + auto untypedRhs = OnExpression(binaryExpr->Rhs.front()); + + return MakeBinaryExpr(binaryExpr, binaryExpr->Opcode, std::move(untypedLhs), std::move(untypedRhs), 0); + } +} + +void TBuilderCtx::InferArgumentTypes( + std::vector<TConstExpressionPtr>* typedArguments, + std::vector<EValueType>* argTypes, + const NAst::TExpressionList& expressions, + TStringBuf operatorName, + TStringBuf source) +{ + std::unordered_set<TString> columnNames; + + for (const auto& argument : expressions) { + auto untypedArgument = OnExpression(argument); + + EValueType argType = GetFrontWithCheck(untypedArgument.FeasibleTypes, argument->GetSource(Source)); + auto typedArgument = untypedArgument.Generator(argType); + + typedArguments->push_back(typedArgument); + argTypes->push_back(argType); + if (auto reference = typedArgument->As<TReferenceExpression>()) { + if (!columnNames.insert(reference->ColumnName).second) { + THROW_ERROR_EXCEPTION("%v operator has multiple references to column %Qv", + operatorName, + reference->ColumnName) + << TErrorAttribute("source", source); + } + } + } +} + +TUntypedExpression TBuilderCtx::OnInOp( + const NAst::TInExpression* inExpr) +{ + std::vector<TConstExpressionPtr> typedArguments; + std::vector<EValueType> argTypes; + + auto source = inExpr->GetSource(Source); + + InferArgumentTypes( + &typedArguments, + &argTypes, + inExpr->Expr, + "IN", + inExpr->GetSource(Source)); + + auto capturedRows = LiteralTupleListToRows(inExpr->Values, argTypes, source); + auto result = New<TInExpression>(std::move(typedArguments), std::move(capturedRows)); + + TTypeSet resultTypes({EValueType::Boolean}); + TExpressionGenerator generator = [result] (EValueType /*type*/) mutable { + return result; + }; + return TUntypedExpression{resultTypes, std::move(generator), false}; +} + +TUntypedExpression TBuilderCtx::OnBetweenOp( + const NAst::TBetweenExpression* betweenExpr) +{ + std::vector<TConstExpressionPtr> typedArguments; + std::vector<EValueType> argTypes; + + auto source = betweenExpr->GetSource(Source); + + InferArgumentTypes( + &typedArguments, + &argTypes, + betweenExpr->Expr, + "BETWEEN", + source); + + auto capturedRows = LiteralRangesListToRows(betweenExpr->Values, argTypes, source); + auto result = New<TBetweenExpression>(std::move(typedArguments), std::move(capturedRows)); + + TTypeSet resultTypes({EValueType::Boolean}); + TExpressionGenerator generator = [result] (EValueType /*type*/) mutable { + return result; + }; + return TUntypedExpression{resultTypes, std::move(generator), false}; +} + +TUntypedExpression TBuilderCtx::OnTransformOp( + const NAst::TTransformExpression* transformExpr) +{ + std::vector<TConstExpressionPtr> typedArguments; + std::vector<EValueType> argTypes; + + auto source = transformExpr->GetSource(Source); + + InferArgumentTypes( + &typedArguments, + &argTypes, + transformExpr->Expr, + "TRANSFORM", + source); + + if (transformExpr->From.size() != transformExpr->To.size()) { + THROW_ERROR_EXCEPTION("Size mismatch for source and result arrays in TRANSFORM operator") + << TErrorAttribute("source", source); + } + + TTypeSet resultTypes({ + EValueType::Null, + EValueType::Int64, + EValueType::Uint64, + EValueType::Double, + EValueType::Boolean, + EValueType::String, + EValueType::Any}); + + for (const auto& tuple : transformExpr->To) { + if (tuple.size() != 1) { + THROW_ERROR_EXCEPTION("Expecting scalar expression") + << TErrorAttribute("source", source); + } + + auto valueTypes = GetTypes(tuple.front()); + + if (!Unify(&resultTypes, valueTypes)) { + THROW_ERROR_EXCEPTION("Types mismatch in tuple") + << TErrorAttribute("source", source) + << TErrorAttribute("actual_type", ToString(valueTypes)) + << TErrorAttribute("expected_type", ToString(resultTypes)); + } + } + + const auto& defaultExpr = transformExpr->DefaultExpr; + + TConstExpressionPtr defaultTypedExpr; + + EValueType resultType; + if (defaultExpr) { + if (defaultExpr->size() != 1) { + THROW_ERROR_EXCEPTION("Default expression must scalar") + << TErrorAttribute("source", source); + } + + auto untypedArgument = OnExpression(defaultExpr->front()); + + if (!Unify(&resultTypes, untypedArgument.FeasibleTypes)) { + THROW_ERROR_EXCEPTION("Type mismatch in default expression: expected %Qlv, got %Qlv", + resultTypes, + untypedArgument.FeasibleTypes) + << TErrorAttribute("source", source); + } + + resultType = GetFrontWithCheck(resultTypes, source); + + defaultTypedExpr = untypedArgument.Generator(resultType); + } else { + resultType = GetFrontWithCheck(resultTypes, source); + } + + auto rowBuffer = New<TRowBuffer>(TQueryPreparerBufferTag()); + TUnversionedRowBuilder rowBuilder; + std::vector<TRow> rows; + + for (int index = 0; index < std::ssize(transformExpr->From); ++index) { + const auto& sourceTuple = transformExpr->From[index]; + if (sourceTuple.size() != argTypes.size()) { + THROW_ERROR_EXCEPTION("Arguments size mismatch in tuple") + << TErrorAttribute("source", source); + } + for (int i = 0; i < std::ssize(sourceTuple); ++i) { + auto valueType = GetType(sourceTuple[i]); + auto value = GetValue(sourceTuple[i]); + + if (valueType == EValueType::Null) { + value = MakeUnversionedSentinelValue(EValueType::Null); + } else if (valueType != argTypes[i]) { + if (IsArithmeticType(valueType) && IsArithmeticType(argTypes[i])) { + value = CastValueWithCheck(value, argTypes[i]); + } else { + THROW_ERROR_EXCEPTION("Types mismatch in tuple") + << TErrorAttribute("source", source) + << TErrorAttribute("actual_type", valueType) + << TErrorAttribute("expected_type", argTypes[i]); + } + } + rowBuilder.AddValue(value); + } + + const auto& resultTuple = transformExpr->To[index]; + + YT_VERIFY(resultTuple.size() == 1); + auto value = CastValueWithCheck(GetValue(resultTuple.front()), resultType); + rowBuilder.AddValue(value); + + rows.push_back(rowBuffer->CaptureRow(rowBuilder.GetRow())); + rowBuilder.Reset(); + } + + std::sort(rows.begin(), rows.end(), [argCount = argTypes.size()] (TRow lhs, TRow rhs) { + return CompareRows(lhs, rhs, argCount) < 0; + }); + + auto capturedRows = MakeSharedRange(std::move(rows), std::move(rowBuffer)); + auto result = New<TTransformExpression>( + resultType, + std::move(typedArguments), + std::move(capturedRows), + std::move(defaultTypedExpr)); + + TExpressionGenerator generator = [result] (EValueType /*type*/) mutable { + return result; + }; + return TUntypedExpression{TTypeSet({resultType}), std::move(generator), false}; +} + +//////////////////////////////////////////////////////////////////////////////// + +TConstExpressionPtr BuildPredicate( + const NAst::TExpressionList& expressionAst, + TBuilderCtx& builder, + TStringBuf name) +{ + if (expressionAst.size() != 1) { + THROW_ERROR_EXCEPTION("Expecting scalar expression") + << TErrorAttribute("source", FormatExpression(expressionAst)); + } + + auto typedPredicate = builder.BuildTypedExpression(expressionAst.front()); + + auto actualType = typedPredicate->GetWireType(); + EValueType expectedType(EValueType::Boolean); + if (actualType != expectedType) { + THROW_ERROR_EXCEPTION("%v is not a boolean expression", name) + << TErrorAttribute("source", expressionAst.front()->GetSource(builder.Source)) + << TErrorAttribute("actual_type", actualType) + << TErrorAttribute("expected_type", expectedType); + } + + return typedPredicate; +} + +TGroupClausePtr BuildGroupClause( + const NAst::TExpressionList& expressionsAst, + ETotalsMode totalsMode, + TBuilderCtx& builder) +{ + auto groupClause = New<TGroupClause>(); + groupClause->TotalsMode = totalsMode; + + for (const auto& expressionAst : expressionsAst) { + auto typedExpr = builder.BuildTypedExpression(expressionAst, ComparableTypes); + + groupClause->AddGroupItem(typedExpr, InferColumnName(*expressionAst)); + } + + builder.SetGroupData( + &groupClause->GroupItems, + &groupClause->AggregateItems); + + return groupClause; +} + +TConstProjectClausePtr BuildProjectClause( + const NAst::TExpressionList& expressionsAst, + TBuilderCtx& builder) +{ + auto projectClause = New<TProjectClause>(); + for (const auto& expressionAst : expressionsAst) { + auto typedExpr = builder.BuildTypedExpression(expressionAst); + + projectClause->AddProjection(typedExpr, InferColumnName(*expressionAst)); + } + + return projectClause; +} + +void PrepareQuery( + const TQueryPtr& query, + const NAst::TQuery& ast, + TBuilderCtx& builder) +{ + if (ast.WherePredicate) { + auto wherePredicate = BuildPredicate(*ast.WherePredicate, builder, "WHERE-clause"); + query->WhereClause = IsTrue(wherePredicate) ? nullptr : wherePredicate; + } + + if (ast.GroupExprs) { + auto groupClause = BuildGroupClause(ast.GroupExprs->first, ast.GroupExprs->second, builder); + + auto keyColumns = query->GetKeyColumns(); + + TNamedItemList groupItems = std::move(groupClause->GroupItems); + + std::vector<int> touchedKeyColumns(keyColumns.size(), -1); + for (int index = 0; index < std::ssize(groupItems); ++index) { + const auto& item = groupItems[index]; + if (auto referenceExpr = item.Expression->As<TReferenceExpression>()) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0) { + touchedKeyColumns[keyPartIndex] = index; + } + } + } + + size_t keyPrefix = 0; + for (; keyPrefix < touchedKeyColumns.size(); ++keyPrefix) { + if (touchedKeyColumns[keyPrefix] >= 0) { + continue; + } + + const auto& expression = query->Schema.Original->Columns()[keyPrefix].Expression(); + + if (!expression) { + break; + } + + // Call PrepareExpression to extract references only. + THashSet<TString> references; + PrepareExpression(*expression, *query->Schema.Original, builder.Functions, &references); + + auto canEvaluate = true; + for (const auto& reference : references) { + int referenceIndex = query->Schema.Original->GetColumnIndexOrThrow(reference); + if (touchedKeyColumns[referenceIndex] < 0) { + canEvaluate = false; + } + } + + if (!canEvaluate) { + break; + } + } + + touchedKeyColumns.resize(keyPrefix); + for (int index : touchedKeyColumns) { + if (index >= 0) { + groupClause->GroupItems.push_back(std::move(groupItems[index])); + } + } + + groupClause->CommonPrefixWithPrimaryKey = groupClause->GroupItems.size(); + + for (auto& item : groupItems) { + if (item.Expression) { + groupClause->GroupItems.push_back(std::move(item)); + } + } + + query->GroupClause = groupClause; + + // not prefix, because of equal prefixes near borders + bool containsPrimaryKey = keyPrefix == query->GetKeyColumns().size(); + // COMPAT(lukyan) + query->UseDisjointGroupBy = containsPrimaryKey; + } + + if (ast.HavingPredicate) { + if (!query->GroupClause) { + THROW_ERROR_EXCEPTION("Expected GROUP BY before HAVING"); + } + query->HavingClause = BuildPredicate( + *ast.HavingPredicate, + builder, + "HAVING-clause"); + } + + if (!ast.OrderExpressions.empty()) { + auto orderClause = New<TOrderClause>(); + + for (const auto& orderExpr : ast.OrderExpressions) { + for (const auto& expressionAst : orderExpr.first) { + auto typedExpr = builder.BuildTypedExpression( + expressionAst, + ComparableTypes); + + orderClause->OrderItems.push_back({typedExpr, orderExpr.second}); + } + } + + ssize_t keyPrefix = 0; + while (keyPrefix < std::ssize(orderClause->OrderItems)) { + const auto& item = orderClause->OrderItems[keyPrefix]; + + if (item.Descending) { + break; + } + + const auto* referenceExpr = item.Expression->As<TReferenceExpression>(); + + if (!referenceExpr) { + break; + } + + auto columnIndex = ColumnNameToKeyPartIndex(query->GetKeyColumns(), referenceExpr->ColumnName); + + if (keyPrefix != columnIndex) { + break; + } + ++keyPrefix; + } + + if (keyPrefix < std::ssize(orderClause->OrderItems)) { + query->OrderClause = std::move(orderClause); + } + + // Use ordered scan otherwise + } + + if (ast.SelectExprs) { + query->ProjectClause = BuildProjectClause( + *ast.SelectExprs, + builder); + } else { + // Select all columns. + builder.PopulateAllColumns(); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +class TYsonToQueryExpressionConvertVisitor + : public TYsonConsumerBase +{ +public: + explicit TYsonToQueryExpressionConvertVisitor(TStringBuilder* builder) + : Builder_(builder) + { } + + void OnStringScalar(TStringBuf value) override + { + Builder_->AppendChar('"'); + Builder_->AppendString(EscapeC(value)); + Builder_->AppendChar('"'); + } + + void OnInt64Scalar(i64 value) override + { + Builder_->AppendFormat("%v", value); + } + + void OnUint64Scalar(ui64 value) override + { + Builder_->AppendFormat("%vu", value); + } + + void OnDoubleScalar(double value) override + { + Builder_->AppendFormat("%lf", value); + } + + void OnBooleanScalar(bool value) override + { + Builder_->AppendFormat("%lv", value); + } + + void OnEntity() override + { + Builder_->AppendString("null"); + } + + void OnBeginList() override + { + Builder_->AppendChar('('); + InListBeginning_ = true; + } + + void OnListItem() override + { + if (!InListBeginning_) { + Builder_->AppendString(", "); + } + InListBeginning_ = false; + } + + void OnEndList() override + { + Builder_->AppendChar(')'); + } + + void OnBeginMap() override + { + THROW_ERROR_EXCEPTION("Maps inside YSON placeholder are not allowed"); + } + + void OnKeyedItem(TStringBuf) override + { + THROW_ERROR_EXCEPTION("Maps inside YSON placeholder are not allowed"); + } + + void OnEndMap() override + { + THROW_ERROR_EXCEPTION("Maps inside YSON placeholder are not allowed"); + } + + void OnBeginAttributes() override + { + THROW_ERROR_EXCEPTION("Attributes inside YSON placeholder are not allowed"); + } + + void OnEndAttributes() override + { + THROW_ERROR_EXCEPTION("Attributes inside YSON placeholder are not allowed"); + } + +private: + TStringBuilder* Builder_; + bool InListBeginning_; +}; + +void YsonParseError(TStringBuf message, TYsonStringBuf source) +{ + THROW_ERROR_EXCEPTION("%v", message) + << TErrorAttribute("context", Format("%v", source.AsStringBuf())); +} + +THashMap<TString, TString> ConvertYsonPlaceholdersToQueryLiterals(TYsonStringBuf placeholders) +{ + TMemoryInput input{placeholders.AsStringBuf()}; + TYsonPullParser ysonParser{&input, EYsonType::Node}; + TYsonPullParserCursor ysonCursor{&ysonParser}; + + if (ysonCursor->GetType() != EYsonItemType::BeginMap) { + YsonParseError("Incorrect placeholder argument: YSON map expected", placeholders); + } + + ysonCursor.Next(); + + THashMap<TString, TString> queryLiterals; + while (ysonCursor->GetType() != EYsonItemType::EndMap) { + if (ysonCursor->GetType() != EYsonItemType::StringValue) { + YsonParseError("Incorrect YSON map placeholder: keys should be strings", placeholders); + } + auto key = TString(ysonCursor->UncheckedAsString()); + + ysonCursor.Next(); + switch (ysonCursor->GetType()) { + case EYsonItemType::EntityValue: + case EYsonItemType::BooleanValue: + case EYsonItemType::Int64Value: + case EYsonItemType::Uint64Value: + case EYsonItemType::DoubleValue: + case EYsonItemType::StringValue: + case EYsonItemType::BeginList: { + TStringBuilder valueBuilder; + TYsonToQueryExpressionConvertVisitor ysonValueTransferrer{&valueBuilder}; + ysonCursor.TransferComplexValue(&ysonValueTransferrer); + queryLiterals.emplace(std::move(key), valueBuilder.Flush()); + break; + } + default: + YsonParseError("Incorrect placeholder map: values should be plain types or lists", placeholders); + } + } + + return queryLiterals; +} + +void ParseQueryString( + NAst::TAstHead* astHead, + const TString& source, + NAst::TParser::token::yytokentype strayToken, + TYsonStringBuf placeholderValues = {}) +{ + THashMap<TString, TString> queryLiterals; + if (placeholderValues) { + queryLiterals = ConvertYsonPlaceholdersToQueryLiterals(placeholderValues); + } + + NAst::TLexer lexer(source, strayToken, std::move(queryLiterals)); + NAst::TParser parser(lexer, astHead, source); + + int result = parser.parse(); + + if (result != 0) { + THROW_ERROR_EXCEPTION("Parse failure") + << TErrorAttribute("source", source); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +NAst::TParser::token::yytokentype GetStrayToken(EParseMode mode) +{ + switch (mode) { + case EParseMode::Query: return NAst::TParser::token::StrayWillParseQuery; + case EParseMode::JobQuery: return NAst::TParser::token::StrayWillParseJobQuery; + case EParseMode::Expression: return NAst::TParser::token::StrayWillParseExpression; + default: YT_ABORT(); + } +} + +NAst::TAstHead MakeAstHead(EParseMode mode) +{ + switch (mode) { + case EParseMode::Query: + case EParseMode::JobQuery: return NAst::TAstHead::MakeQuery(); + case EParseMode::Expression: return NAst::TAstHead::MakeExpression(); + default: YT_ABORT(); + } +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +void DefaultFetchFunctions(const std::vector<TString>& /*names*/, const TTypeInferrerMapPtr& typeInferrers) +{ + MergeFrom(typeInferrers.Get(), *GetBuiltinTypeInferrers()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TParsedSource::TParsedSource(const TString& source, NAst::TAstHead astHead) + : Source(source) + , AstHead(std::move(astHead)) +{ } + +std::unique_ptr<TParsedSource> ParseSource( + const TString& source, + EParseMode mode, + TYsonStringBuf placeholderValues) +{ + auto parsedSource = std::make_unique<TParsedSource>( + source, + MakeAstHead(mode)); + ParseQueryString( + &parsedSource->AstHead, + source, + GetStrayToken(mode), + placeholderValues); + return parsedSource; +} + +//////////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<TPlanFragment> PreparePlanFragment( + IPrepareCallbacks* callbacks, + const TString& source, + const TFunctionsFetcher& functionsFetcher, + TYsonStringBuf placeholderValues) +{ + return PreparePlanFragment( + callbacks, + *ParseSource(source, EParseMode::Query, placeholderValues), + functionsFetcher); +} + +std::unique_ptr<TPlanFragment> PreparePlanFragment( + IPrepareCallbacks* callbacks, + const TParsedSource& parsedSource, + const TFunctionsFetcher& functionsFetcher) +{ + auto query = New<TQuery>(TGuid::Create()); + + auto Logger = MakeQueryLogger(query); + + const auto& ast = std::get<NAst::TQuery>(parsedSource.AstHead.Ast); + const auto& aliasMap = parsedSource.AstHead.AliasMap; + + auto functionNames = ExtractFunctionNames(ast, aliasMap); + + auto functions = New<TTypeInferrerMap>(); + functionsFetcher(functionNames, functions); + + const auto& table = ast.Table; + + YT_LOG_DEBUG("Getting initial data splits (PrimaryPath: %v, ForeignPaths: %v)", + table.Path, + MakeFormattableView(ast.Joins, [] (TStringBuilderBase* builder, const auto& join) { + FormatValue(builder, join.Table.Path, TStringBuf()); + })); + + std::vector<TFuture<TDataSplit>> asyncDataSplits; + asyncDataSplits.push_back(callbacks->GetInitialSplit(table.Path)); + for (const auto& join : ast.Joins) { + asyncDataSplits.push_back(callbacks->GetInitialSplit(join.Table.Path)); + } + + auto dataSplits = WaitFor(AllSucceeded(asyncDataSplits)) + .ValueOrThrow(); + + YT_LOG_DEBUG("Initial data splits received"); + + const auto& selfDataSplit = dataSplits[0]; + + auto tableSchema = selfDataSplit.TableSchema; + query->Schema.Original = tableSchema; + + TBuilderCtx builder{ + parsedSource.Source, + functions, + aliasMap, + *query->Schema.Original, + table.Alias, + &query->Schema.Mapping}; + + size_t commonKeyPrefix = std::numeric_limits<size_t>::max(); + + std::vector<TJoinClausePtr> joinClauses; + for (size_t joinIndex = 0; joinIndex < ast.Joins.size(); ++joinIndex) { + const auto& join = ast.Joins[joinIndex]; + const auto& foreignDataSplit = dataSplits[joinIndex + 1]; + + auto foreignTableSchema = foreignDataSplit.TableSchema; + auto foreignKeyColumnsCount = foreignTableSchema->GetKeyColumns().size(); + + auto joinClause = New<TJoinClause>(); + joinClause->Schema.Original = foreignTableSchema; + joinClause->ForeignObjectId = foreignDataSplit.ObjectId; + joinClause->ForeignCellId = foreignDataSplit.CellId; + joinClause->IsLeft = join.IsLeft; + + // BuildPredicate and BuildTypedExpression are used with foreignBuilder. + TBuilderCtx foreignBuilder{ + parsedSource.Source, + functions, + aliasMap, + *joinClause->Schema.Original, + join.Table.Alias, + &joinClause->Schema.Mapping}; + + std::vector<TSelfEquation> selfEquations; + std::vector<TConstExpressionPtr> foreignEquations; + // Merge columns. + for (const auto& referenceExpr : join.Fields) { + auto selfColumn = builder.GetColumnPtr(referenceExpr->Reference); + auto foreignColumn = foreignBuilder.GetColumnPtr(referenceExpr->Reference); + + if (!selfColumn || !foreignColumn) { + THROW_ERROR_EXCEPTION("Column %Qv not found", + NAst::InferColumnName(referenceExpr->Reference)); + } + + if (!NTableClient::IsV1Type(selfColumn->LogicalType) || !NTableClient::IsV1Type(foreignColumn->LogicalType)) { + THROW_ERROR_EXCEPTION("Cannot join column %Qv of nonsimple type", + NAst::InferColumnName(referenceExpr->Reference)) + << TErrorAttribute("self_type", selfColumn->LogicalType) + << TErrorAttribute("foreign_type", foreignColumn->LogicalType); + } + + // N.B. When we try join optional<int32> and int16 columns it must work. + if (NTableClient::GetWireType(selfColumn->LogicalType) != NTableClient::GetWireType(foreignColumn->LogicalType)) { + THROW_ERROR_EXCEPTION("Column %Qv type mismatch in join", + NAst::InferColumnName(referenceExpr->Reference)) + << TErrorAttribute("self_type", selfColumn->LogicalType) + << TErrorAttribute("foreign_type", foreignColumn->LogicalType); + } + + selfEquations.push_back({New<TReferenceExpression>(selfColumn->LogicalType, selfColumn->Name), false}); + foreignEquations.push_back(New<TReferenceExpression>(foreignColumn->LogicalType, foreignColumn->Name)); + } + + for (const auto& argument : join.Lhs) { + selfEquations.push_back({builder.BuildTypedExpression(argument, ComparableTypes), false}); + } + + for (const auto& argument : join.Rhs) { + foreignEquations.push_back( + foreignBuilder.BuildTypedExpression(argument, ComparableTypes)); + } + + if (selfEquations.size() != foreignEquations.size()) { + THROW_ERROR_EXCEPTION("Tuples of same size are expected but got %v vs %v", + selfEquations.size(), + foreignEquations.size()) + << TErrorAttribute("lhs_source", FormatExpression(join.Lhs)) + << TErrorAttribute("rhs_source", FormatExpression(join.Rhs)); + } + + for (int index = 0; index < std::ssize(selfEquations); ++index) { + if (*selfEquations[index].Expression->LogicalType != *foreignEquations[index]->LogicalType) { + THROW_ERROR_EXCEPTION("Types mismatch in join equation \"%v = %v\"", + InferName(selfEquations[index].Expression), + InferName(foreignEquations[index])) + << TErrorAttribute("self_type", selfEquations[index].Expression->LogicalType) + << TErrorAttribute("foreign_type", foreignEquations[index]->LogicalType); + } + } + + // If can use ranges, rearrange equations according to key columns and enrich with evaluated columns + + std::vector<TSelfEquation> keySelfEquations(foreignKeyColumnsCount); + std::vector<TConstExpressionPtr> keyForeignEquations(foreignKeyColumnsCount); + + for (size_t equationIndex = 0; equationIndex < foreignEquations.size(); ++equationIndex) { + const auto& expr = foreignEquations[equationIndex]; + + if (const auto* referenceExpr = expr->As<TReferenceExpression>()) { + int index = ColumnNameToKeyPartIndex(joinClause->GetKeyColumns(), referenceExpr->ColumnName); + + if (index >= 0) { + keySelfEquations[index] = selfEquations[equationIndex]; + keyForeignEquations[index] = foreignEquations[equationIndex]; + continue; + } + } + + keySelfEquations.push_back(selfEquations[equationIndex]); + keyForeignEquations.push_back(foreignEquations[equationIndex]); + } + + size_t keyPrefix = 0; + for (; keyPrefix < foreignKeyColumnsCount; ++keyPrefix) { + if (keyForeignEquations[keyPrefix]) { + YT_VERIFY(keySelfEquations[keyPrefix].Expression); + + if (const auto* referenceExpr = keySelfEquations[keyPrefix].Expression->As<TReferenceExpression>()) { + if (ColumnNameToKeyPartIndex(query->GetKeyColumns(), referenceExpr->ColumnName) != static_cast<ssize_t>(keyPrefix)) { + commonKeyPrefix = std::min(commonKeyPrefix, keyPrefix); + } + } else { + commonKeyPrefix = std::min(commonKeyPrefix, keyPrefix); + } + + continue; + } + + const auto& foreignColumnExpression = foreignTableSchema->Columns()[keyPrefix].Expression(); + + if (!foreignColumnExpression) { + break; + } + + THashSet<TString> references; + auto evaluatedColumnExpression = PrepareExpression( + *foreignColumnExpression, + *foreignTableSchema, + functions, + &references); + + auto canEvaluate = true; + for (const auto& reference : references) { + int referenceIndex = foreignTableSchema->GetColumnIndexOrThrow(reference); + if (!keySelfEquations[referenceIndex].Expression) { + YT_VERIFY(!keyForeignEquations[referenceIndex]); + canEvaluate = false; + } + } + + if (!canEvaluate) { + break; + } + + keySelfEquations[keyPrefix] = {evaluatedColumnExpression, true}; + + auto reference = NAst::TReference( + foreignTableSchema->Columns()[keyPrefix].Name(), + join.Table.Alias); + + auto foreignColumn = foreignBuilder.GetColumnPtr(reference); + + keyForeignEquations[keyPrefix] = New<TReferenceExpression>( + foreignColumn->LogicalType, + foreignColumn->Name); + } + + commonKeyPrefix = std::min(commonKeyPrefix, keyPrefix); + + for (size_t index = 0; index < keyPrefix; ++index) { + if (keySelfEquations[index].Evaluated) { + const auto& evaluatedColumnExpression = keySelfEquations[index].Expression; + + if (const auto& selfColumnExpression = tableSchema->Columns()[index].Expression()) { + auto evaluatedSelfColumnExpression = PrepareExpression( + *selfColumnExpression, + *tableSchema, + functions); + + if (!Compare( + evaluatedColumnExpression, + *foreignTableSchema, + evaluatedSelfColumnExpression, + *tableSchema, + commonKeyPrefix)) + { + commonKeyPrefix = std::min(commonKeyPrefix, index); + } + } else { + commonKeyPrefix = std::min(commonKeyPrefix, index); + } + } + } + + YT_VERIFY(keyForeignEquations.size() == keySelfEquations.size()); + + size_t lastEmptyIndex = keyPrefix; + for (int index = keyPrefix; index < std::ssize(keyForeignEquations); ++index) { + if (keyForeignEquations[index]) { + YT_VERIFY(keySelfEquations[index].Expression); + keyForeignEquations[lastEmptyIndex] = std::move(keyForeignEquations[index]); + keySelfEquations[lastEmptyIndex] = std::move(keySelfEquations[index]); + ++lastEmptyIndex; + } + } + + keyForeignEquations.resize(lastEmptyIndex); + keySelfEquations.resize(lastEmptyIndex); + + joinClause->SelfEquations = std::move(keySelfEquations); + joinClause->ForeignEquations = std::move(keyForeignEquations); + joinClause->ForeignKeyPrefix = keyPrefix; + joinClause->CommonKeyPrefix = commonKeyPrefix; + + YT_LOG_DEBUG("Creating join (CommonKeyPrefix: %v, ForeignKeyPrefix: %v)", + commonKeyPrefix, + keyPrefix); + + if (join.Predicate) { + joinClause->Predicate = BuildPredicate( + *join.Predicate, + foreignBuilder, + "JOIN-PREDICATE-clause"); + } + + builder.Merge(foreignBuilder); + + joinClauses.push_back(std::move(joinClause)); + } + + PrepareQuery(query, ast, builder); + + // Must be filled after builder.Finish() + for (const auto& [reference, entry] : builder.Lookup) { + auto formattedName = NAst::InferColumnName(reference); + + for (size_t index = entry.OriginTableIndex; index < entry.LastTableIndex; ++index) { + YT_VERIFY(index < joinClauses.size()); + joinClauses[index]->SelfJoinedColumns.push_back(formattedName); + } + + if (entry.OriginTableIndex > 0 && entry.LastTableIndex > 0) { + joinClauses[entry.OriginTableIndex - 1]->ForeignJoinedColumns.push_back(formattedName); + } + } + + // Why after PrepareQuery? GetTableSchema is called inside PrepareQuery? + query->JoinClauses.assign(joinClauses.begin(), joinClauses.end()); + + if (ast.Limit) { + if (*ast.Limit > MaxQueryLimit) { + THROW_ERROR_EXCEPTION("Maximum LIMIT exceeded") + << TErrorAttribute("limit", *ast.Limit) + << TErrorAttribute("max_limit", MaxQueryLimit); + } + + query->Limit = *ast.Limit; + + if (!query->OrderClause && query->HavingClause) { + THROW_ERROR_EXCEPTION("HAVING with LIMIT is not allowed"); + } + } else if (!ast.OrderExpressions.empty()) { + THROW_ERROR_EXCEPTION("ORDER BY used without LIMIT"); + } + + if (ast.Offset) { + if (!query->OrderClause && query->HavingClause) { + THROW_ERROR_EXCEPTION("HAVING with OFFSET is not allowed"); + } + + query->Offset = *ast.Offset; + + if (!ast.Limit) { + THROW_ERROR_EXCEPTION("OFFSET used without LIMIT"); + } + } + + auto queryFingerprint = InferName(query, {.OmitValues = true}); + YT_LOG_DEBUG("Prepared query (Fingerprint: %v, ReadSchema: %v, ResultSchema: %v)", + queryFingerprint, + *query->GetReadSchema(), + *query->GetTableSchema()); + + auto fragment = std::make_unique<TPlanFragment>(); + fragment->Query = query; + fragment->DataSource.ObjectId = selfDataSplit.ObjectId; + fragment->DataSource.CellId = selfDataSplit.CellId; + fragment->DataSource.Ranges = MakeSingletonRowRange(selfDataSplit.LowerBound, selfDataSplit.UpperBound); + + return fragment; +} + +TQueryPtr PrepareJobQuery( + const TString& source, + const TTableSchemaPtr& tableSchema, + const TFunctionsFetcher& functionsFetcher) +{ + auto astHead = NAst::TAstHead::MakeQuery(); + ParseQueryString( + &astHead, + source, + NAst::TParser::token::StrayWillParseJobQuery); + + const auto& ast = std::get<NAst::TQuery>(astHead.Ast); + const auto& aliasMap = astHead.AliasMap; + + if (ast.Offset) { + THROW_ERROR_EXCEPTION("OFFSET is not supported in map-reduce queries"); + } + + if (ast.Limit) { + THROW_ERROR_EXCEPTION("LIMIT is not supported in map-reduce queries"); + } + + if (ast.GroupExprs) { + THROW_ERROR_EXCEPTION("GROUP BY is not supported in map-reduce queries"); + } + + auto query = New<TQuery>(TGuid::Create()); + query->Schema.Original = tableSchema; + + auto functionNames = ExtractFunctionNames(ast, aliasMap); + + auto functions = New<TTypeInferrerMap>(); + functionsFetcher(functionNames, functions); + + TBuilderCtx builder{ + source, + functions, + aliasMap, + *tableSchema, + std::nullopt, + &query->Schema.Mapping}; + + PrepareQuery( + query, + ast, + builder); + + return query; +} + +TConstExpressionPtr PrepareExpression( + const TString& source, + const TTableSchema& tableSchema, + const TConstTypeInferrerMapPtr& functions, + THashSet<TString>* references) +{ + return PrepareExpression( + *ParseSource(source, EParseMode::Expression), + tableSchema, + functions, + references); +} + +TConstExpressionPtr PrepareExpression( + const TParsedSource& parsedSource, + const TTableSchema& tableSchema, + const TConstTypeInferrerMapPtr& functions, + THashSet<TString>* references) +{ + auto expr = std::get<NAst::TExpressionPtr>(parsedSource.AstHead.Ast); + const auto& aliasMap = parsedSource.AstHead.AliasMap; + + std::vector<TColumnDescriptor> mapping; + + TBuilderCtx builder{ + parsedSource.Source, + functions, + aliasMap, + tableSchema, + std::nullopt, + &mapping}; + + auto result = builder.BuildTypedExpression(expr); + + if (references) { + for (const auto& item : mapping) { + references->insert(item.Name); + } + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/query_preparer.h b/yt/yt/library/query/base/query_preparer.h new file mode 100644 index 0000000000..4604e9358a --- /dev/null +++ b/yt/yt/library/query/base/query_preparer.h @@ -0,0 +1,82 @@ +#pragma once + +#include "query.h" +#include "ast.h" +#include "callbacks.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +using TFunctionsFetcher = std::function<void( + const std::vector<TString>& names, + const TTypeInferrerMapPtr& typeInferrers)>; + +void DefaultFetchFunctions( + const std::vector<TString>& names, + const TTypeInferrerMapPtr& typeInferrers); + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(EParseMode, + (Query) + (JobQuery) + (Expression) +); + +struct TParsedSource +{ + TParsedSource( + const TString& source, + NAst::TAstHead astHead); + + TString Source; + NAst::TAstHead AstHead; +}; + +std::unique_ptr<TParsedSource> ParseSource( + const TString& source, + EParseMode mode, + NYson::TYsonStringBuf placeholderValues = {}); + +//////////////////////////////////////////////////////////////////////////////// + +struct TPlanFragment +{ + TQueryPtr Query; + TDataSource DataSource; +}; + +std::unique_ptr<TPlanFragment> PreparePlanFragment( + IPrepareCallbacks* callbacks, + const TString& source, + const TFunctionsFetcher& functionsFetcher = DefaultFetchFunctions, + NYson::TYsonStringBuf placeholderValues = {}); + +std::unique_ptr<TPlanFragment> PreparePlanFragment( + IPrepareCallbacks* callbacks, + const TParsedSource& parsedSource, + const TFunctionsFetcher& functionsFetcher = DefaultFetchFunctions); + +//////////////////////////////////////////////////////////////////////////////// + +TQueryPtr PrepareJobQuery( + const TString& source, + const TTableSchemaPtr& tableSchema, + const TFunctionsFetcher& functionsFetcher); + +TConstExpressionPtr PrepareExpression( + const TString& source, + const TTableSchema& tableSchema, + const TConstTypeInferrerMapPtr& functions = GetBuiltinTypeInferrers(), + THashSet<TString>* references = nullptr); + +TConstExpressionPtr PrepareExpression( + const TParsedSource& parsedSource, + const TTableSchema& tableSchema, + const TConstTypeInferrerMapPtr& functions = GetBuiltinTypeInferrers(), + THashSet<TString>* references = nullptr); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/base/ya.make b/yt/yt/library/query/base/ya.make new file mode 100644 index 0000000000..8c1e8213fb --- /dev/null +++ b/yt/yt/library/query/base/ya.make @@ -0,0 +1,32 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +PROTO_NAMESPACE(yt) + +SRCS( + ast.cpp + constraints.cpp + coordination_helpers.cpp + functions.cpp + builtin_function_registry.cpp + builtin_function_types.cpp + functions_common.cpp + key_trie.cpp + lexer.rl6 + parser.ypp + public.cpp + query.cpp + query_common.cpp + query_helpers.cpp + query_preparer.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/client + yt/yt/library/query/misc + yt/yt/library/query/proto +) + +END() diff --git a/yt/yt/library/query/engine_api/append_function_implementation.cpp b/yt/yt/library/query/engine_api/append_function_implementation.cpp new file mode 100644 index 0000000000..fa0ee8acf1 --- /dev/null +++ b/yt/yt/library/query/engine_api/append_function_implementation.cpp @@ -0,0 +1,28 @@ +#include "append_function_implementation.h" + +#include <yt/yt/library/query/base/query_helpers.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK void AppendFunctionImplementation( + const TFunctionProfilerMapPtr& /*functionProfilers*/, + const TAggregateProfilerMapPtr& /*aggregateProfilers*/, + bool /*functionIsAggregate*/, + const TString& /*functionName*/, + const TString& /*functionSymbolName*/, + ECallingConvention /*functionCallingConvention*/, + TSharedRef /*functionChunkSpecsFingerprint*/, + TType /*functionRepeatedArgType*/, + int /*functionRepeatedArgIndex*/, + bool /*functionUseFunctionContext*/, + const TSharedRef& /*functionImpl*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/append_function_implementation.cpp + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/append_function_implementation.h b/yt/yt/library/query/engine_api/append_function_implementation.h new file mode 100644 index 0000000000..1a8f7c2494 --- /dev/null +++ b/yt/yt/library/query/engine_api/append_function_implementation.h @@ -0,0 +1,26 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/query/base/functions_common.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +void AppendFunctionImplementation( + const TFunctionProfilerMapPtr& functionProfilers, + const TAggregateProfilerMapPtr& aggregateProfilers, + bool functionIsAggregate, + const TString& functionName, + const TString& functionSymbolName, + ECallingConvention functionCallingConvention, + TSharedRef functionChunkSpecsFingerprint, + TType functionRepeatedArgType, + int functionRepeatedArgIndex, + bool functionUseFunctionContext, + const TSharedRef& functionImpl); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/builtin_function_profiler.cpp b/yt/yt/library/query/engine_api/builtin_function_profiler.cpp new file mode 100644 index 0000000000..5b30871c55 --- /dev/null +++ b/yt/yt/library/query/engine_api/builtin_function_profiler.cpp @@ -0,0 +1,64 @@ +#include "builtin_function_profiler.h" + +#include "public.h" + +#include <yt/yt/core/misc/error.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +const IFunctionCodegenPtr& TFunctionProfilerMap::GetFunction(const TString& functionName) const +{ + auto found = this->find(functionName); + if (found == this->end()) { + THROW_ERROR_EXCEPTION("Code generator not found for regular function %Qv", + functionName); + } + return found->second; +} + +const IAggregateCodegenPtr& TAggregateProfilerMap::GetAggregate(const TString& functionName) const +{ + auto found = this->find(functionName); + if (found == this->end()) { + THROW_ERROR_EXCEPTION("Code generator not found for aggregate function %Qv", + functionName); + } + return found->second; +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK const TConstFunctionProfilerMapPtr GetBuiltinFunctionProfilers() +{ + // Proper implementation resides in yt/yt/library/query/engine/builtin_function_profiler.cpp. + YT_ABORT(); +} + +Y_WEAK const TConstAggregateProfilerMapPtr GetBuiltinAggregateProfilers() +{ + // Proper implementation resides in yt/yt/library/query/engine/builtin_function_profiler.cpp. + YT_ABORT(); +} + +Y_WEAK const TConstRangeExtractorMapPtr GetBuiltinRangeExtractors() +{ + // Proper implementation resides in yt/yt/library/query/engine/builtin_function_profiler.cpp. + YT_ABORT(); +} + +Y_WEAK const TConstConstraintExtractorMapPtr GetBuiltinConstraintExtractors() +{ + // Proper implementation resides in yt/yt/library/query/engine/builtin_function_profiler.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_WEAK_REFCOUNTED_TYPE(IFunctionCodegen) +DEFINE_WEAK_REFCOUNTED_TYPE(IAggregateCodegen) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/builtin_function_profiler.h b/yt/yt/library/query/engine_api/builtin_function_profiler.h new file mode 100644 index 0000000000..ebe1861a53 --- /dev/null +++ b/yt/yt/library/query/engine_api/builtin_function_profiler.h @@ -0,0 +1,39 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/query/base/builtin_function_registry.h> +#include <yt/yt/library/query/base/functions.h> +#include <yt/yt/library/query/base/functions_builder.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TFunctionProfilerMap) + +struct TFunctionProfilerMap + : public TRefCounted + , public std::unordered_map<TString, IFunctionCodegenPtr> +{ + const IFunctionCodegenPtr& GetFunction(const TString& functionName) const; +}; + +DEFINE_REFCOUNTED_TYPE(TFunctionProfilerMap) + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TFunctionProfilerMap) + +struct TAggregateProfilerMap + : public TRefCounted + , public std::unordered_map<TString, IAggregateCodegenPtr> +{ + const IAggregateCodegenPtr& GetAggregate(const TString& functionName) const; +}; + +DEFINE_REFCOUNTED_TYPE(TAggregateProfilerMap) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/column_evaluator-inl.h b/yt/yt/library/query/engine_api/column_evaluator-inl.h new file mode 100644 index 0000000000..e8917015ae --- /dev/null +++ b/yt/yt/library/query/engine_api/column_evaluator-inl.h @@ -0,0 +1,18 @@ +#ifndef COLUMN_EVALUATOR_INL_H_ +#error "Direct inclusion of this file is not allowed, include column_evaluator.h" +// For the sake of sane code completion. +#include "column_evaluator.h" +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +Y_FORCE_INLINE bool TColumnEvaluator::IsAggregate(int index) const +{ + return IsAggregate_[index]; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/column_evaluator.cpp b/yt/yt/library/query/engine_api/column_evaluator.cpp new file mode 100644 index 0000000000..9c7b6c8b29 --- /dev/null +++ b/yt/yt/library/query/engine_api/column_evaluator.cpp @@ -0,0 +1,138 @@ +#include "column_evaluator.h" + +#include <yt/yt/client/table_client/row_buffer.h> + +namespace NYT::NQueryClient { + +using NTableClient::TMutableVersionedRow; +using NTableClient::TMutableUnversionedRow; +using NTableClient::TUnversionedValue; + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK TColumnEvaluatorPtr TColumnEvaluator::Create( + const TTableSchemaPtr& /*schema*/, + const TConstTypeInferrerMapPtr& /*typeInferrers*/, + const TConstFunctionProfilerMapPtr& /*profilers*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/column_evaluator.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TColumnEvaluator::TColumnEvaluator( + std::vector<TColumn> columns, + std::vector<bool> isAggregate) + : Columns_(std::move(columns)) + , IsAggregate_(std::move(isAggregate)) +{ } + +void TColumnEvaluator::EvaluateKey(TMutableRow fullRow, const TRowBufferPtr& buffer, int index) const +{ + YT_VERIFY(index < static_cast<int>(fullRow.GetCount())); + YT_VERIFY(index < std::ssize(Columns_)); + + const auto& column = Columns_[index]; + const auto& evaluator = column.Evaluator; + YT_VERIFY(evaluator); + + // Zero row to avoid garbage after evaluator. + fullRow[index] = MakeUnversionedSentinelValue(EValueType::Null); + + evaluator( + column.Variables.GetLiteralValues(), + column.Variables.GetOpaqueData(), + &fullRow[index], + fullRow.Elements(), + buffer.Get()); + + fullRow[index].Id = index; +} + +void TColumnEvaluator::EvaluateKeys(TMutableRow fullRow, const TRowBufferPtr& buffer) const +{ + for (int index = 0; index < std::ssize(Columns_); ++index) { + if (Columns_[index].Evaluator) { + EvaluateKey(fullRow, buffer, index); + } + } +} + +void TColumnEvaluator::EvaluateKeys( + TMutableVersionedRow fullRow, + const TRowBufferPtr& buffer) const +{ + auto row = buffer->CaptureRow(fullRow.Keys(), /*captureValues*/ false); + EvaluateKeys(row, buffer); + + for (int index = 0; index < fullRow.GetKeyCount(); ++index) { + if (Columns_[index].Evaluator) { + fullRow.Keys()[index] = row[index]; + } + } +} + +const std::vector<int>& TColumnEvaluator::GetReferenceIds(int index) const +{ + return Columns_[index].ReferenceIds; +} + +TConstExpressionPtr TColumnEvaluator::GetExpression(int index) const +{ + return Columns_[index].Expression; +} + +void TColumnEvaluator::InitAggregate( + int index, + TUnversionedValue* state, + const TRowBufferPtr& buffer) const +{ + Columns_[index].Aggregate.Init(buffer.Get(), state); + state->Id = index; +} + +void TColumnEvaluator::UpdateAggregate( + int index, + TUnversionedValue* state, + const TRange<TUnversionedValue> update, + const TRowBufferPtr& buffer) const +{ + Columns_[index].Aggregate.Update(buffer.Get(), state, update); + state->Id = index; +} + +void TColumnEvaluator::MergeAggregate( + int index, + TUnversionedValue* state, + const TUnversionedValue& mergeeState, + const TRowBufferPtr& buffer) const +{ + Columns_[index].Aggregate.Merge(buffer.Get(), state, &mergeeState); + state->Id = index; +} + +void TColumnEvaluator::FinalizeAggregate( + int index, + TUnversionedValue* result, + const TUnversionedValue& state, + const TRowBufferPtr& buffer) const +{ + Columns_[index].Aggregate.Finalize(buffer.Get(), result, &state); + result->Id = index; +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK IColumnEvaluatorCachePtr CreateColumnEvaluatorCache( + TColumnEvaluatorCacheConfigPtr /*config*/, + TConstTypeInferrerMapPtr /*typeInferrers*/, + TConstFunctionProfilerMapPtr /*profilers*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/column_evaluator.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/column_evaluator.h b/yt/yt/library/query/engine_api/column_evaluator.h new file mode 100644 index 0000000000..4261c89df5 --- /dev/null +++ b/yt/yt/library/query/engine_api/column_evaluator.h @@ -0,0 +1,107 @@ +#pragma once + +#include "evaluation_helpers.h" +#include "public.h" + +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +class TColumnEvaluator + : public TRefCounted +{ +public: + static TColumnEvaluatorPtr Create( + const TTableSchemaPtr& schema, + const TConstTypeInferrerMapPtr& typeInferrers, + const TConstFunctionProfilerMapPtr& profilers); + + void EvaluateKey( + TMutableRow fullRow, + const TRowBufferPtr& buffer, + int index) const; + + void EvaluateKeys( + TMutableRow fullRow, + const TRowBufferPtr& buffer) const; + + void EvaluateKeys( + NTableClient::TMutableVersionedRow fullRow, + const TRowBufferPtr& buffer) const; + + const std::vector<int>& GetReferenceIds(int index) const; + TConstExpressionPtr GetExpression(int index) const; + + void InitAggregate( + int schemaId, + NTableClient::TUnversionedValue* state, + const TRowBufferPtr& buffer) const; + + void UpdateAggregate( + int index, + NTableClient::TUnversionedValue* state, + const TRange<NTableClient::TUnversionedValue> update, + const TRowBufferPtr& buffer) const; + + void MergeAggregate( + int index, + NTableClient::TUnversionedValue* state, + const NTableClient::TUnversionedValue& mergeeState, + const TRowBufferPtr& buffer) const; + + void FinalizeAggregate( + int index, + NTableClient::TUnversionedValue* result, + const NTableClient::TUnversionedValue& state, + const TRowBufferPtr& buffer) const; + + bool IsAggregate(int index) const; + +private: + struct TColumn + { + TCGExpressionCallback Evaluator; + TCGVariables Variables; + std::vector<int> ReferenceIds; + TConstExpressionPtr Expression; + TCGAggregateCallbacks Aggregate; + }; + + std::vector<TColumn> Columns_; + std::vector<bool> IsAggregate_; + + TColumnEvaluator( + std::vector<TColumn> columns, + std::vector<bool> isAggregate); + + DECLARE_NEW_FRIEND() +}; + +DEFINE_REFCOUNTED_TYPE(TColumnEvaluator) + +//////////////////////////////////////////////////////////////////////////////// + +struct IColumnEvaluatorCache + : public virtual TRefCounted +{ + virtual TColumnEvaluatorPtr Find(const TTableSchemaPtr& schema) = 0; + + virtual void Configure(const TColumnEvaluatorCacheDynamicConfigPtr& config) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IColumnEvaluatorCache) + +IColumnEvaluatorCachePtr CreateColumnEvaluatorCache( + TColumnEvaluatorCacheConfigPtr config, + TConstTypeInferrerMapPtr typeInferrers = GetBuiltinTypeInferrers(), + TConstFunctionProfilerMapPtr profilers = GetBuiltinFunctionProfilers()); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + +#define COLUMN_EVALUATOR_INL_H_ +#include "column_evaluator-inl.h" +#undef COLUMN_EVALUATOR_INL_H_ diff --git a/yt/yt/library/query/engine_api/config.cpp b/yt/yt/library/query/engine_api/config.cpp new file mode 100644 index 0000000000..87e9a5bd3d --- /dev/null +++ b/yt/yt/library/query/engine_api/config.cpp @@ -0,0 +1,43 @@ +#include "config.h" + +#include <yt/yt/core/misc/cache_config.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +void TExecutorConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cg_cache", &TThis::CGCache) + .DefaultNew(); + + registrar.Preprocessor([] (TThis* config) { + config->CGCache->Capacity = 512; + config->CGCache->ShardCount = 1; + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TColumnEvaluatorCacheConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cg_cache", &TThis::CGCache) + .DefaultNew(); + + registrar.Preprocessor([] (TThis* config) { + config->CGCache->Capacity = 512; + config->CGCache->ShardCount = 1; + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TColumnEvaluatorCacheDynamicConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("cg_cache", &TThis::CGCache) + .DefaultNew(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/config.h b/yt/yt/library/query/engine_api/config.h new file mode 100644 index 0000000000..f45d4ea4ec --- /dev/null +++ b/yt/yt/library/query/engine_api/config.h @@ -0,0 +1,58 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/misc/public.h> + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +class TExecutorConfig + : public NYTree::TYsonStruct +{ +public: + TSlruCacheConfigPtr CGCache; + + REGISTER_YSON_STRUCT(TExecutorConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TExecutorConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TColumnEvaluatorCacheConfig + : public NYTree::TYsonStruct +{ +public: + TSlruCacheConfigPtr CGCache; + + REGISTER_YSON_STRUCT(TColumnEvaluatorCacheConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TColumnEvaluatorCacheConfig) + +//////////////////////////////////////////////////////////////////////////////// + +class TColumnEvaluatorCacheDynamicConfig + : public NYTree::TYsonStruct +{ +public: + TSlruCacheDynamicConfigPtr CGCache; + + REGISTER_YSON_STRUCT(TColumnEvaluatorCacheDynamicConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TColumnEvaluatorCacheDynamicConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/coordinator.cpp b/yt/yt/library/query/engine_api/coordinator.cpp new file mode 100644 index 0000000000..7a6d24ab68 --- /dev/null +++ b/yt/yt/library/query/engine_api/coordinator.cpp @@ -0,0 +1,57 @@ +#include "coordinator.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK std::pair<TConstFrontQueryPtr, std::vector<TConstQueryPtr>> CoordinateQuery( + const TConstQueryPtr& /*query*/, + const std::vector<TRefiner>& /*refiners*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/coordinator.cpp. + YT_ABORT(); +} + +Y_WEAK TRowRanges GetPrunedRanges( + const TConstExpressionPtr& /*predicate*/, + const TTableSchemaPtr& /*tableSchema*/, + const TKeyColumns& /*keyColumns*/, + NObjectClient::TObjectId /*tableId*/, + const TSharedRange<TRowRange>& /*ranges*/, + const TRowBufferPtr& /*rowBuffer*/, + const IColumnEvaluatorCachePtr& /*evaluatorCache*/, + const TConstRangeExtractorMapPtr& /*rangeExtractors*/, + const TQueryOptions& /*options*/, + TGuid /*queryId*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/coordinator.cpp. + YT_ABORT(); +} + +Y_WEAK TRowRanges GetPrunedRanges( + const TConstQueryPtr& /*query*/, + NObjectClient::TObjectId /*tableId*/, + const TSharedRange<TRowRange>& /*ranges*/, + const TRowBufferPtr& /*rowBuffer*/, + const IColumnEvaluatorCachePtr& /*evaluatorCache*/, + const TConstRangeExtractorMapPtr& /*rangeExtractors*/, + const TQueryOptions& /*options*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/coordinator.cpp. + YT_ABORT(); +} + +Y_WEAK TQueryStatistics CoordinateAndExecute( + const TConstQueryPtr& /*query*/, + const IUnversionedRowsetWriterPtr& /*writer*/, + const std::vector<TRefiner>& /*ranges*/, + std::function<TEvaluateResult(const TConstQueryPtr&, int)> /*evaluateSubquery*/, + std::function<TQueryStatistics(const TConstFrontQueryPtr&, const ISchemafulUnversionedReaderPtr&, const IUnversionedRowsetWriterPtr&)> /*evaluateTop*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/coordinator.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/coordinator.h b/yt/yt/library/query/engine_api/coordinator.h new file mode 100644 index 0000000000..e61c6631a9 --- /dev/null +++ b/yt/yt/library/query/engine_api/coordinator.h @@ -0,0 +1,55 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/client/query_client/query_statistics.h> + +#include <yt/yt/core/actions/future.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +using TRefiner = std::function<TConstExpressionPtr( + const TConstExpressionPtr& expr, + const TKeyColumns& keyColumns)>; + +std::pair<TConstFrontQueryPtr, std::vector<TConstQueryPtr>> CoordinateQuery( + const TConstQueryPtr& query, + const std::vector<TRefiner>& refiners); + +TRowRanges GetPrunedRanges( + const TConstExpressionPtr& predicate, + const TTableSchemaPtr& tableSchema, + const TKeyColumns& keyColumns, + NObjectClient::TObjectId tableId, + const TSharedRange<TRowRange>& ranges, + const TRowBufferPtr& rowBuffer, + const IColumnEvaluatorCachePtr& evaluatorCache, + const TConstRangeExtractorMapPtr& rangeExtractors, + const TQueryOptions& options, + TGuid queryId = {}); + +TRowRanges GetPrunedRanges( + const TConstQueryPtr& query, + NObjectClient::TObjectId tableId, + const TSharedRange<TRowRange>& ranges, + const TRowBufferPtr& rowBuffer, + const IColumnEvaluatorCachePtr& evaluatorCache, + const TConstRangeExtractorMapPtr& rangeExtractors, + const TQueryOptions& options); + +using TEvaluateResult = std::pair< + ISchemafulUnversionedReaderPtr, + TFuture<TQueryStatistics>>; + +TQueryStatistics CoordinateAndExecute( + const TConstQueryPtr& query, + const IUnversionedRowsetWriterPtr& writer, + const std::vector<TRefiner>& ranges, + std::function<TEvaluateResult(const TConstQueryPtr&, int)> evaluateSubquery, + std::function<TQueryStatistics(const TConstFrontQueryPtr&, const ISchemafulUnversionedReaderPtr&, const IUnversionedRowsetWriterPtr&)> evaluateTop); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/evaluation_helpers-inl.h b/yt/yt/library/query/engine_api/evaluation_helpers-inl.h new file mode 100644 index 0000000000..35a57b25bb --- /dev/null +++ b/yt/yt/library/query/engine_api/evaluation_helpers-inl.h @@ -0,0 +1,24 @@ +#ifndef EVALUATION_HELPERS_INL_H_ +#error "Direct inclusion of this file is not allowed, include evaluation_helpers.h" +// For the sake of sane code completion. +#include "evaluation_helpers.h" +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T, class... TArgs> +int TCGVariables::AddOpaque(TArgs&& ... args) +{ + auto pointer = Holder_.Register(new T(std::forward<TArgs>(args)...)); + + int index = static_cast<int>(OpaquePointers_.size()); + OpaquePointers_.push_back(pointer); + + return index; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/evaluation_helpers.cpp b/yt/yt/library/query/engine_api/evaluation_helpers.cpp new file mode 100644 index 0000000000..18b3f57a28 --- /dev/null +++ b/yt/yt/library/query/engine_api/evaluation_helpers.cpp @@ -0,0 +1,286 @@ +#include "evaluation_helpers.h" + +#include "position_independent_value_transfer.h" + +#include <yt/yt/library/query/base/private.h> +#include <yt/yt/library/query/base/query.h> +#include <yt/yt/library/query/base/query_helpers.h> + +#include <yt/yt/client/query_client/query_statistics.h> + +namespace NYT::NQueryClient { + +using namespace NConcurrency; +using namespace NTableClient; + +static const auto& Logger = QueryClientLogger; + +//////////////////////////////////////////////////////////////////////////////// + +constexpr ssize_t BufferLimit = 512_KB; + +struct TTopCollectorBufferTag +{ }; + +//////////////////////////////////////////////////////////////////////////////// + +TTopCollector::TTopCollector( + i64 limit, + TComparerFunction* comparer, + size_t rowSize, + IMemoryChunkProviderPtr memoryChunkProvider) + : Comparer_(comparer) + , RowSize_(rowSize) + , MemoryChunkProvider_(std::move(memoryChunkProvider)) +{ + Rows_.reserve(limit); +} + +std::pair<const TPIValue*, int> TTopCollector::Capture(const TPIValue* row) +{ + if (EmptyBufferIds_.empty()) { + if (GarbageMemorySize_ > TotalMemorySize_ / 2) { + // Collect garbage. + + std::vector<std::vector<size_t>> buffersToRows(Buffers_.size()); + for (size_t rowId = 0; rowId < Rows_.size(); ++rowId) { + buffersToRows[Rows_[rowId].second].push_back(rowId); + } + + auto buffer = New<TRowBuffer>(TTopCollectorBufferTag(), MemoryChunkProvider_); + + TotalMemorySize_ = 0; + AllocatedMemorySize_ = 0; + GarbageMemorySize_ = 0; + + for (size_t bufferId = 0; bufferId < buffersToRows.size(); ++bufferId) { + for (auto rowId : buffersToRows[bufferId]) { + auto& row = Rows_[rowId].first; + + auto savedSize = buffer->GetSize(); + row = CapturePIValueRange(buffer.Get(), MakeRange(row, RowSize_)).Begin(); + AllocatedMemorySize_ += buffer->GetSize() - savedSize; + } + + TotalMemorySize_ += buffer->GetCapacity(); + + if (buffer->GetSize() < BufferLimit) { + EmptyBufferIds_.push_back(bufferId); + } + + std::swap(buffer, Buffers_[bufferId]); + buffer->Clear(); + } + } else { + // Allocate buffer and add to emptyBufferIds. + EmptyBufferIds_.push_back(Buffers_.size()); + Buffers_.push_back(New<TRowBuffer>(TTopCollectorBufferTag(), MemoryChunkProvider_)); + } + } + + YT_VERIFY(!EmptyBufferIds_.empty()); + + auto bufferId = EmptyBufferIds_.back(); + auto buffer = Buffers_[bufferId]; + + auto savedSize = buffer->GetSize(); + auto savedCapacity = buffer->GetCapacity(); + + TPIValue* capturedRow = CapturePIValueRange(buffer.Get(), MakeRange(row, RowSize_)).Begin(); + + AllocatedMemorySize_ += buffer->GetSize() - savedSize; + TotalMemorySize_ += buffer->GetCapacity() - savedCapacity; + + if (buffer->GetSize() >= BufferLimit) { + EmptyBufferIds_.pop_back(); + } + + return std::make_pair(capturedRow, bufferId); +} + +void TTopCollector::AccountGarbage(const TPIValue* row) +{ + GarbageMemorySize_ += GetUnversionedRowByteSize(RowSize_); + for (int index = 0; index < static_cast<int>(RowSize_); ++index) { + const auto& value = row[index]; + + if (IsStringLikeType(EValueType(value.Type))) { + GarbageMemorySize_ += value.Length; + } + } +} + +void TTopCollector::AddRow(const TPIValue* row) +{ + if (Rows_.size() < Rows_.capacity()) { + auto capturedRow = Capture(row); + Rows_.emplace_back(capturedRow); + std::push_heap(Rows_.begin(), Rows_.end(), Comparer_); + } else if (!Rows_.empty() && !Comparer_(Rows_.front().first, row)) { + auto capturedRow = Capture(row); + std::pop_heap(Rows_.begin(), Rows_.end(), Comparer_); + AccountGarbage(Rows_.back().first); + Rows_.back() = capturedRow; + std::push_heap(Rows_.begin(), Rows_.end(), Comparer_); + } +} + +std::vector<const TPIValue*> TTopCollector::GetRows() const +{ + std::vector<const TPIValue*> result; + result.reserve(Rows_.size()); + for (const auto& [value, _] : Rows_) { + result.push_back(value); + } + std::sort(result.begin(), result.end(), Comparer_); + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +TMultiJoinClosure::TItem::TItem( + IMemoryChunkProviderPtr chunkProvider, + size_t keySize, + TComparerFunction* prefixEqComparer, + THasherFunction* lookupHasher, + TComparerFunction* lookupEqComparer) + : Buffer(New<TRowBuffer>(TPermanentBufferTag(), std::move(chunkProvider))) + , KeySize(keySize) + , PrefixEqComparer(prefixEqComparer) + , Lookup( + InitialGroupOpHashtableCapacity, + lookupHasher, + lookupEqComparer) +{ + Lookup.set_empty_key(nullptr); +} + +TGroupByClosure::TGroupByClosure( + IMemoryChunkProviderPtr chunkProvider, + TComparerFunction* prefixEqComparer, + THasherFunction* groupHasher, + TComparerFunction* groupComparer, + int keySize, + int valuesCount, + bool checkNulls) + : Buffer(New<TRowBuffer>(TPermanentBufferTag(), std::move(chunkProvider))) + , PrefixEqComparer(prefixEqComparer) + , Lookup( + InitialGroupOpHashtableCapacity, + groupHasher, + groupComparer) + , KeySize(keySize) + , ValuesCount(valuesCount) + , CheckNulls(checkNulls) +{ + Lookup.set_empty_key(nullptr); +} + +TWriteOpClosure::TWriteOpClosure(IMemoryChunkProviderPtr chunkProvider) + : OutputBuffer(New<TRowBuffer>(TOutputBufferTag(), std::move(chunkProvider))) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +std::pair<TQueryPtr, TDataSource> GetForeignQuery( + TQueryPtr subquery, + TConstJoinClausePtr joinClause, + std::vector<TRow> keys, + TRowBufferPtr permanentBuffer) +{ + auto foreignKeyPrefix = joinClause->ForeignKeyPrefix; + const auto& foreignEquations = joinClause->ForeignEquations; + + auto newQuery = New<TQuery>(*subquery); + + TDataSource dataSource; + dataSource.ObjectId = joinClause->ForeignObjectId; + dataSource.CellId = joinClause->ForeignCellId; + + if (foreignKeyPrefix > 0) { + if (foreignKeyPrefix == foreignEquations.size()) { + YT_LOG_DEBUG("Using join via source ranges"); + dataSource.Keys = MakeSharedRange(std::move(keys), std::move(permanentBuffer)); + } else { + YT_LOG_DEBUG("Using join via prefix ranges"); + std::vector<TRow> prefixKeys; + for (auto key : keys) { + prefixKeys.push_back(permanentBuffer->CaptureRow(MakeRange(key.Begin(), foreignKeyPrefix), false)); + } + prefixKeys.erase(std::unique(prefixKeys.begin(), prefixKeys.end()), prefixKeys.end()); + dataSource.Keys = MakeSharedRange(std::move(prefixKeys), std::move(permanentBuffer)); + } + + for (size_t index = 0; index < foreignKeyPrefix; ++index) { + dataSource.Schema.push_back(foreignEquations[index]->LogicalType); + } + + newQuery->InferRanges = false; + // COMPAT(lukyan): Use ordered read without modification of protocol + newQuery->Limit = std::numeric_limits<i64>::max() - 1; + } else { + TRowRanges ranges; + + YT_LOG_DEBUG("Using join via IN clause"); + ranges.emplace_back( + permanentBuffer->CaptureRow(NTableClient::MinKey().Get()), + permanentBuffer->CaptureRow(NTableClient::MaxKey().Get())); + + auto inClause = New<TInExpression>( + foreignEquations, + MakeSharedRange(std::move(keys), permanentBuffer)); + + dataSource.Ranges = MakeSharedRange(std::move(ranges), std::move(permanentBuffer)); + + newQuery->WhereClause = newQuery->WhereClause + ? MakeAndExpression(inClause, newQuery->WhereClause) + : inClause; + } + + return std::make_pair(newQuery, dataSource); +} + +//////////////////////////////////////////////////////////////////////////////// + +TRange<void*> TCGVariables::GetOpaqueData() const +{ + return OpaquePointers_; +} + +void TCGVariables::Clear() +{ + OpaquePointers_.clear(); + Holder_.Clear(); + OwningLiteralValues_.clear(); + LiteralValues_.reset(); +} + +int TCGVariables::AddLiteralValue(TOwningValue value) +{ + YT_ASSERT(!LiteralValues_); + int index = static_cast<int>(OwningLiteralValues_.size()); + OwningLiteralValues_.emplace_back(std::move(value)); + return index; +} + +TRange<TPIValue> TCGVariables::GetLiteralValues() const +{ + InitLiteralValuesIfNeeded(this); + return {LiteralValues_.get(), OwningLiteralValues_.size()}; +} + +void TCGVariables::InitLiteralValuesIfNeeded(const TCGVariables* variables) +{ + if (!variables->LiteralValues_) { + variables->LiteralValues_ = std::make_unique<TPIValue[]>(variables->OwningLiteralValues_.size()); + size_t index = 0; + for (const auto& value : variables->OwningLiteralValues_) { + MakePositionIndependentFromUnversioned(&variables->LiteralValues_[index], value); + ++index; + } + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/evaluation_helpers.h b/yt/yt/library/query/engine_api/evaluation_helpers.h new file mode 100644 index 0000000000..aacbf626e7 --- /dev/null +++ b/yt/yt/library/query/engine_api/evaluation_helpers.h @@ -0,0 +1,369 @@ +#pragma once + +#include "position_independent_value.h" + +#include "public.h" + +#include <yt/yt/library/query/base/callbacks.h> + +#include <yt/yt/library/query/misc/objects_holder.h> +#include <yt/yt/library/query/misc/function_context.h> + +#include <yt/yt/client/api/rowset.h> + +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <library/cpp/yt/memory/chunked_memory_pool.h> + +#include <deque> +#include <unordered_map> +#include <unordered_set> + +#include <sparsehash/dense_hash_set> +#include <sparsehash/dense_hash_map> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +constexpr i64 RowsetProcessingSize = 1024; +constexpr i64 WriteRowsetSize = 64 * RowsetProcessingSize; +constexpr i64 MaxJoinBatchSize = 1024 * RowsetProcessingSize; + +class TInterruptedIncompleteException +{ }; + +struct TOutputBufferTag +{ }; + +struct TIntermediateBufferTag +{ }; + +struct TPermanentBufferTag +{ }; + +//////////////////////////////////////////////////////////////////////////////// + +constexpr const size_t InitialGroupOpHashtableCapacity = 1024; + +using THasherFunction = ui64(const TPIValue*); +using TComparerFunction = char(const TPIValue*, const TPIValue*); +using TTernaryComparerFunction = i64(const TPIValue*, const TPIValue*); + +namespace NDetail { + +class TGroupHasher +{ +public: + // Intentionally implicit. + TGroupHasher(THasherFunction* ptr) + : Ptr_(ptr) + { } + + ui64 operator () (const TPIValue* row) const + { + return Ptr_(row); + } + +private: + THasherFunction* Ptr_; +}; + +class TRowComparer +{ +public: + // Intentionally implicit. + TRowComparer(TComparerFunction* ptr) + : Ptr_(ptr) + { } + + bool operator () (const TPIValue* a, const TPIValue* b) const + { + return a == b || a && b && Ptr_(a, b); + } + +private: + TComparerFunction* Ptr_; +}; + +} // namespace NDetail + +using TLookupRows = google::dense_hash_set< + const TPIValue*, + NDetail::TGroupHasher, + NDetail::TRowComparer>; + +using TJoinLookup = google::dense_hash_map< + const TPIValue*, + std::pair<int, bool>, + NDetail::TGroupHasher, + NDetail::TRowComparer>; + +using TJoinLookupRows = std::unordered_multiset< + const TPIValue*, + NDetail::TGroupHasher, + NDetail::TRowComparer>; + +struct TExecutionContext; + +struct TSingleJoinParameters +{ + size_t KeySize; + bool IsLeft; + bool IsPartiallySorted; + std::vector<size_t> ForeignColumns; + TJoinSubqueryEvaluator ExecuteForeign; +}; + +struct TMultiJoinParameters +{ + TCompactVector<TSingleJoinParameters, 10> Items; + size_t PrimaryRowSize; + size_t BatchSize; +}; + +struct TMultiJoinClosure +{ + TRowBufferPtr Buffer; + + using THashJoinLookup = google::dense_hash_set< + TPIValue*, + NDetail::TGroupHasher, + NDetail::TRowComparer>; // + slot after row + + std::vector<TPIValue*> PrimaryRows; + + struct TItem + { + TRowBufferPtr Buffer; + size_t KeySize; + TComparerFunction* PrefixEqComparer; + + THashJoinLookup Lookup; + std::vector<TPIValue*> OrderedKeys; // + slot after row + const TPIValue* LastKey = nullptr; + + TItem( + IMemoryChunkProviderPtr chunkProvider, + size_t keySize, + TComparerFunction* prefixEqComparer, + THasherFunction* lookupHasher, + TComparerFunction* lookupEqComparer); + }; + + TCompactVector<TItem, 32> Items; + + size_t PrimaryRowSize; + size_t BatchSize; + std::function<void(size_t)> ProcessSegment; + std::function<bool()> ProcessJoinBatch; +}; + +struct TGroupByClosure +{ + TRowBufferPtr Buffer; + TComparerFunction* PrefixEqComparer; + TLookupRows Lookup; + const TPIValue* LastKey = nullptr; + std::vector<const TPIValue*> GroupedRows; + int KeySize; + int ValuesCount; + bool CheckNulls; + + // GroupedRows can be flushed and cleared during aggregation. + // So we have to count grouped rows separately. + size_t GroupedRowCount = 0; + + TGroupByClosure( + IMemoryChunkProviderPtr chunkProvider, + TComparerFunction* prefixEqComparer, + THasherFunction* groupHasher, + TComparerFunction* groupComparer, + int keySize, + int valuesCount, + bool checkNulls); + + std::function<void()> ProcessSegment; +}; + +struct TWriteOpClosure +{ + TRowBufferPtr OutputBuffer; + + // Rows stored in OutputBuffer + std::vector<TRow> OutputRowsBatch; + size_t RowSize; + + explicit TWriteOpClosure(IMemoryChunkProviderPtr chunkProvider); +}; + +using TExpressionContext = TRowBuffer; + +#define CHECK_STACK() (void) 0; + +struct TExecutionContext +{ + ISchemafulUnversionedReaderPtr Reader; + IUnversionedRowsetWriterPtr Writer; + + TQueryStatistics* Statistics = nullptr; + + // These limits prevent full scan. + i64 InputRowLimit = std::numeric_limits<i64>::max(); + i64 OutputRowLimit = std::numeric_limits<i64>::max(); + i64 GroupRowLimit = std::numeric_limits<i64>::max(); + i64 JoinRowLimit = std::numeric_limits<i64>::max(); + + // Offset from OFFSET clause. + i64 Offset = 0; + // Limit from LIMIT clause. + i64 Limit = std::numeric_limits<i64>::max(); + + bool Ordered = false; + bool IsMerge = false; + + IMemoryChunkProviderPtr MemoryChunkProvider; + + TExecutionContext() + { + auto context = this; + Y_UNUSED(context); + CHECK_STACK(); + } +}; + +class TTopCollector +{ +public: + TTopCollector( + i64 limit, + TComparerFunction* comparer, + size_t rowSize, + IMemoryChunkProviderPtr memoryChunkProvider); + + std::vector<const TPIValue*> GetRows() const; + + void AddRow(const TPIValue* row); + +private: + // GarbageMemorySize <= AllocatedMemorySize <= TotalMemorySize + size_t TotalMemorySize_ = 0; + size_t AllocatedMemorySize_ = 0; + size_t GarbageMemorySize_ = 0; + + class TComparer + { + public: + explicit TComparer(TComparerFunction* ptr) + : Ptr_(ptr) + { } + + bool operator() (const std::pair<const TPIValue*, int>& lhs, const std::pair<const TPIValue*, int>& rhs) const + { + return (*this)(lhs.first, rhs.first); + } + + bool operator () (const TPIValue* a, const TPIValue* b) const + { + return Ptr_(a, b); + } + + private: + TComparerFunction* const Ptr_; + }; + + TComparer Comparer_; + size_t RowSize_; + IMemoryChunkProviderPtr MemoryChunkProvider_; + + std::vector<TRowBufferPtr> Buffers_; + std::vector<int> EmptyBufferIds_; + std::vector<std::pair<const TPIValue*, int>> Rows_; + + std::pair<const TPIValue*, int> Capture(const TPIValue* row); + + void AccountGarbage(const TPIValue* row); +}; + +class TCGVariables +{ +public: + template <class T, class... TArgs> + int AddOpaque(TArgs&&... args); + + TRange<void*> GetOpaqueData() const; + + void Clear(); + + int AddLiteralValue(TOwningValue value); + + TRange<TPIValue> GetLiteralValues() const; + +private: + TObjectsHolder Holder_; + std::vector<void*> OpaquePointers_; + std::vector<TOwningValue> OwningLiteralValues_; + mutable std::unique_ptr<TPIValue[]> LiteralValues_; + + static void InitLiteralValuesIfNeeded(const TCGVariables* variables); +}; + +using TCGPIQuerySignature = void(const TPIValue*, void* const*, TExecutionContext*); +using TCGPIExpressionSignature = void(const TPIValue*, void* const*, TPIValue*, const TPIValue*, TExpressionContext*); +using TCGPIAggregateInitSignature = void(TExpressionContext*, TPIValue*); +using TCGPIAggregateUpdateSignature = void(TExpressionContext*, TPIValue*, const TPIValue*); +using TCGPIAggregateMergeSignature = void(TExpressionContext*, TPIValue*, const TPIValue*); +using TCGPIAggregateFinalizeSignature = void(TExpressionContext*, TPIValue*, const TPIValue*); + +using TCGQuerySignature = void(TRange<TPIValue>, TRange<void*>, TExecutionContext*); +using TCGExpressionSignature = void(TRange<TPIValue>, TRange<void*>, TValue*, TRange<TValue>, TRowBuffer*); +using TCGAggregateInitSignature = void(TExpressionContext*, TValue*); +using TCGAggregateUpdateSignature = void(TExpressionContext*, TValue*, TRange<TValue>); +using TCGAggregateMergeSignature = void(TExpressionContext*, TValue*, const TValue*); +using TCGAggregateFinalizeSignature = void(TExpressionContext*, TValue*, const TValue*); + +using TCGQueryCallback = TCallback<TCGQuerySignature>; +using TCGExpressionCallback = TCallback<TCGExpressionSignature>; +using TCGAggregateInitCallback = TCallback<TCGAggregateInitSignature>; +using TCGAggregateUpdateCallback = TCallback<TCGAggregateUpdateSignature>; +using TCGAggregateMergeCallback = TCallback<TCGAggregateMergeSignature>; +using TCGAggregateFinalizeCallback = TCallback<TCGAggregateFinalizeSignature>; + +struct TCGAggregateCallbacks +{ + TCGAggregateInitCallback Init; + TCGAggregateUpdateCallback Update; + TCGAggregateMergeCallback Merge; + TCGAggregateFinalizeCallback Finalize; +}; + +//////////////////////////////////////////////////////////////////////////////// + +std::pair<TQueryPtr, TDataSource> GetForeignQuery( + TQueryPtr subquery, + TConstJoinClausePtr joinClause, + std::vector<TRow> keys, + TRowBufferPtr permanentBuffer); + +//////////////////////////////////////////////////////////////////////////////// + +struct TExpressionClosure; + +struct TJoinComparers +{ + TComparerFunction* PrefixEqComparer; + THasherFunction* SuffixHasher; + TComparerFunction* SuffixEqComparer; + TComparerFunction* SuffixLessComparer; + TComparerFunction* ForeignPrefixEqComparer; + TComparerFunction* ForeignSuffixLessComparer; + TTernaryComparerFunction* FullTernaryComparer; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + +#define EVALUATION_HELPERS_INL_H_ +#include "evaluation_helpers-inl.h" +#undef EVALUATION_HELPERS_INL_H_ diff --git a/yt/yt/library/query/engine_api/evaluator.cpp b/yt/yt/library/query/engine_api/evaluator.cpp new file mode 100644 index 0000000000..a7238a6873 --- /dev/null +++ b/yt/yt/library/query/engine_api/evaluator.cpp @@ -0,0 +1,15 @@ +#include "evaluator.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK IEvaluatorPtr CreateEvaluator(TExecutorConfigPtr /*config*/, const NProfiling::TProfiler& /*profiler*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/evaluator.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/evaluator.h b/yt/yt/library/query/engine_api/evaluator.h new file mode 100644 index 0000000000..18c389bd9a --- /dev/null +++ b/yt/yt/library/query/engine_api/evaluator.h @@ -0,0 +1,36 @@ +#pragma once + +#include "builtin_function_profiler.h" +#include "public.h" + +#include <yt/yt/library/query/base/callbacks.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct IEvaluator + : public virtual TRefCounted +{ + virtual TQueryStatistics Run( + const TConstBaseQueryPtr& query, + const ISchemafulUnversionedReaderPtr& reader, + const IUnversionedRowsetWriterPtr& writer, + const TJoinSubqueryProfiler& joinProfiler, + const TConstFunctionProfilerMapPtr& functionProfilers, + const TConstAggregateProfilerMapPtr& aggregateProfilers, + const IMemoryChunkProviderPtr& memoryChunkProvider, + const TQueryBaseOptions& options) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IEvaluator) + +IEvaluatorPtr CreateEvaluator( + TExecutorConfigPtr config, + const NProfiling::TProfiler& profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/new_range_inferrer.cpp b/yt/yt/library/query/engine_api/new_range_inferrer.cpp new file mode 100644 index 0000000000..fd15cd54e9 --- /dev/null +++ b/yt/yt/library/query/engine_api/new_range_inferrer.cpp @@ -0,0 +1,382 @@ +#include "new_range_inferrer.h" + +#include <yt/yt/library/query/base/query_helpers.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +// Build mapping from schema key index to index of reference in tuple. +std::vector<int> BuildKeyMapping(const TKeyColumns& keyColumns, TRange<TConstExpressionPtr> expressions) +{ + std::vector<int> keyMapping(keyColumns.size(), -1); + for (int index = 0; index < std::ssize(expressions); ++index) { + const auto* referenceExpr = expressions[index]->As<TReferenceExpression>(); + if (referenceExpr) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0 && keyMapping[keyPartIndex] == -1) { + keyMapping[keyPartIndex] = index; + } + } + } + return keyMapping; +} + +int CompareRowUsingMapping(TRow lhs, TRow rhs, TRange<int> mapping) +{ + for (auto index : mapping) { + if (index == -1) { + continue; + } + + int result = CompareRowValuesCheckingNan(lhs.Begin()[index], rhs.Begin()[index]); + + if (result != 0) { + return result; + } + } + return 0; +} + +TConstraintRef TConstraintsHolder::ExtractFromExpression( + const TConstExpressionPtr& expr, + const TKeyColumns& keyColumns, + const TRowBufferPtr& rowBuffer, + const TConstConstraintExtractorMapPtr& constraintExtractors) +{ + YT_VERIFY(!keyColumns.empty()); + + if (!expr) { + return TConstraintRef::Universal(); + } + + if (const auto* binaryOpExpr = expr->As<TBinaryOpExpression>()) { + auto opcode = binaryOpExpr->Opcode; + auto lhsExpr = binaryOpExpr->Lhs; + auto rhsExpr = binaryOpExpr->Rhs; + + if (opcode == EBinaryOp::And) { + auto lhsConstraint = ExtractFromExpression(lhsExpr, keyColumns, rowBuffer, constraintExtractors); + auto rhsConstraint = ExtractFromExpression(rhsExpr, keyColumns, rowBuffer, constraintExtractors); + return TConstraintsHolder::Intersect(lhsConstraint, rhsConstraint); + } else if (opcode == EBinaryOp::Or) { + auto lhsConstraint = ExtractFromExpression(lhsExpr, keyColumns, rowBuffer, constraintExtractors); + auto rhsConstraint = ExtractFromExpression(rhsExpr, keyColumns, rowBuffer, constraintExtractors); + return TConstraintsHolder::Unite(lhsConstraint, rhsConstraint); + } else { + if (rhsExpr->As<TReferenceExpression>()) { + // Ensure that references are on the left. + std::swap(lhsExpr, rhsExpr); + opcode = GetReversedBinaryOpcode(opcode); + } + + const auto* referenceExpr = lhsExpr->As<TReferenceExpression>(); + const auto* constantExpr = rhsExpr->As<TLiteralExpression>(); + + if (referenceExpr && constantExpr) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0) { + auto value = TValue(constantExpr->Value); + switch (opcode) { + case EBinaryOp::Equal: + return Interval( + TValueBound{value, false}, + TValueBound{value, true}, + keyPartIndex); + case EBinaryOp::NotEqual: + return TConstraintsHolder::Append( + { + TConstraint::Make( + MinBound, + TValueBound{value, false}), + TConstraint::Make( + TValueBound{value, true}, + MaxBound) + }, + keyPartIndex); + case EBinaryOp::Less: + return TConstraintsHolder::Interval( + MinBound, + TValueBound{value, false}, + keyPartIndex); + case EBinaryOp::LessOrEqual: + return TConstraintsHolder::Interval( + MinBound, + TValueBound{value, true}, + keyPartIndex); + case EBinaryOp::Greater: + return TConstraintsHolder::Interval( + TValueBound{value, true}, + MaxBound, + keyPartIndex); + case EBinaryOp::GreaterOrEqual: + return TConstraintsHolder::Interval( + TValueBound{value, false}, + MaxBound, + keyPartIndex); + default: + break; + } + } + } + + return TConstraintRef::Universal(); + } + } else if (const auto* functionExpr = expr->As<TFunctionExpression>()) { + auto foundIt = constraintExtractors->find(functionExpr->FunctionName); + if (foundIt == constraintExtractors->end()) { + return TConstraintRef::Universal(); + } + + const auto& constraintExtractor = foundIt->second; + + return constraintExtractor( + this, + functionExpr, + keyColumns, + rowBuffer); + } else if (const auto* inExpr = expr->As<TInExpression>()) { + TRange<TRow> values = inExpr->Values; + auto rowCount = std::ssize(values); + + std::vector<ui32> startOffsets; + startOffsets.reserve(size()); + for (const auto& columnConstraints : *this) { + startOffsets.push_back(columnConstraints.size()); + } + + auto keyMapping = BuildKeyMapping(keyColumns, inExpr->Arguments); + + bool orderedMapping = true; + for (int index = 1; index < std::ssize(keyMapping); ++index) { + if (keyMapping[index] <= keyMapping[index - 1]) { + orderedMapping = false; + break; + } + } + + std::vector<TRow> sortedValues; + if (!orderedMapping) { + sortedValues = values.ToVector(); + std::sort(sortedValues.begin(), sortedValues.end(), [&] (TRow lhs, TRow rhs) { + return CompareRowUsingMapping(lhs, rhs, keyMapping) < 0; + }); + values = sortedValues; + } + + int lastKeyPart = -1; + for (int keyIndex = keyMapping.size() - 1; keyIndex >= 0; --keyIndex) { + auto index = keyMapping[keyIndex]; + if (index >= 0) { + auto& columnConstraints = (*this)[keyIndex]; + + for (int rowIndex = 0; rowIndex < rowCount; ++rowIndex) { + auto next = TConstraintRef::Universal(); + if (lastKeyPart >= 0) { + next.ColumnId = lastKeyPart; + next.StartIndex = startOffsets[lastKeyPart] + rowIndex; + next.EndIndex = next.StartIndex + 1; + } + + const auto& value = values[rowIndex][index]; + + columnConstraints.push_back(TConstraint::Make( + TValueBound{value, false}, + TValueBound{value, true}, + next)); + } + + lastKeyPart = keyIndex; + } + } + + auto result = TConstraintRef::Universal(); + if (lastKeyPart >= 0) { + result.ColumnId = lastKeyPart; + result.StartIndex = startOffsets[lastKeyPart]; + result.EndIndex = result.StartIndex + rowCount; + } + + return result; + } else if (const auto* betweenExpr = expr->As<TBetweenExpression>()) { + const auto& expressions = betweenExpr->Arguments; + std::vector<int> keyColumnIds; + + for (int index = 0; index < std::ssize(expressions); ++index) { + const auto* referenceExpr = expressions[index]->As<TReferenceExpression>(); + if (!referenceExpr) { + break; + } + + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + + if (keyPartIndex < 0 || !keyColumnIds.empty() && keyColumnIds.back() >= keyPartIndex) { + break; + } + + keyColumnIds.push_back(keyPartIndex); + } + + if (keyColumnIds.empty()) { + return TConstraintRef::Universal(); + } + + size_t startOffsetForFirstColumn = (*this)[keyColumnIds.front()].size(); + + // BETWEEN (a, b, c) and (k, l, m) generates the following constraints: + // [a-, a+] [b-, b+] [c-, c+] + // [c+, +inf] + // [b+, +inf] + // [a+, k-] + // [k-, k+] [-inf, l-] + // [l-, l+] [-inf, m-] + // [m-, m+] + + for (int rowIndex = 0; rowIndex < std::ssize(betweenExpr->Ranges); ++rowIndex) { + auto literalRange = betweenExpr->Ranges[rowIndex]; + + auto lower = literalRange.first; + auto upper = literalRange.second; + + size_t equalPrefix = 0; + while (equalPrefix < lower.GetCount() && + equalPrefix < upper.GetCount() && + lower[equalPrefix] == upper[equalPrefix]) + { + ++equalPrefix; + } + + // Lower and upper bounds are included. + auto currentLower = TConstraintRef::Universal(); + auto currentUpper = TConstraintRef::Universal(); + + size_t expressionIndex = keyColumnIds.size(); + + while (expressionIndex > equalPrefix + 1) { + --expressionIndex; + + auto keyColumnIndex = keyColumnIds[expressionIndex]; + + if (expressionIndex < lower.GetCount()) { + const auto& lowerValue = lower[expressionIndex]; + currentLower = Append({ + TConstraint::Make( + TValueBound{lowerValue, false}, + TValueBound{lowerValue, true}, + currentLower), + TConstraint::Make(TValueBound{lowerValue, true}, MaxBound) + }, + keyColumnIndex); + } + + if (expressionIndex < upper.GetCount()) { + const auto& upperValue = upper[expressionIndex]; + currentUpper = Append({ + TConstraint::Make(MinBound, TValueBound{upperValue, false}), + TConstraint::Make( + TValueBound{upperValue, false}, + TValueBound{upperValue, true}, + currentUpper) + }, + keyColumnIndex); + } + } + + auto current = TConstraintRef::Universal(); + if (expressionIndex == equalPrefix + 1) { + --expressionIndex; + auto keyColumnIndex = keyColumnIds[expressionIndex]; + + if (expressionIndex < lower.GetCount() && expressionIndex < upper.GetCount()) { + const auto& lowerValue = lower[expressionIndex]; + const auto& upperValue = upper[expressionIndex]; + + current = Append({ + TConstraint::Make( + TValueBound{lowerValue, false}, + TValueBound{lowerValue, true}, + currentLower), + TConstraint::Make( + TValueBound{lowerValue, true}, + TValueBound{upperValue, false} + ), + TConstraint::Make( + TValueBound{upperValue, false}, + TValueBound{upperValue, true}, + currentUpper) + }, + keyColumnIndex); + } else if (expressionIndex < lower.GetCount()) { + const auto& lowerValue = lower[expressionIndex]; + current = Append({ + TConstraint::Make( + TValueBound{lowerValue, false}, + TValueBound{lowerValue, true}, + currentLower), + TConstraint::Make(TValueBound{lowerValue, true}, MaxBound) + }, + keyColumnIndex); + } else if (expressionIndex < upper.GetCount()) { + const auto& upperValue = upper[expressionIndex]; + current = Append({ + TConstraint::Make(MinBound, TValueBound{upperValue, false}), + TConstraint::Make( + TValueBound{upperValue, false}, + TValueBound{upperValue, true}, + currentUpper) + }, + keyColumnIndex); + } + } + + while (expressionIndex > 0) { + --expressionIndex; + + auto keyColumnIndex = keyColumnIds[expressionIndex]; + + const auto& value = lower[expressionIndex]; + YT_VERIFY(value == upper[expressionIndex]); + + current = Append({ + TConstraint::Make( + TValueBound{value, false}, + TValueBound{value, true}, + current) + }, + keyColumnIndex); + } + } + + TConstraintRef result; + result.StartIndex = startOffsetForFirstColumn; + result.EndIndex = (*this)[keyColumnIds.front()].size(); + result.ColumnId = keyColumnIds.front(); + + return result; + } else if (const auto* literalExpr = expr->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean ? TConstraintRef::Universal() : TConstraintRef::Empty(); + } + } + + return TConstraintRef::Universal(); +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK TRangeInferrer CreateNewRangeInferrer( + TConstExpressionPtr /*predicate*/, + const TTableSchemaPtr& /*schema*/, + const TKeyColumns& /*keyColumns*/, + const IColumnEvaluatorCachePtr& /*evaluatorCache*/, + const TConstConstraintExtractorMapPtr& /*constraintExtractors*/, + const TQueryOptions& /*options*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/new_range_inferrer.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/new_range_inferrer.h b/yt/yt/library/query/engine_api/new_range_inferrer.h new file mode 100644 index 0000000000..957da923dd --- /dev/null +++ b/yt/yt/library/query/engine_api/new_range_inferrer.h @@ -0,0 +1,37 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/query/base/functions.h> +#include <yt/yt/library/query/base/query.h> + +#include <functional> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct TConstraintExtractorMap + : public TRefCounted + , public std::unordered_map<TString, TConstraintExtractor> +{ }; + +DEFINE_REFCOUNTED_TYPE(TConstraintExtractorMap) + +//////////////////////////////////////////////////////////////////////////////// + +using TRangeInferrer = std::function<std::vector<TMutableRowRange>( + const TRowRange& keyRange, + const TRowBufferPtr& rowBuffer)>; + +TRangeInferrer CreateNewRangeInferrer( + TConstExpressionPtr predicate, + const TTableSchemaPtr& schema, + const TKeyColumns& keyColumns, + const IColumnEvaluatorCachePtr& evaluatorCache, + const TConstConstraintExtractorMapPtr& constraintExtractors, + const TQueryOptions& options); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/position_independent_value-inl.h b/yt/yt/library/query/engine_api/position_independent_value-inl.h new file mode 100644 index 0000000000..643f0c039a --- /dev/null +++ b/yt/yt/library/query/engine_api/position_independent_value-inl.h @@ -0,0 +1,78 @@ +#ifndef POSITION_INDEPENDENT_VALUE_INL_H +#error "Direct inclusion of this file is not allowed, position_independent_value.h" +// For the sake of sane code completion. +#include "position_independent_value.h" +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +T FromPositionIndependentValue(const TPIValue& positionIndependentValue) +{ + TUnversionedValue asUnversioned; + MakeUnversionedFromPositionIndependent(&asUnversioned, positionIndependentValue); + return FromUnversionedValue<T>(asUnversioned); +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_FORCE_INLINE const char* GetStringPosition(const TPIValue& value) +{ + return reinterpret_cast<char*>(value.Data.StringOffset + reinterpret_cast<int64_t>(&value.Data.StringOffset)); +} + +Y_FORCE_INLINE void SetStringPosition(TPIValue* value, const char* string) +{ + value->Data.StringOffset = reinterpret_cast<int64_t>(string) - reinterpret_cast<int64_t>(&value->Data.StringOffset); +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_FORCE_INLINE void MakeUnversionedFromPositionIndependent(TUnversionedValue* destination, const TPIValue& source) +{ + destination->Id = source.Id; + destination->Type = source.Type; + destination->Flags = source.Flags; + destination->Length = source.Length; + + if (IsStringLikeType(source.Type)) { + destination->Data.String = GetStringPosition(source); + } else { + destination->Data.Uint64 = source.Data.Uint64; + } +} + +Y_FORCE_INLINE void MakePositionIndependentFromUnversioned(TPIValue* destination, const TUnversionedValue& source) +{ + destination->Id = source.Id; + destination->Type = source.Type; + destination->Flags = source.Flags; + destination->Length = source.Length; + + if (IsStringLikeType(source.Type)) { + SetStringPosition(destination, source.Data.String); + } else { + destination->Data.Uint64 = source.Data.Uint64; + } +} + +Y_FORCE_INLINE void CopyPositionIndependent(TPIValue* destination, const TPIValue& source) +{ + destination->Id = source.Id; + destination->Type = source.Type; + destination->Flags = source.Flags; + destination->Length = source.Length; + + if (IsStringLikeType(source.Type)) { + SetStringPosition(destination, GetStringPosition(source)); + } else { + destination->Data.Uint64 = source.Data.Uint64; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/engine_api/position_independent_value.cpp b/yt/yt/library/query/engine_api/position_independent_value.cpp new file mode 100644 index 0000000000..0f5f7031ad --- /dev/null +++ b/yt/yt/library/query/engine_api/position_independent_value.cpp @@ -0,0 +1,177 @@ +#include "position_independent_value.h" + +#include "position_independent_value_transfer.h" + +#ifndef YT_COMPILING_UDF + +#include <yt/yt/client/table_client/unversioned_row.h> +#include <yt/yt/client/table_client/public.h> +#include <yt/yt/client/table_client/helpers.h> + +#include <yt/yt/core/misc/error.h> + +#include <yt/yt/core/ytree/convert.h> + +#include <library/cpp/yt/string/format.h> + +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +void TPIValue::SetStringPosition(const char* string) +{ + NQueryClient::SetStringPosition(this, string); +} + +TStringBuf TPIValue::AsStringBuf() const +{ + return TStringBuf(GetStringPosition(*this), Length); +} + +TFingerprint GetFarmFingerprint(const TPIValue& value) +{ + TUnversionedValue asUnversioned{}; + MakeUnversionedFromPositionIndependent(&asUnversioned, value); + return GetFarmFingerprint(asUnversioned); +} + +TFingerprint GetFarmFingerprint(const TPIValue* begin, const TPIValue* end) +{ + auto asUnversionedRange = BorrowFromPI(TPIValueRange(begin, static_cast<size_t>(end - begin))); + return GetFarmFingerprint(NTableClient::TUnversionedValueRange( + asUnversionedRange.Begin(), + asUnversionedRange.Size())); +} + +//////////////////////////////////////////////////////////////////////////////// + +void PrintTo(const TPIValue& value, ::std::ostream* os) +{ + TUnversionedValue asUnversioned{}; + MakeUnversionedFromPositionIndependent(&asUnversioned, value); + *os << ToString(asUnversioned); +} + +//////////////////////////////////////////////////////////////////////////////// + +TString ToString(const TPIValue& value, bool valueOnly) +{ + TUnversionedValue asUnversioned{}; + MakeUnversionedFromPositionIndependent(&asUnversioned, value); + return ToString(asUnversioned, valueOnly); +} + +//////////////////////////////////////////////////////////////////////////////// + +static_assert(sizeof(TUnversionedValue) == sizeof(TPIValue), "Structs must have equal size"); + +//////////////////////////////////////////////////////////////////////////////// + +void MakePositionIndependentSentinelValue(TPIValue* result, EValueType type, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedSentinelValue(type, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentNullValue(TPIValue* result, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedNullValue(id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentInt64Value(TPIValue* result, i64 value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedInt64Value(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentUint64Value(TPIValue* result, ui64 value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedUint64Value(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentDoubleValue(TPIValue* result, double value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedDoubleValue(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentBooleanValue(TPIValue* result, bool value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedBooleanValue(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentStringLikeValue(TPIValue* result, EValueType valueType, TStringBuf value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedStringLikeValue(valueType, value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentStringValue(TPIValue* result, TStringBuf value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedStringValue(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentAnyValue(TPIValue* result, TStringBuf value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedAnyValue(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentCompositeValue(TPIValue* result, TStringBuf value, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedCompositeValue(value, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +void MakePositionIndependentValueHeader(TPIValue* result, EValueType type, int id, EValueFlags flags) +{ + auto asUnversioned = MakeUnversionedValueHeader(type, id, flags); + MakePositionIndependentFromUnversioned(result, asUnversioned); +} + +//////////////////////////////////////////////////////////////////////////////// + +int CompareRowValues(const TPIValue& lhs, const TPIValue& rhs) +{ + TUnversionedValue lhsAsUnversioned{}; + MakeUnversionedFromPositionIndependent(&lhsAsUnversioned, lhs); + + TUnversionedValue rhsAsUnversioned{}; + MakeUnversionedFromPositionIndependent(&rhsAsUnversioned, rhs); + + return CompareRowValues(lhsAsUnversioned, rhsAsUnversioned); +} + +int CompareRows(const TPIValue* lhsBegin, const TPIValue* lhsEnd, const TPIValue* rhsBegin, const TPIValue* rhsEnd) +{ + auto* lhsCurrent = lhsBegin; + auto* rhsCurrent = rhsBegin; + while (lhsCurrent != lhsEnd && rhsCurrent != rhsEnd) { + int result = CompareRowValues(*lhsCurrent++, *rhsCurrent++); + if (result != 0) { + return result; + } + } + return static_cast<int>(lhsEnd - lhsBegin) - static_cast<int>(rhsEnd - rhsBegin); +} + +//////////////////////////////////////////////////////////////////////////////// + +void ToAny(TPIValue* result, TPIValue* value, TRowBuffer* rowBuffer) +{ + auto unversionedResult = BorrowFromPI(result); + auto unversionedValue = BorrowFromPI(value); + *unversionedResult.GetValue() = EncodeUnversionedAnyValue( + *unversionedValue.GetValue(), + rowBuffer->GetPool()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/position_independent_value.h b/yt/yt/library/query/engine_api/position_independent_value.h new file mode 100644 index 0000000000..bd05404e53 --- /dev/null +++ b/yt/yt/library/query/engine_api/position_independent_value.h @@ -0,0 +1,148 @@ +#pragma once + +#include <yt/yt/client/table_client/row_base.h> +#include <yt/yt/client/table_client/unversioned_value.h> + +#include <library/cpp/yt/farmhash/farm_hash.h> +#include <library/cpp/yt/yson_string/public.h> + +#include <util/system/defaults.h> + +namespace NYT::NQueryClient { + +using namespace NYT::NTableClient; + +//////////////////////////////////////////////////////////////////////////////// + +union TPositionIndependentValueData +{ + //! |Int64| value. + i64 Int64; + //! |Uint64| value. + ui64 Uint64; + //! |Double| value. + double Double; + //! |Boolean| value. + bool Boolean; + //! Offset for |String| type or YSON-encoded value for |Any| type. + //! NB: string is not zero-terminated, so never use it as a TString. + //! Use #TPositionIndependentValue::AsStringBuf() instead. + ptrdiff_t StringOffset; +}; + +static_assert( + sizeof(TPositionIndependentValueData) == 8, + "TPositionIndependentValueData has to be exactly 8 bytes."); + +struct TPositionIndependentValue + : public TNonCopyable +{ + //! Column id w.r.t. the name table. + ui16 Id; + + //! Column type. + EValueType Type; + + //! Various bit-packed flags. + EValueFlags Flags; + + //! Length of a variable-sized value (only meaningful for string-like types). + ui32 Length; + + //! Payload. + TPositionIndependentValueData Data; + + //! Assuming #IsStringLikeType(Type), return string data as a TStringBuf. + TStringBuf AsStringBuf() const; + + TPositionIndependentValue() = default; + + void SetStringPosition(const char* string); +}; + +static_assert( + sizeof(TPositionIndependentValue) == 16, + "TPositionIndependentValue has to be exactly 16 bytes."); +static_assert( + std::is_trivial_v<TPositionIndependentValue>, + "TPositionIndependentValue must be a POD type."); + +using TPIValue = TPositionIndependentValue; +using TPIValueData = TPositionIndependentValueData; +using TPIValueRange = TRange<TPIValue>; +using TMutablePIValueRange = TMutableRange<TPIValue>; +using TPIRowRange = std::pair<TPIValueRange, TPIValueRange>; + +//////////////////////////////////////////////////////////////////////////////// + +//! Computes FarmHash forever-fixed fingerprint for a given TPIValue. +TFingerprint GetFarmFingerprint(const TPIValue& value); + +//! Computes FarmHash forever-fixed fingerprint for a given set of values. +TFingerprint GetFarmFingerprint(const TPIValue* begin, const TPIValue* end); + +//////////////////////////////////////////////////////////////////////////////// + +//! Debug printer for Gtest unittests. +void PrintTo(const TPIValue& value, ::std::ostream* os); + +//////////////////////////////////////////////////////////////////////////////// + +void FormatValue(TStringBuilderBase* builder, const TPIValue& value, TStringBuf format); + +//////////////////////////////////////////////////////////////////////////////// + +void MakeUnversionedFromPositionIndependent(TUnversionedValue* destination, const TPIValue& source); + +void MakePositionIndependentFromUnversioned(TPIValue* destination, const TUnversionedValue& source); + +void CopyPositionIndependent(TPIValue* destination, const TPIValue& source); + +//////////////////////////////////////////////////////////////////////////////// + +void MakePositionIndependentSentinelValue(TPIValue* result, EValueType type, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentNullValue(TPIValue* result, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentInt64Value(TPIValue* result, i64 value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentUint64Value(TPIValue* result, ui64 value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentDoubleValue(TPIValue* result, double value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentBooleanValue(TPIValue* result, bool value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentStringLikeValue(TPIValue* result, EValueType valueType, TStringBuf value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentStringValue(TPIValue* result, TStringBuf value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentAnyValue(TPIValue* result, TStringBuf value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentCompositeValue(TPIValue* result, TStringBuf value, int id = 0, EValueFlags flags = EValueFlags::None); + +void MakePositionIndependentValueHeader(TPIValue* result, EValueType type, int id = 0, EValueFlags flags = EValueFlags::None); + +//////////////////////////////////////////////////////////////////////////////// + +int CompareRowValues(const TPIValue& lhs, const TPIValue& rhs); + +int CompareRows(const TPIValue* lhsBegin, const TPIValue* lhsEnd, const TPIValue* rhsBegin, const TPIValue* rhsEnd); + +int CompareRowValues(const TPIValue& lhs, const TPIValue& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +void ToAny(TPIValue* result, TPIValue* value, TRowBuffer* rowBuffer); + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +T FromPositionIndependentValue(const TPIValue& positionIndependentValue); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + +#define POSITION_INDEPENDENT_VALUE_INL_H +#include "position_independent_value-inl.h" +#undef POSITION_INDEPENDENT_VALUE_INL_H diff --git a/yt/yt/library/query/engine_api/position_independent_value_transfer-inl.h b/yt/yt/library/query/engine_api/position_independent_value_transfer-inl.h new file mode 100644 index 0000000000..b997aeb7c7 --- /dev/null +++ b/yt/yt/library/query/engine_api/position_independent_value_transfer-inl.h @@ -0,0 +1,27 @@ +#ifndef POSITION_INDEPENDENT_VALUE_TRANSFER_INL_H +#error "Direct inclusion of this file is not allowed, position_independent_value_transfer.h" +// For the sake of sane code completion. +#include "position_independent_value_transfer.h" +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +// NB(dtorilov): in WebAssembly case this function should use compartment's memory base. +template <class TNonPI> +TBorrowingPIValueGuard<TNonPI> BorrowFromNonPI(TNonPI value) +{ + return TBorrowingPIValueGuard<TNonPI>(value); +} + +// NB(dtorilov): in WebAssembly case this function should use compartment's memory base. +template <class TPI> +TBorrowingNonPIValueGuard<TPI> BorrowFromPI(TPI value) +{ + return TBorrowingNonPIValueGuard<TPI>(value); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/position_independent_value_transfer.cpp b/yt/yt/library/query/engine_api/position_independent_value_transfer.cpp new file mode 100644 index 0000000000..3b235a4812 --- /dev/null +++ b/yt/yt/library/query/engine_api/position_independent_value_transfer.cpp @@ -0,0 +1,372 @@ +#include "position_independent_value_transfer.h" + +#include <yt/yt/client/table_client/row_buffer.h> + +#include <library/cpp/yt/memory/range.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +TMutablePIValueRange AllocatePIValueRange(TRowBuffer* buffer, int valueCount) +{ + auto* data = buffer->GetPool()->AllocateAligned(sizeof(TPIValue) * valueCount); + return TMutablePIValueRange( + reinterpret_cast<TPIValue*>(data), + static_cast<size_t>(valueCount)); +} + +void CapturePIValue(TRowBuffer* buffer, TPIValue* value) +{ + if (IsStringLikeType(value->Type)) { + char* dst = buffer->GetPool()->AllocateUnaligned(value->Length); + ::memcpy(dst, value->AsStringBuf().Data(), value->AsStringBuf().Size()); + value->SetStringPosition(dst); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +TMutablePIValueRange CapturePIValueRange( + TRowBuffer* buffer, + TPIValueRange values, + bool captureValues) +{ + int count = static_cast<int>(values.Size()); + + auto capturedRange = AllocatePIValueRange(buffer, values.Size()); + + for (size_t index = 0; index < values.Size(); ++index) { + CopyPositionIndependent(&capturedRange[index], values[index]); + } + + if (captureValues) { + for (int index = 0; index < count; ++index) { + CapturePIValue(buffer, &capturedRange[index]); + } + } + + return capturedRange; +} + +TMutablePIValueRange CapturePIValueRange( + TRowBuffer* buffer, + TUnversionedValueRange values, + bool captureValues) +{ + auto captured = buffer->CaptureRow(values, captureValues); + InplaceConvertToPI(captured); + return TMutablePIValueRange( + reinterpret_cast<TPIValue*>(captured.Begin()), + static_cast<size_t>(captured.GetCount())); +} + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TRowBufferHolder) + +struct TRowBufferHolder + : public TSharedRangeHolder +{ + explicit TRowBufferHolder(TRowBufferPtr rowBuffer) + : RowBuffer(rowBuffer) + { } + + const TRowBufferPtr RowBuffer; +}; + +DEFINE_REFCOUNTED_TYPE(TRowBufferHolder) + +TRowBufferHolderPtr MakeRowBufferHolder(TRowBufferPtr rowBuffer) +{ + return New<TRowBufferHolder>(rowBuffer); +} + +//////////////////////////////////////////////////////////////////////////////// + +struct TPIValueTransferBufferTag +{ }; + +TSharedRange<TRange<TPIValue>> CopyAndConvertToPI( + const TSharedRange<TUnversionedRow>& rows, + bool captureValues) +{ + auto buffer = New<TRowBuffer>(TPIValueTransferBufferTag()); + + auto holder = TSharedRangeHolderPtr(MakeRowBufferHolder(buffer)); + if (!captureValues) { + holder = MakeCompositeSharedRangeHolder({holder, rows.GetHolder()}); + } + + auto rowRange = TSharedMutableRange<TRange<TPIValue>>( + reinterpret_cast<TRange<TPIValue>*>( + buffer->GetPool()->AllocateAligned(sizeof(TRange<TPIValue>) * rows.Size())), + rows.Size(), + holder); + + for (size_t rowIndex = 0; rowIndex < rows.Size(); ++rowIndex) { + auto captured = CapturePIValueRange( + buffer.Get(), + TUnversionedValueRange( + rows[rowIndex].Begin(), + rows[rowIndex].GetCount()), + captureValues); + rowRange[rowIndex] = captured; + } + + return TSharedRange<TRange<TPIValue>>( + rowRange.Begin(), + rowRange.Size(), + rowRange.GetHolder()); +} + +TSharedRange<TPIRowRange> CopyAndConvertToPI( + const TSharedRange<TRowRange>& range, + bool captureValues) +{ + auto buffer = New<TRowBuffer>(TPIValueTransferBufferTag()); + + auto holder = TSharedRangeHolderPtr(MakeRowBufferHolder(buffer)); + if (!captureValues) { + holder = MakeCompositeSharedRangeHolder({holder, range.GetHolder()}); + } + + auto mutableRange = TSharedMutableRange<TPIRowRange>( + reinterpret_cast<TPIRowRange*>( + buffer->GetPool()->AllocateAligned( + sizeof(TPIRowRange) * range.Size())), + range.Size(), + holder); + + for (size_t rowIndex = 0; rowIndex < range.Size(); ++rowIndex) { + { + auto captured = CapturePIValueRange( + buffer.Get(), + TUnversionedValueRange( + range[rowIndex].first.Begin(), + range[rowIndex].first.GetCount()), + captureValues); + + mutableRange[rowIndex].first = TRange<TPIValue>( + captured.Begin(), + captured.Size()); + } + { + auto captured = CapturePIValueRange( + buffer.Get(), + TUnversionedValueRange( + range[rowIndex].second.Begin(), + range[rowIndex].second.GetCount()), + captureValues); + + mutableRange[rowIndex].second = TRange<TPIValue>( + captured.Begin(), + captured.Size()); + } + } + + return TSharedRange<TPIRowRange>( + mutableRange.Begin(), + mutableRange.Size(), + mutableRange.GetHolder()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TMutableUnversionedRow CopyAndConvertFromPI( + TRowBuffer* buffer, + TPIValueRange values, + bool captureValues) +{ + auto capturedRow = TMutableUnversionedRow::Allocate( + buffer->GetPool(), + values.Size()); + + for (size_t index = 0; index < values.Size(); ++index) { + MakeUnversionedFromPositionIndependent(&capturedRow[index], values[index]); + } + + if (captureValues) { + buffer->CaptureValues(capturedRow); + } + + return capturedRow; +} + +std::vector<TUnversionedRow> CopyAndConvertFromPI( + TRowBuffer* buffer, + const std::vector<TPIValueRange>& rows, + bool captureValues) +{ + std::vector<TUnversionedRow> result; + result.reserve(rows.size()); + + for (auto& row : rows) { + result.push_back(CopyAndConvertFromPI(buffer, row, captureValues)); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +TMutablePIValueRange InplaceConvertToPI(TMutableUnversionedValueRange range) +{ + auto positionIndependent = TMutablePIValueRange( + reinterpret_cast<TPIValue*>(range.Begin()), + range.Size()); + + for (size_t index = 0; index < range.Size(); ++index) { + MakePositionIndependentFromUnversioned(&positionIndependent[index], range[index]); + } + + return positionIndependent; +} + +TMutablePIValueRange InplaceConvertToPI(const TUnversionedRow& row) +{ + return InplaceConvertToPI( + TMutableUnversionedValueRange( + const_cast<TUnversionedValue*>(row.Begin()), + static_cast<size_t>(row.GetCount()))); +} + +//////////////////////////////////////////////////////////////////////////////// + +TMutableUnversionedValueRange InplaceConvertFromPI(TMutablePIValueRange range) +{ + auto unversioned = TMutableUnversionedValueRange( + reinterpret_cast<TUnversionedValue*>(range.Begin()), + range.Size()); + + for (size_t index = 0; index < range.Size(); ++index) { + MakeUnversionedFromPositionIndependent(&unversioned[index], range[index]); + } + + return unversioned; +} + +//////////////////////////////////////////////////////////////////////////////// + +TBorrowingPIValueGuard<TUnversionedValue*>::TBorrowingPIValueGuard(TUnversionedValue* value) + : Value_(value) +{ + PIValue_ = reinterpret_cast<TPIValue*>(Value_); + MakePositionIndependentFromUnversioned(PIValue_, *Value_); +} + +TBorrowingPIValueGuard<TUnversionedValue*>::~TBorrowingPIValueGuard() +{ + MakeUnversionedFromPositionIndependent(Value_, *PIValue_); +} + +TPIValue* TBorrowingPIValueGuard<TUnversionedValue*>::TBorrowingPIValueGuard<TUnversionedValue*>::GetPIValue() +{ + return PIValue_; +} + +//////////////////////////////////////////////////////////////////////////////// + +TBorrowingPIValueGuard<TUnversionedValueRange>::TBorrowingPIValueGuard( + TUnversionedValueRange valueRange) +{ + if (valueRange.Empty()) { + return; + } + + ValueRange_ = TMutableUnversionedValueRange( + const_cast<TUnversionedValue*>(&valueRange.Front()), + valueRange.Size()); + + PIValueRange_ = TMutablePIValueRange( + reinterpret_cast<TPIValue*>(&ValueRange_.Front()), + ValueRange_.Size()); + + InplaceConvertToPI(ValueRange_); +} + +TBorrowingPIValueGuard<TUnversionedValueRange>::~TBorrowingPIValueGuard() +{ + InplaceConvertFromPI(PIValueRange_); +} + +TPIValue* TBorrowingPIValueGuard<TUnversionedValueRange>::Begin() +{ + if (PIValueRange_.Empty()) { + return nullptr; + } + + return &PIValueRange_.Front(); +} + +const TPIValue& TBorrowingPIValueGuard<TUnversionedValueRange>::operator[](int index) const +{ + return PIValueRange_[index]; +} + +size_t TBorrowingPIValueGuard<TUnversionedValueRange>::Size() +{ + return PIValueRange_.Size(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TBorrowingNonPIValueGuard<TPIValue*>::TBorrowingNonPIValueGuard(TPIValue* piValue) + : PIValue_(piValue) +{ + Value_ = reinterpret_cast<TUnversionedValue*>(PIValue_); + MakeUnversionedFromPositionIndependent(Value_, *PIValue_); +} + +TBorrowingNonPIValueGuard<TPIValue*>::~TBorrowingNonPIValueGuard() +{ + MakePositionIndependentFromUnversioned(PIValue_, *Value_); +} + +TUnversionedValue* TBorrowingNonPIValueGuard<TPIValue*>::GetValue() +{ + return Value_; +} + +//////////////////////////////////////////////////////////////////////////////// + +TBorrowingNonPIValueGuard<TPIValueRange>::TBorrowingNonPIValueGuard( + TPIValueRange valueRange) +{ + if (valueRange.Empty()) { + return; + } + + PIValueRange_ = TMutablePIValueRange( + const_cast<TPIValue*>(&valueRange.Front()), + valueRange.Size()); + + ValueRange_ = TMutableUnversionedValueRange( + reinterpret_cast<TUnversionedValue*>(&PIValueRange_.Front()), + PIValueRange_.Size()); + + InplaceConvertFromPI(PIValueRange_); +} + +TBorrowingNonPIValueGuard<TPIValueRange>::~TBorrowingNonPIValueGuard() +{ + InplaceConvertToPI(ValueRange_); +} + +TUnversionedValue* TBorrowingNonPIValueGuard<TPIValueRange>::Begin() +{ + if (ValueRange_.Empty()) { + return nullptr; + } + + return &ValueRange_.Front(); +} + +size_t TBorrowingNonPIValueGuard<TPIValueRange>::Size() +{ + return PIValueRange_.Size(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/position_independent_value_transfer.h b/yt/yt/library/query/engine_api/position_independent_value_transfer.h new file mode 100644 index 0000000000..190d345e17 --- /dev/null +++ b/yt/yt/library/query/engine_api/position_independent_value_transfer.h @@ -0,0 +1,132 @@ +#pragma once + +#include "position_independent_value.h" + +#include <yt/yt/client/table_client/unversioned_row.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +TMutablePIValueRange AllocatePIValueRange(TRowBuffer* buffer, int valueCount); + +void CapturePIValue(TRowBuffer* buffer, TPIValue* value); + +TMutablePIValueRange CapturePIValueRange( + TRowBuffer* buffer, + TPIValueRange values, + bool captureValues = true); +TMutablePIValueRange CapturePIValueRange( + TRowBuffer* buffer, + TUnversionedValueRange Values, + bool captureValues = true); + +TSharedRange<TRange<TPIValue>> CopyAndConvertToPI( + const TSharedRange<TUnversionedRow>& rows, + bool captureValues = true); +TSharedRange<TPIRowRange> CopyAndConvertToPI( + const TSharedRange<TRowRange>& range, + bool captureValues = true); + +TMutableUnversionedRow CopyAndConvertFromPI( + TRowBuffer* buffer, + TPIValueRange values, + bool captureValues = true); +std::vector<TUnversionedRow> CopyAndConvertFromPI( + TRowBuffer* buffer, + const std::vector<TPIValueRange>& rows, + bool captureValues = true); + +TMutablePIValueRange InplaceConvertToPI(TMutableUnversionedValueRange range); +TMutablePIValueRange InplaceConvertToPI(const TUnversionedRow& row); + +TMutableUnversionedValueRange InplaceConvertFromPI(TMutablePIValueRange range); + +//////////////////////////////////////////////////////////////////////////////// + +template <class TNonPI> +class TBorrowingPIValueGuard; + +template <class TNonPI> +TBorrowingPIValueGuard<TNonPI> BorrowFromNonPI(TNonPI value); + +template <class TNonPI> +class TBorrowingNonPIValueGuard; + +template <class TPI> +TBorrowingNonPIValueGuard<TPI> BorrowFromPI(TPI value); + +//////////////////////////////////////////////////////////////////////////////// + +template <> +class TBorrowingPIValueGuard<TUnversionedValue*> + : public TNonCopyable +{ +public: + explicit TBorrowingPIValueGuard(TUnversionedValue* value); + ~TBorrowingPIValueGuard(); + + TPIValue* GetPIValue(); + +private: + TUnversionedValue* Value_ = nullptr; + TPIValue* PIValue_ = nullptr; +}; + +template <> +class TBorrowingPIValueGuard<TUnversionedValueRange> + : public TNonCopyable +{ +public: + explicit TBorrowingPIValueGuard(TUnversionedValueRange valueRange); + ~TBorrowingPIValueGuard(); + + TPIValue* Begin(); + const TPIValue& operator[](int index) const; + size_t Size(); + +private: + TMutableUnversionedValueRange ValueRange_{}; + TMutablePIValueRange PIValueRange_{}; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template <> +class TBorrowingNonPIValueGuard<TPIValue*> + : public TNonCopyable +{ +public: + explicit TBorrowingNonPIValueGuard(TPIValue* piValue); + ~TBorrowingNonPIValueGuard(); + + TUnversionedValue* GetValue(); + +private: + TUnversionedValue* Value_ = nullptr; + TPIValue* PIValue_ = nullptr; +}; + +template <> +class TBorrowingNonPIValueGuard<TPIValueRange> + : public TNonCopyable +{ +public: + explicit TBorrowingNonPIValueGuard(TPIValueRange valueRange); + ~TBorrowingNonPIValueGuard(); + + TUnversionedValue* Begin(); + size_t Size(); + +private: + TMutableUnversionedValueRange ValueRange_{}; + TMutablePIValueRange PIValueRange_{}; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + +#define POSITION_INDEPENDENT_VALUE_TRANSFER_INL_H +#include "position_independent_value_transfer-inl.h" +#undef POSITION_INDEPENDENT_VALUE_TRANSFER_INL_H diff --git a/yt/yt/library/query/engine_api/public.h b/yt/yt/library/query/engine_api/public.h new file mode 100644 index 0000000000..2f7688d558 --- /dev/null +++ b/yt/yt/library/query/engine_api/public.h @@ -0,0 +1,35 @@ +#pragma once + +#include <yt/yt/library/query/base/public.h> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(IFunctionCodegen) +DECLARE_REFCOUNTED_STRUCT(IAggregateCodegen) + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(TFunctionProfilerMap) +using TConstFunctionProfilerMapPtr = TIntrusivePtr<const TFunctionProfilerMap>; + +DECLARE_REFCOUNTED_STRUCT(TAggregateProfilerMap) +using TConstAggregateProfilerMapPtr = TIntrusivePtr<const TAggregateProfilerMap>; + +DECLARE_REFCOUNTED_STRUCT(TRangeExtractorMap) +using TConstRangeExtractorMapPtr = TIntrusivePtr<const TRangeExtractorMap>; + +DECLARE_REFCOUNTED_STRUCT(TConstraintExtractorMap) +using TConstConstraintExtractorMapPtr = TIntrusivePtr<const TConstraintExtractorMap>; + +//////////////////////////////////////////////////////////////////////////////// + +const TConstFunctionProfilerMapPtr GetBuiltinFunctionProfilers(); +const TConstAggregateProfilerMapPtr GetBuiltinAggregateProfilers(); +const TConstRangeExtractorMapPtr GetBuiltinRangeExtractors(); +const TConstConstraintExtractorMapPtr GetBuiltinConstraintExtractors(); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/range_inferrer.cpp b/yt/yt/library/query/engine_api/range_inferrer.cpp new file mode 100644 index 0000000000..8506d743a2 --- /dev/null +++ b/yt/yt/library/query/engine_api/range_inferrer.cpp @@ -0,0 +1,229 @@ +#include "range_inferrer.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +TKeyTriePtr ExtractMultipleConstraints( + TConstExpressionPtr expr, + const TKeyColumns& keyColumns, + const TRowBufferPtr& rowBuffer, + const TConstRangeExtractorMapPtr& rangeExtractors) +{ + if (!expr) { + return TKeyTrie::Universal(); + } + + if (auto binaryOpExpr = expr->As<TBinaryOpExpression>()) { + auto opcode = binaryOpExpr->Opcode; + auto lhsExpr = binaryOpExpr->Lhs; + auto rhsExpr = binaryOpExpr->Rhs; + + if (opcode == EBinaryOp::And) { + return IntersectKeyTrie( + ExtractMultipleConstraints(lhsExpr, keyColumns, rowBuffer, rangeExtractors), + ExtractMultipleConstraints(rhsExpr, keyColumns, rowBuffer, rangeExtractors)); + } if (opcode == EBinaryOp::Or) { + return UniteKeyTrie( + ExtractMultipleConstraints(lhsExpr, keyColumns, rowBuffer, rangeExtractors), + ExtractMultipleConstraints(rhsExpr, keyColumns, rowBuffer, rangeExtractors)); + } else { + if (rhsExpr->As<TReferenceExpression>()) { + // Ensure that references are on the left. + std::swap(lhsExpr, rhsExpr); + opcode = GetReversedBinaryOpcode(opcode); + } + + auto referenceExpr = lhsExpr->As<TReferenceExpression>(); + auto constantExpr = rhsExpr->As<TLiteralExpression>(); + + auto result = TKeyTrie::Universal(); + + if (referenceExpr && constantExpr) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0) { + auto value = TValue(constantExpr->Value); + + result = New<TKeyTrie>(0); + + auto& bounds = result->Bounds; + switch (opcode) { + case EBinaryOp::Equal: + result->Offset = keyPartIndex; + result->Next.emplace_back(value, TKeyTrie::Universal()); + break; + case EBinaryOp::NotEqual: + result->Offset = keyPartIndex; + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + bounds.emplace_back(value, false); + bounds.emplace_back(value, false); + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + + break; + case EBinaryOp::Less: + result->Offset = keyPartIndex; + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + bounds.emplace_back(value, false); + + break; + case EBinaryOp::LessOrEqual: + result->Offset = keyPartIndex; + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Min), true); + bounds.emplace_back(value, true); + + break; + case EBinaryOp::Greater: + result->Offset = keyPartIndex; + bounds.emplace_back(value, false); + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + + break; + case EBinaryOp::GreaterOrEqual: + result->Offset = keyPartIndex; + bounds.emplace_back(value, true); + bounds.emplace_back(MakeUnversionedSentinelValue(EValueType::Max), true); + + break; + default: + break; + } + } + } + + return result; + } + } else if (auto functionExpr = expr->As<TFunctionExpression>()) { + auto found = rangeExtractors->find(functionExpr->FunctionName); + if (found == rangeExtractors->end()) { + return TKeyTrie::Universal(); + } + + auto rangeExtractor = found->second; + + return rangeExtractor( + functionExpr, + keyColumns, + rowBuffer); + } else if (auto inExpr = expr->As<TInExpression>()) { + int argsSize = inExpr->Arguments.size(); + + std::vector<int> keyMapping(keyColumns.size(), -1); + for (int index = 0; index < argsSize; ++index) { + auto referenceExpr = inExpr->Arguments[index]->As<TReferenceExpression>(); + if (referenceExpr) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0 && keyMapping[keyPartIndex] == -1) { + keyMapping[keyPartIndex] = index; + } + } + } + + std::vector<TKeyTriePtr> keyTries; + for (int rowIndex = 0; rowIndex < std::ssize(inExpr->Values); ++rowIndex) { + auto literalTuple = inExpr->Values[rowIndex]; + + auto rowConstraint = TKeyTrie::Universal(); + for (int keyIndex = keyMapping.size() - 1; keyIndex >= 0; --keyIndex) { + auto index = keyMapping[keyIndex]; + if (index >= 0) { + auto valueConstraint = New<TKeyTrie>(keyIndex); + valueConstraint->Next.emplace_back(literalTuple[index], std::move(rowConstraint)); + rowConstraint = std::move(valueConstraint); + } + } + + keyTries.push_back(rowConstraint); + } + + return UniteKeyTrie(keyTries); + } else if (auto betweenExpr = expr->As<TBetweenExpression>()) { + int argsSize = betweenExpr->Arguments.size(); + + std::vector<int> keyMapping(keyColumns.size(), -1); + for (int index = 0; index < argsSize; ++index) { + auto referenceExpr = betweenExpr->Arguments[index]->As<TReferenceExpression>(); + if (referenceExpr) { + int keyPartIndex = ColumnNameToKeyPartIndex(keyColumns, referenceExpr->ColumnName); + if (keyPartIndex >= 0 && keyMapping[keyPartIndex] == -1) { + keyMapping[keyPartIndex] = index; + } + } + } + + std::vector<TKeyTriePtr> keyTries; + for (int rowIndex = 0; rowIndex < std::ssize(betweenExpr->Ranges); ++rowIndex) { + auto literalRange = betweenExpr->Ranges[rowIndex]; + + auto lower = literalRange.first; + auto upper = literalRange.second; + + size_t prefix = 0; + while (prefix < lower.GetCount() && prefix < upper.GetCount() && lower[prefix] == upper[prefix]) { + ++prefix; + } + + int rangeColumnIndex = -1; + auto rowConstraint = TKeyTrie::Universal(); + for (int keyIndex = keyMapping.size() - 1; keyIndex >= 0; --keyIndex) { + auto index = keyMapping[keyIndex]; + if (index >= 0 && index < static_cast<int>(prefix)) { + auto valueConstraint = New<TKeyTrie>(keyIndex); + valueConstraint->Next.emplace_back(lower[index], std::move(rowConstraint)); + rowConstraint = std::move(valueConstraint); + } + + if (index == static_cast<int>(prefix)) { + rangeColumnIndex = keyIndex; + } + } + + if (rangeColumnIndex != -1) { + auto rangeConstraint = New<TKeyTrie>(rangeColumnIndex); + auto& bounds = rangeConstraint->Bounds; + + bounds.emplace_back( + lower.GetCount() > prefix + ? lower[prefix] + : MakeUnversionedSentinelValue(EValueType::Min), + true); + + bounds.emplace_back( + upper.GetCount() > prefix + ? upper[prefix] + : MakeUnversionedSentinelValue(EValueType::Max), + true); + + rowConstraint = IntersectKeyTrie(rowConstraint, rangeConstraint); + } + + keyTries.push_back(rowConstraint); + } + + return UniteKeyTrie(keyTries); + } else if (auto literalExpr = expr->As<TLiteralExpression>()) { + TValue value = literalExpr->Value; + if (value.Type == EValueType::Boolean) { + return value.Data.Boolean ? TKeyTrie::Universal() : TKeyTrie::Empty(); + } + } + + return TKeyTrie::Universal(); +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK TRangeInferrer CreateRangeInferrer( + TConstExpressionPtr /*predicate*/, + const TTableSchemaPtr& /*schema*/, + const TKeyColumns& /*keyColumns*/, + const IColumnEvaluatorCachePtr& /*evaluatorCache*/, + const TConstRangeExtractorMapPtr& /*rangeExtractors*/, + const TQueryOptions& /*options*/) +{ + // Proper implementation resides in yt/yt/library/query/engine/range_inferrer.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/range_inferrer.h b/yt/yt/library/query/engine_api/range_inferrer.h new file mode 100644 index 0000000000..3617f50d56 --- /dev/null +++ b/yt/yt/library/query/engine_api/range_inferrer.h @@ -0,0 +1,47 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/query/base/functions.h> +#include <yt/yt/library/query/base/key_trie.h> +#include <yt/yt/library/query/base/query.h> + +#include <functional> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +struct TRangeExtractorMap + : public TRefCounted + , public std::unordered_map<TString, TRangeExtractor> +{ }; + +DEFINE_REFCOUNTED_TYPE(TRangeExtractorMap) + +//////////////////////////////////////////////////////////////////////////////// + +//! Descends down to conjuncts and disjuncts and extract all constraints. +TKeyTriePtr ExtractMultipleConstraints( + TConstExpressionPtr expr, + const TKeyColumns& keyColumns, + const TRowBufferPtr& rowBuffer, + const TConstRangeExtractorMapPtr& rangeExtractors = GetBuiltinRangeExtractors()); + +//////////////////////////////////////////////////////////////////////////////// + +using TRangeInferrer = std::function<std::vector<TMutableRowRange>( + const TRowRange& keyRange, + const TRowBufferPtr& rowBuffer)>; + +TRangeInferrer CreateRangeInferrer( + TConstExpressionPtr predicate, + const TTableSchemaPtr& schema, + const TKeyColumns& keyColumns, + const IColumnEvaluatorCachePtr& evaluatorCache, + const TConstRangeExtractorMapPtr& rangeExtractors, + const TQueryOptions& options); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/engine_api/ya.make b/yt/yt/library/query/engine_api/ya.make new file mode 100644 index 0000000000..303069037e --- /dev/null +++ b/yt/yt/library/query/engine_api/ya.make @@ -0,0 +1,35 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +PROTO_NAMESPACE(yt) + +SRCS( + append_function_implementation.cpp + column_evaluator.cpp + config.cpp + coordinator.cpp + evaluation_helpers.cpp + evaluator.cpp + builtin_function_profiler.cpp + range_inferrer.cpp + new_range_inferrer.cpp + position_independent_value.cpp + position_independent_value_transfer.cpp +) + +ADDINCL( + contrib/libs/sparsehash/src +) + +PEERDIR( + yt/yt/core + yt/yt/library/query/misc + yt/yt/library/query/proto + yt/yt/library/query/base + yt/yt/client + library/cpp/yt/memory + contrib/libs/sparsehash +) + +END() diff --git a/yt/yt/library/query/misc/function_context-inl.h b/yt/yt/library/query/misc/function_context-inl.h new file mode 100644 index 0000000000..c6567b2766 --- /dev/null +++ b/yt/yt/library/query/misc/function_context-inl.h @@ -0,0 +1,25 @@ +#ifndef FUNCTION_CONTEXT_INL_H_ +#error "Direct inclusion of this file is not allowed, include function_context.h" +// For the sake of sane code completion. +#include "function_context.h" +#endif + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T, class... Args> +T* TFunctionContext::CreateObject(Args&&... args) +{ + auto pointer = new T(std::forward<Args>(args)...); + auto deleter = [] (void* ptr) { + static_assert(sizeof(T) > 0, "Cannot delete incomplete type."); + delete static_cast<T*>(ptr); + }; + + return static_cast<T*>(CreateUntypedObject(pointer, deleter)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/misc/function_context.cpp b/yt/yt/library/query/misc/function_context.cpp new file mode 100644 index 0000000000..14a42a81b0 --- /dev/null +++ b/yt/yt/library/query/misc/function_context.cpp @@ -0,0 +1,74 @@ +#include "function_context.h" + +#include "objects_holder.h" + +#include <library/cpp/yt/assert/assert.h> + +#include <vector> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +class TFunctionContext::TImpl + : public TObjectsHolder +{ +public: + explicit TImpl(std::unique_ptr<bool[]> literalArgs) + : LiteralArgs_(std::move(literalArgs)) + { } + + void* GetPrivateData() const + { + return PrivateData_; + } + + void SetPrivateData(void* privateData) + { + PrivateData_ = privateData; + } + + bool IsLiteralArg(int argIndex) const + { + YT_ASSERT(argIndex >= 0); + return LiteralArgs_[argIndex]; + } + +private: + const std::unique_ptr<bool[]> LiteralArgs_; + + void* PrivateData_ = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TFunctionContext::TFunctionContext(std::unique_ptr<bool[]> literalArgs) + : Impl_(std::make_unique<TImpl>(std::move(literalArgs))) +{ } + +TFunctionContext::~TFunctionContext() = default; + +void* TFunctionContext::CreateUntypedObject(void* pointer, void(*deleter)(void*)) +{ + return Impl_->Register(pointer, deleter); +} + +void* TFunctionContext::GetPrivateData() const +{ + return Impl_->GetPrivateData(); +} + +void TFunctionContext::SetPrivateData(void* data) +{ + Impl_->SetPrivateData(data); +} + +bool TFunctionContext::IsLiteralArg(int argIndex) const +{ + return Impl_->IsLiteralArg(argIndex); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + diff --git a/yt/yt/library/query/misc/function_context.h b/yt/yt/library/query/misc/function_context.h new file mode 100644 index 0000000000..4af857dd0f --- /dev/null +++ b/yt/yt/library/query/misc/function_context.h @@ -0,0 +1,44 @@ +#pragma once + +#include <memory> + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +// XXX(babenko): "struct" is due to ABI bug. +struct TFunctionContext +{ +public: + explicit TFunctionContext(std::unique_ptr<bool[]> literalArgs); + ~TFunctionContext(); + + //! Creates typed function-local object. + //! Function-local objects are destroyed automaticaly when the function context is destroyed. + //! In case of any error, nullptr is returned. + template <class T, class... Args> + T* CreateObject(Args&&... args); + + //! Creates untyped function-local object. + //! Function-local objects are destroyed automaticaly when the function context is destroyed. + //! In case of any error, nullptr is returned. + void* CreateUntypedObject(void* pointer, void(*deleter)(void*)); + + void* GetPrivateData() const; + void SetPrivateData(void* data); + + bool IsLiteralArg(int argIndex) const; + +private: + class TImpl; + const std::unique_ptr<TImpl> Impl_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient + +#define FUNCTION_CONTEXT_INL_H_ +#include "function_context-inl.h" +#undef FUNCTION_CONTEXT_INL_H_ + diff --git a/yt/yt/library/query/misc/objects_holder.cpp b/yt/yt/library/query/misc/objects_holder.cpp new file mode 100644 index 0000000000..afb78f0de7 --- /dev/null +++ b/yt/yt/library/query/misc/objects_holder.cpp @@ -0,0 +1,31 @@ +#include "objects_holder.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +void* TObjectsHolder::Register(void* pointer, void(*deleter)(void*)) +{ + try { + auto holder = std::unique_ptr<void, void(*)(void*)>(pointer, deleter); + Objects_.push_back(std::move(holder)); + return pointer; + } catch (...) { + return nullptr; + } +} + +void TObjectsHolder::Clear() +{ + Objects_.clear(); +} + +void TObjectsHolder::Merge(TObjectsHolder&& other) +{ + std::move(other.Objects_.begin(), other.Objects_.end(), std::back_inserter(Objects_)); + other.Objects_.clear(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/query/misc/objects_holder.h b/yt/yt/library/query/misc/objects_holder.h new file mode 100644 index 0000000000..b43d2e4375 --- /dev/null +++ b/yt/yt/library/query/misc/objects_holder.h @@ -0,0 +1,46 @@ +#pragma once + +#include <memory> +#include <vector> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TObjectsHolder +{ +public: + void* Register(void* pointer, void(*deleter)(void*)); + + template <class T> + T* Register(T* pointer) + { + auto deleter = [] (void* ptr) { + static_assert(sizeof(T) > 0, "Cannot delete incomplete type."); + delete static_cast<T*>(ptr); + }; + + return static_cast<T*>(Register(pointer, deleter)); + } + + template <class T, class... TArgs> + T* New(TArgs&&... args) + { + return TObjectsHolder::Register(new T(std::forward<TArgs>(args)...)); + } + + void Clear(); + + void Merge(TObjectsHolder&& other); + + TObjectsHolder() = default; + TObjectsHolder(TObjectsHolder&&) = default; + TObjectsHolder(const TObjectsHolder&) = delete; + +private: + std::vector<std::unique_ptr<void, void(*)(void*)>> Objects_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/query/misc/ya.make b/yt/yt/library/query/misc/ya.make new file mode 100644 index 0000000000..535156bd79 --- /dev/null +++ b/yt/yt/library/query/misc/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +PROTO_NAMESPACE(yt) + +SRCS( + objects_holder.cpp + function_context.cpp +) + +PEERDIR( + library/cpp/yt/assert +) + +END() diff --git a/yt/yt/library/query/proto/functions_cache.proto b/yt/yt/library/query/proto/functions_cache.proto new file mode 100644 index 0000000000..67b7c6756e --- /dev/null +++ b/yt/yt/library/query/proto/functions_cache.proto @@ -0,0 +1,21 @@ +package NYT.NQueryClient.NProto; + +import "yt_proto/yt/client/chunk_client/proto/chunk_spec.proto"; + +//////////////////////////////////////////////////////////////////////////////// + +message TExternalFunctionImpl +{ + required bool is_aggregate = 1; + required string name = 2; + required string symbol_name = 3; + required int32 calling_convention = 5; + repeated NChunkClient.NProto.TChunkSpec chunk_specs = 6; + + required string repeated_arg_type = 7; + required int32 repeated_arg_index = 8; + optional bool use_function_context = 9 [default = false]; +}; + +//////////////////////////////////////////////////////////////////////////////// + diff --git a/yt/yt/library/query/proto/query.proto b/yt/yt/library/query/proto/query.proto new file mode 100644 index 0000000000..6a22eccb82 --- /dev/null +++ b/yt/yt/library/query/proto/query.proto @@ -0,0 +1,241 @@ +package NYT.NQueryClient.NProto; + +import "yt_proto/yt/core/misc/proto/guid.proto"; +import "yt_proto/yt/client/misc/proto/workload.proto"; +import "yt_proto/yt/client/table_chunk_format/proto/chunk_meta.proto"; + +//////////////////////////////////////////////////////////////////////////////// + +message TColumnDescriptor +{ + required string name = 1; + required uint32 index = 2; +} + +message TExpression +{ + required int32 kind = 1; + optional uint32 type = 2; + optional NTableClient.NProto.TLogicalType logical_type = 5; + + // required int32 location_begin = 3; (deprecated) + // required int32 location_end = 4; (deprecated) + + extensions 100 to max; +} + +message TLiteralExpression +{ + extend TExpression + { + optional TLiteralExpression literal_expression = 103; + } + + optional int64 int64_value = 1; + optional uint64 uint64_value = 2; + optional double double_value = 3; + optional bytes string_value = 4; + optional bool boolean_value = 5; +} + +message TReferenceExpression +{ + extend TExpression + { + optional TReferenceExpression reference_expression = 104; + } + required string column_name = 1; +} + +message TFunctionExpression +{ + extend TExpression + { + optional TFunctionExpression function_expression = 105; + } + required string function_name = 1; + repeated TExpression arguments = 2; + +} + +message TUnaryOpExpression +{ + extend TExpression + { + optional TUnaryOpExpression unary_op_expression = 106; + } + required int32 opcode = 1; // EUnaryOp + required TExpression operand = 2; +} + +message TBinaryOpExpression +{ + extend TExpression + { + optional TBinaryOpExpression binary_op_expression = 107; + } + required int32 opcode = 1; // EBinaryOp + required TExpression lhs = 2; + required TExpression rhs = 3; +} + +message TInExpression +{ + extend TExpression + { + optional TInExpression in_expression = 108; + } + repeated TExpression arguments = 1; + required bytes values = 2; +} + +message TBetweenExpression +{ + extend TExpression + { + optional TBetweenExpression between_expression = 110; + } + repeated TExpression arguments = 1; + required bytes ranges = 2; +} + +message TTransformExpression +{ + extend TExpression + { + optional TTransformExpression transform_expression = 109; + } + repeated TExpression arguments = 1; + required bytes values = 2; + optional TExpression default_expression = 3; +} + +message TNamedItem +{ + required TExpression expression = 1; + required string name = 2; +} + +message TAggregateItem +{ + required TExpression expression = 1; + required string name = 2; + required string aggregate_function_name = 3; + optional uint32 state_type = 4; + optional uint32 result_type = 5; + // COMPAT(sabdenov): Legacy clients may omit this field. + repeated TExpression arguments = 6; +} + +message TSelfEquation +{ + required TExpression expression = 1; + required bool evaluated = 2; +} + +message TJoinClause +{ + repeated TExpression foreign_equations = 1; + repeated TSelfEquation self_equations = 2; + required bool can_use_source_ranges = 3; + + required NYT.NTableClient.NProto.TTableSchemaExt original_schema = 5; + repeated TColumnDescriptor schema_mapping = 6; + repeated string self_joined_columns = 7; + repeated string foreign_joined_columns = 8; + + required NYT.NProto.TGuid foreign_object_id = 9; + // COMPAT(babenko): legacy clients may omit this field. + optional NYT.NProto.TGuid foreign_cell_id = 15; + + required bool is_left = 10; + + optional uint64 common_key_prefix = 11 [default = 0]; + + optional TExpression predicate = 12; + optional uint64 foreign_key_prefix = 13 [default = 0]; + optional uint64 common_key_prefix_new = 14 [default = 0]; +} + +message TGroupClause +{ + repeated TNamedItem group_items = 1; + repeated TAggregateItem aggregate_items = 2; + required uint32 totals_mode = 6; + optional uint32 common_prefix_with_primary_key = 7 [default = 0]; +} + +message TOrderItem +{ + required TExpression expression = 1; + required bool descending = 2; +} + +message TOrderClause +{ + repeated TOrderItem order_items = 1; +} + +message TProjectClause +{ + repeated TNamedItem projections = 1; +} + +message TQuery +{ + required NYT.NProto.TGuid id = 1; + required int64 input_row_limit = 2; + required int64 output_row_limit = 3; + + optional int64 offset = 16 [default = 0]; + required int64 limit = 4; + + required NYT.NTableClient.NProto.TTableSchemaExt original_schema = 5; + repeated TColumnDescriptor schema_mapping = 6; + + repeated TJoinClause join_clauses = 7; + optional TExpression where_clause = 8; + optional TGroupClause group_clause = 9; + optional TExpression having_clause = 12; + optional TOrderClause order_clause = 11; + optional TProjectClause project_clause = 10; + + optional bool use_disjoint_group_by = 13 [default = false]; + optional bool infer_ranges = 14 [default = true]; + required bool is_final = 15; +} + +message TQueryOptions +{ + required uint64 timestamp = 1; + optional uint64 retention_timestamp = 14 [default = 0]; + required bool verbose_logging = 2; + required int64 max_subqueries = 3; + required bool enable_code_cache = 4; + optional NYT.NProto.TWorkloadDescriptor workload_descriptor = 5; + reserved 6; // deprecated + optional bool allow_full_scan = 7 [default = true]; + optional NYT.NProto.TGuid read_session_id = 8; + optional uint64 deadline = 9; + optional uint64 memory_limit_per_node = 10; + optional string execution_pool = 11; + optional bool suppress_access_tracking = 12; + optional uint64 range_expansion_limit = 13; + optional bool new_range_inference = 15; +} + +message TDataSource +{ + required NYT.NProto.TGuid object_id = 1; + // COMPAT(babenko): legacy clients may omit this field. + optional NYT.NProto.TGuid cell_id = 7; + + required uint64 mount_revision = 2; + required bytes ranges = 3; + optional bool lookup_supported = 4 [default = true]; + + optional bytes keys = 5; + optional uint64 key_width = 6 [default = 0]; +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/library/query/proto/query_service.proto b/yt/yt/library/query/proto/query_service.proto new file mode 100644 index 0000000000..963b59cca6 --- /dev/null +++ b/yt/yt/library/query/proto/query_service.proto @@ -0,0 +1,274 @@ +package NYT.NQueryClient.NProto; +import "yt/library/query/proto/query.proto"; +import "yt/library/query/proto/functions_cache.proto"; +import "yt_proto/yt/client/misc/proto/workload.proto"; +import "yt_proto/yt/client/node_tracker_client/proto/node_directory.proto"; +import "yt_proto/yt/client/table_chunk_format/proto/wire_protocol.proto"; +import "yt_proto/yt/client/query_client/proto/query_statistics.proto"; +import "yt_proto/yt/client/chaos_client/proto/replication_card.proto"; +import "yt_proto/yt/client/chunk_client/proto/read_limit.proto"; +import "yt_proto/yt/client/chunk_client/proto/chunk_spec.proto"; +import "yt_proto/yt/core/rpc/proto/rpc.proto"; +import "yt_proto/yt/core/misc/proto/guid.proto"; +import "yt_proto/yt/core/misc/proto/error.proto"; + +//////////////////////////////////////////////////////////////////////////////// + +message TReqExecuteExt +{ + extend NRpc.NProto.TRequestHeader + { + optional TReqExecuteExt req_execute_ext = 200; + } + + optional string execution_pool = 1 [default = "default"]; + optional string execution_tag = 2 [default = "default"]; +} + +message TReqExecute +{ + required TQuery query = 1; + repeated TExternalFunctionImpl external_functions = 5; + required NYT.NNodeTrackerClient.NProto.TNodeDirectory node_directory = 6; + required TQueryOptions options = 2; + repeated TDataSource data_sources = 3; + // TODO(kiselyovp) move codec to header + required int32 response_codec = 4; // ECodec + + reserved 7; +} + +message TRspExecute +{ + required TQueryStatistics query_statistics = 1; + // Attachments contain wire-encoded data. +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqMultireadExt +{ + extend NRpc.NProto.TRequestHeader + { + optional TReqMultireadExt req_multiread_ext = 201; + } + + optional int32 in_memory_mode = 1 [default = 0]; // EInMemoryMode, EInMemoryMode::None by default +} + +message TReqMultiread +{ + // TODO(kiselyovp) move codecs to header + required int32 request_codec = 1; // ECodec + required int32 response_codec = 2; // ECodec + optional bytes retention_config = 3; + required uint64 timestamp = 4; + optional uint64 retention_timestamp = 10; + repeated NYT.NProto.TGuid tablet_ids = 5; + repeated NYT.NProto.TGuid cell_ids = 9; + repeated int64 mount_revisions = 6; + optional bool enable_partial_result = 7 [default=false]; + optional bool use_lookup_cache = 8 [default=false]; + + // Attachment contains wire-encoded data. +} + +message TRspMultiread +{ + // Attachment contains wire-encoded data. +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqReadDynamicStore +{ + required NYT.NProto.TGuid store_id = 2; + required NYT.NProto.TGuid tablet_id = 1; + required NYT.NProto.TGuid cell_id = 12; + optional uint64 timestamp = 3; + optional bytes lower_bound = 4; + optional bytes upper_bound = 5; + required NYT.NProto.TGuid read_session_id = 6; + optional NYT.NTableClient.NProto.TColumnFilter column_filter = 7; + optional int64 start_row_index = 8; + optional int64 end_row_index = 9; + optional int64 max_rows_per_read = 10; + + // Fail each attachment request with certain probability. + // Used for testing only. + optional float failure_probability = 11; +} + +message TRspReadDynamicStore +{ + // Data is transferred via RPC streaming. +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqPullRows +{ + required NYT.NProto.TGuid upstream_replica_id = 1; // NChaosClient::TReplicaId + + required int32 request_codec = 2; // ECodec + required int32 response_codec = 3; // ECodec + required uint64 mount_revision = 4; + required int64 max_rows_per_read = 5; + + required NYT.NProto.TGuid tablet_id = 6; + required NYT.NProto.TGuid cell_id = 7; + required NChaosClient.NProto.TReplicationProgress start_replication_progress = 8; + required uint64 upper_timestamp = 9; + optional int64 start_replication_row_index = 10; +} + +message TRspPullRows +{ + required int64 row_count = 1; + required int64 data_weight = 2; + optional int64 end_replication_row_index = 3; + required NChaosClient.NProto.TReplicationProgress end_replication_progress = 4; + + // Attachment contains wire-encoded data. +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqGetTabletInfo +{ + repeated NYT.NProto.TGuid tablet_ids = 1; + repeated NYT.NProto.TGuid cell_ids = 2; + optional bool request_errors = 3; +} + +message TReplicaInfo +{ + required NYT.NProto.TGuid replica_id = 1; + required uint64 last_replication_timestamp = 2; + required int32 mode = 3; // ETableReplicaMode + required int64 current_replication_row_index = 4; + required int64 committed_replication_row_index = 7; + optional NYT.NProto.TError replication_error = 5; + optional int32 status = 6; // ETableReplicaStatus +} + +message TTabletInfo +{ + required NYT.NProto.TGuid tablet_id = 1; + repeated TReplicaInfo replicas = 2; + required int64 total_row_count = 3; + required int64 trimmed_row_count = 4; + optional int64 delayed_lockless_row_count = 8; + optional uint64 barrier_timestamp = 5; + optional uint64 last_write_timestamp = 6; + repeated NYT.NProto.TError tablet_errors = 7; +} + +message TRspGetTabletInfo +{ + repeated TTabletInfo tablets = 2; +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqFetchTabletStores +{ + message TSubrequest + { + required NYT.NProto.TGuid tablet_id = 1; + required NYT.NProto.TGuid cell_id = 9; + optional uint64 mount_revision = 2; + repeated NYT.NChunkClient.NProto.TReadRange ranges = 3; + repeated int32 range_indices = 4; + required int32 table_index = 5; + optional int32 tablet_index = 6; + optional bool fetch_samples = 7 [default = false]; + // Compressed data size. + optional int32 data_size_between_samples = 8; + } + + repeated TSubrequest subrequests = 1; + + optional bool fetch_all_meta_extensions = 2 [default = false]; + repeated int32 extension_tags = 3; + + // Do not send dynamic stores even if @enable_dynamic_store_read is set. + optional bool omit_dynamic_stores = 4 [default = false]; +} + +message TRspFetchTabletStores +{ + message TSubresponse + { + repeated NChunkClient.NProto.TChunkSpec stores = 1; + // COMPAT(babenko): drop this later. + optional bool tablet_missing = 2 [default = false]; + optional NYT.NProto.TError error = 3; + } + + // |subresponses| contain exactly one subresponse for each subrequest + // from |subrequests| in the same order. + repeated TSubresponse subresponses = 1; + + // Attachments contain wire-encoded samples keys. For each subrequest with + // |fetch_samples| = true there are len(|ranges|) attachments with samples. +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqFetchTableRows +{ + message TOptions + { + optional NYT.NProto.TWorkloadDescriptor workload_descriptor = 1; + } + + required NYT.NProto.TGuid tablet_id = 1; + required NYT.NProto.TGuid cell_id = 2; + optional uint64 mount_revision = 3; + + // The fields below *must* be set. + // They are marked optional for potential future extensions. + optional int32 tablet_index = 4; + optional int64 row_index = 5; + optional int64 max_row_count = 6; + optional int64 max_data_weight = 7; + + optional TOptions options = 8; +} + +message TRspFetchTableRows +{ + // Attachment contains wire-encoded rows. +} + +//////////////////////////////////////////////////////////////////////////////// + +message TReqGetOrderedTabletSafeTrimRowCount +{ + message TSubrequest + { + required NYT.NProto.TGuid tablet_id = 1; + required NYT.NProto.TGuid cell_id = 2; + optional uint64 mount_revision = 3; + + required uint64 timestamp = 4; + } + + repeated TSubrequest subrequests = 1; +} + +message TRspGetOrderedTabletSafeTrimRowCount +{ + message TSubresponse + { + optional int64 safe_trim_row_count = 1; + + optional NYT.NProto.TError error = 2; + } + + // |subresponses| contain exactly one subresponse for each subrequest from |subrequests| in the same order. + repeated TSubresponse subresponses = 1; +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/library/query/proto/ya.make b/yt/yt/library/query/proto/ya.make new file mode 100644 index 0000000000..eae1d1a26a --- /dev/null +++ b/yt/yt/library/query/proto/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +PROTO_NAMESPACE(yt) + +SRCS( + query.proto + query_service.proto + functions_cache.proto +) + +PEERDIR( + yt/yt/client +) + +END() diff --git a/yt/yt/library/query/row_comparer_api/row_comparer_generator.cpp b/yt/yt/library/query/row_comparer_api/row_comparer_generator.cpp new file mode 100644 index 0000000000..9a62e34c99 --- /dev/null +++ b/yt/yt/library/query/row_comparer_api/row_comparer_generator.cpp @@ -0,0 +1,21 @@ +#include "row_comparer_generator.h" + +namespace NYT::NQueryClient { + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK TCGKeyComparers GenerateComparers(TRange<EValueType> /*keyColumnTypes*/) +{ + // Proper implementation resides in yt/yt/library/query/row_comparer/row_comparer_generator.cpp. + YT_ABORT(); +} + +Y_WEAK IRowComparerProviderPtr CreateRowComparerProvider(TSlruCacheConfigPtr /*config*/) +{ + // Proper implementation resides in yt/yt/library/query/row_comparer/row_comparer_generator.cpp. + YT_ABORT(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/row_comparer_api/row_comparer_generator.h b/yt/yt/library/query/row_comparer_api/row_comparer_generator.h new file mode 100644 index 0000000000..bf718e27b6 --- /dev/null +++ b/yt/yt/library/query/row_comparer_api/row_comparer_generator.h @@ -0,0 +1,50 @@ +#pragma once + +#include <yt/yt/core/actions/callback.h> + +#include <yt/yt/client/table_client/schema.h> +#include <yt/yt/client/table_client/unversioned_row.h> + +#include <yt/yt/client/tablet_client/dynamic_value.h> + +//////////////////////////////////////////////////////////////////////////////// + +namespace NYT::NQueryClient { + +using NTableClient::EValueType; +using NTableClient::TUnversionedValue; +using NTableClient::TUnversionedRow; +using NTabletClient::TDynamicValueData; + +//////////////////////////////////////////////////////////////////////////////// + +using TDDComparerSignature = int(ui32, const TDynamicValueData*, ui32, const TDynamicValueData*); +using TDUComparerSignature = int(ui32, const TDynamicValueData*, const TUnversionedValue*, int); +using TUUComparerSignature = int(const TUnversionedValue*, const TUnversionedValue*, i32); + +struct TCGKeyComparers +{ + TCallback<TDDComparerSignature> DDComparer; + TCallback<TDUComparerSignature> DUComparer; + TCallback<TUUComparerSignature> UUComparer; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TCGKeyComparers GenerateComparers(TRange<EValueType> keyColumnTypes); + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(IRowComparerProvider) + +struct IRowComparerProvider + : public virtual TRefCounted +{ + virtual TCGKeyComparers Get(NTableClient::TKeyColumnTypes keyColumnTypes) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(IRowComparerProvider) + +IRowComparerProviderPtr CreateRowComparerProvider(TSlruCacheConfigPtr config); + +} // namespace NYT::NQueryClient diff --git a/yt/yt/library/query/row_comparer_api/ya.make b/yt/yt/library/query/row_comparer_api/ya.make new file mode 100644 index 0000000000..7ab7dcfd3b --- /dev/null +++ b/yt/yt/library/query/row_comparer_api/ya.make @@ -0,0 +1,14 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + row_comparer_generator.cpp +) + +PEERDIR( + yt/yt/core + yt/yt/client +) + +END() diff --git a/yt/yt/library/random/bernoulli_sampler.cpp b/yt/yt/library/random/bernoulli_sampler.cpp new file mode 100644 index 0000000000..63dd9c8d2f --- /dev/null +++ b/yt/yt/library/random/bernoulli_sampler.cpp @@ -0,0 +1,71 @@ +#include "bernoulli_sampler.h" + +#include <yt/yt/core/misc/serialize.h> + +#include <library/cpp/yt/farmhash/farm_hash.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TBernoulliSampler::TBernoulliSampler( + std::optional<double> samplingRate, + std::optional<ui64> seed) +{ + if (samplingRate) { + SamplingRate_ = samplingRate; + Seed_ = seed; + Distribution_ = std::bernoulli_distribution(*SamplingRate_); + if (seed) { + Generator_ = std::mt19937(*seed); + } + } +} + +bool TBernoulliSampler::Sample() +{ + if (!SamplingRate_) { + return true; + } + + return Distribution_(Generator_); +} + +bool TBernoulliSampler::Sample(ui64 salt) +{ + if (!SamplingRate_) { + return true; + } + + std::minstd_rand0 generator(FarmFingerprint(salt ^ Seed_.value_or(0))); + return Distribution_(generator); +} + +void TBernoulliSampler::Persist(const TStreamPersistenceContext& context) +{ + using NYT::Persist; + + Persist(context, SamplingRate_); + Persist(context, Seed_); + // TODO(max42): Understand which type properties should make this possible + // and fix TPodSerializer instead of doing this utter garbage. + #define SERIALIZE_AS_POD(field) do { \ + std::vector<char> bytes; \ + if (context.IsLoad()) { \ + Persist(context, bytes); \ + YT_VERIFY(sizeof(field) == bytes.size()); \ + memcpy(&field, bytes.data(), sizeof(field)); \ + } else { \ + bytes.resize(sizeof(field)); \ + memcpy(bytes.data(), &field, sizeof(field)); \ + Persist(context, bytes); \ + } \ + } while (0) + SERIALIZE_AS_POD(Generator_); + SERIALIZE_AS_POD(Distribution_); + #undef SERIALIZE_AS_POD +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/random/bernoulli_sampler.h b/yt/yt/library/random/bernoulli_sampler.h new file mode 100644 index 0000000000..f256aa20ff --- /dev/null +++ b/yt/yt/library/random/bernoulli_sampler.h @@ -0,0 +1,37 @@ +#pragma once + +#include <yt/yt/core/misc/public.h> + +#include <random> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +//! A simple helper that is used in sampling routines. +//! It is deterministic and persistable (as POD). +class TBernoulliSampler +{ +public: + explicit TBernoulliSampler( + std::optional<double> samplingRate = std::nullopt, + std::optional<ui64> seed = std::nullopt); + + bool Sample(); + + //! Result of this sampling depends on `salt' and `seed' only + //! and not depends on previous sample calls. + bool Sample(ui64 salt); + + void Persist(const TStreamPersistenceContext& context); + +private: + std::optional<double> SamplingRate_; + std::optional<ui64> Seed_; + std::mt19937 Generator_; + std::bernoulli_distribution Distribution_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/random/ya.make b/yt/yt/library/random/ya.make new file mode 100644 index 0000000000..1e4b3ffff0 --- /dev/null +++ b/yt/yt/library/random/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + bernoulli_sampler.cpp +) + +PEERDIR( + yt/yt/core +) + +END() diff --git a/yt/yt/library/tvm/service/config.cpp b/yt/yt/library/tvm/service/config.cpp new file mode 100644 index 0000000000..f1671e8650 --- /dev/null +++ b/yt/yt/library/tvm/service/config.cpp @@ -0,0 +1,63 @@ +#include "config.h" + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +void TTvmServiceConfig::Register(TRegistrar registrar) +{ + registrar.Parameter("use_tvm_tool", &TThis::UseTvmTool) + .Default(false); + registrar.Parameter("client_self_id", &TThis::ClientSelfId) + .Default(0); + registrar.Parameter("client_disk_cache_dir", &TThis::ClientDiskCacheDir) + .Optional(); + registrar.Parameter("tvm_host", &TThis::TvmHost) + .Optional(); + registrar.Parameter("tvm_port", &TThis::TvmPort) + .Optional(); + registrar.Parameter("client_enable_user_ticket_checking", &TThis::ClientEnableUserTicketChecking) + .Default(false); + registrar.Parameter("client_blackbox_env", &TThis::ClientBlackboxEnv) + .Default("ProdYateam"); + registrar.Parameter("client_enable_service_ticket_fetching", &TThis::ClientEnableServiceTicketFetching) + .Default(false); + registrar.Parameter("client_self_secret", &TThis::ClientSelfSecret) + .Optional(); + registrar.Parameter("client_self_secret_path", &TThis::ClientSelfSecretPath) + .Optional(); + registrar.Parameter("client_self_secret_env", &TThis::ClientSelfSecretEnv) + .Optional(); + registrar.Parameter("client_dst_map", &TThis::ClientDstMap) + .Optional(); + registrar.Parameter("client_enable_service_ticket_checking", &TThis::ClientEnableServiceTicketChecking) + .Default(false); + + registrar.Parameter("enable_ticket_parse_cache", &TThis::EnableTicketParseCache) + .Default(false); + registrar.Parameter("ticket_checking_cache_timeout", &TThis::TicketCheckingCacheTimeout) + .Default(TDuration::Minutes(1)); + + registrar.Parameter("tvm_tool_self_alias", &TThis::TvmToolSelfAlias) + .Optional(); + registrar.Parameter("tvm_tool_port", &TThis::TvmToolPort) + .Optional(); + registrar.Parameter("tvm_tool_auth_token", &TThis::TvmToolAuthToken) + .Optional(); + + registrar.Parameter("enable_mock", &TThis::EnableMock) + .Default(false); + registrar.Parameter("require_mock_secret", &TThis::RequireMockSecret) + .Default(true); + + registrar.Postprocessor([] (TThis* config) { + if (config->ClientSelfSecretEnv && config->ClientSelfSecretPath) { + THROW_ERROR_EXCEPTION("Options \"client_self_secret_env\", \"client_self_secret_path\" " + "cannot be used together"); + } + }); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/tvm/service/config.h b/yt/yt/library/tvm/service/config.h new file mode 100644 index 0000000000..9f3c8b5af6 --- /dev/null +++ b/yt/yt/library/tvm/service/config.h @@ -0,0 +1,69 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/ytree/yson_struct.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +class TTvmServiceConfig + : public virtual NYTree::TYsonStruct +{ +public: + bool UseTvmTool; + + // TvmClient settings + TTvmId ClientSelfId = 0; + std::optional<TString> ClientDiskCacheDir; + + std::optional<TString> TvmHost; + std::optional<ui16> TvmPort; + + bool ClientEnableUserTicketChecking = false; + TString ClientBlackboxEnv; + + bool ClientEnableServiceTicketFetching = false; + + //! Do not use this option as the plaintext value of secret may be exposed via service orchid or somehow else. + std::optional<TString> ClientSelfSecret; + + //! Name of env variable with TVM secret. Used if ClientSelfSecret is unset. + std::optional<TString> ClientSelfSecretEnv; + + //! Path to TVM secret. Used if ClientSelfSecret is unset. + std::optional<TString> ClientSelfSecretPath; + + THashMap<TString, ui32> ClientDstMap; + + bool ClientEnableServiceTicketChecking = false; + + //! If true, then checked tickets are cached, allowing us to speed up checking. + bool EnableTicketParseCache = false; + TDuration TicketCheckingCacheTimeout; + + TString TvmToolSelfAlias; + //! If not specified, get port from env variable `DEPLOY_TVM_TOOL_URL`. + int TvmToolPort = 0; + //! Do not use this option in production. + //! If not specified, get token from env variable `TVMTOOL_LOCAL_AUTHTOKEN`. + std::optional<TString> TvmToolAuthToken; + + //! For testing only. If enabled, then a mock instead of a real TVM service will be used. + bool EnableMock = false; + + //! If EnableMock and RequireMockSecret is true, then ensures that ClientSelfSecret is equal to + //! "SecretPrefix-" + ToString(ClientSelfId). + bool RequireMockSecret = true; + + REGISTER_YSON_STRUCT(TTvmServiceConfig); + + static void Register(TRegistrar registrar); +}; + +DEFINE_REFCOUNTED_TYPE(TTvmServiceConfig) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/tvm/service/public.h b/yt/yt/library/tvm/service/public.h new file mode 100644 index 0000000000..89b3c4e60d --- /dev/null +++ b/yt/yt/library/tvm/service/public.h @@ -0,0 +1,33 @@ +#pragma once + +#include <library/cpp/yt/memory/ref_counted.h> + +#include <util/generic/hash_set.h> +#include <util/generic/string.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_CLASS(TTvmServiceConfig) +DECLARE_REFCOUNTED_STRUCT(ITvmService) +DECLARE_REFCOUNTED_STRUCT(IDynamicTvmService) + +//////////////////////////////////////////////////////////////////////////////// + +struct TParsedTicket +{ + ui64 DefaultUid; + THashSet<TString> Scopes; +}; + +using TTvmId = ui64; + +struct TParsedServiceTicket +{ + TTvmId TvmId; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/tvm/service/tvm_service.h b/yt/yt/library/tvm/service/tvm_service.h new file mode 100644 index 0000000000..7302053e62 --- /dev/null +++ b/yt/yt/library/tvm/service/tvm_service.h @@ -0,0 +1,77 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/tvm/public.h> + +#include <yt/yt/library/profiling/sensor.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +struct ITvmService + : public virtual TRefCounted +{ + //! Our TVM id. + virtual TTvmId GetSelfTvmId() = 0; + + //! Get TVM service ticket from us to serviceAlias. Service mapping must be in config. + //! Throws on failure. + virtual TString GetServiceTicket(const TString& serviceAlias) = 0; + + //! Get TVM service ticket from us to serviceId. Service ID must be known (either during + //! construction or explicitly added in dynamic service). + //! Throws on failure. + virtual TString GetServiceTicket(TTvmId serviceId) = 0; + + //! Decode user ticket contents. Throws on failure. + virtual TParsedTicket ParseUserTicket(const TString& ticket) = 0; + + //! Decode service ticket contents. Throws on failure. + virtual TParsedServiceTicket ParseServiceTicket(const TString& ticket) = 0; +}; + +struct IDynamicTvmService + : public virtual ITvmService +{ +public: + //! Add destination service IDs to fetch. It is possible to add the same ID multiple + //! times, though it will be added only once really. + virtual void AddDestinationServiceIds(const std::vector<TTvmId>& serviceIds) = 0; +}; + +DEFINE_REFCOUNTED_TYPE(ITvmService) +DEFINE_REFCOUNTED_TYPE(IDynamicTvmService) + +//////////////////////////////////////////////////////////////////////////////// + +ITvmServicePtr CreateTvmService( + TTvmServiceConfigPtr config, + NProfiling::TProfiler profiler = {}); + +IDynamicTvmServicePtr CreateDynamicTvmService( + TTvmServiceConfigPtr config, + NProfiling::TProfiler profiler = {}); + +//////////////////////////////////////////////////////////////////////////////// + +IServiceTicketAuthPtr CreateServiceTicketAuth( + ITvmServicePtr tvmService, + TTvmId dstServiceId); + +IServiceTicketAuthPtr CreateServiceTicketAuth( + ITvmServicePtr tvmService, + TString dstServiceAlias); + +//////////////////////////////////////////////////////////////////////////////// + +TStringBuf RemoveTicketSignature(TStringBuf ticketBody); + +//////////////////////////////////////////////////////////////////////////////// + +bool IsDummyTvmServiceImplementation(); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/tvm/service/tvm_service_common.cpp b/yt/yt/library/tvm/service/tvm_service_common.cpp new file mode 100644 index 0000000000..f031da6bfb --- /dev/null +++ b/yt/yt/library/tvm/service/tvm_service_common.cpp @@ -0,0 +1,55 @@ +#include "tvm_service.h" + +#include <library/cpp/yt/memory/new.h> + +#include <yt/yt/library/tvm/tvm_base.h> + +namespace NYT::NAuth { + +//////////////////////////////////////////////////////////////////////////////// + +template <typename TId> +class TServiceTicketAuth + : public IServiceTicketAuth +{ +public: + TServiceTicketAuth( + ITvmServicePtr tvmService, + TId destServiceId) + : TvmService_(std::move(tvmService)) + , DstServiceId_(std::move(destServiceId)) + { } + + TString IssueServiceTicket() override + { + return TvmService_->GetServiceTicket(DstServiceId_); + } + +private: + const ITvmServicePtr TvmService_; + const TId DstServiceId_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +IServiceTicketAuthPtr CreateServiceTicketAuth( + ITvmServicePtr tvmService, + TTvmId dstServiceId) +{ + YT_VERIFY(tvmService); + + return New<TServiceTicketAuth<TTvmId>>(std::move(tvmService), dstServiceId); +} + +IServiceTicketAuthPtr CreateServiceTicketAuth( + ITvmServicePtr tvmService, + TString dstServiceAlias) +{ + YT_VERIFY(tvmService); + + return New<TServiceTicketAuth<TString>>(std::move(tvmService), std::move(dstServiceAlias)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/tvm/service/tvm_service_dummy.cpp b/yt/yt/library/tvm/service/tvm_service_dummy.cpp new file mode 100644 index 0000000000..f592d3a198 --- /dev/null +++ b/yt/yt/library/tvm/service/tvm_service_dummy.cpp @@ -0,0 +1,40 @@ +#include "tvm_service.h" +#include "config.h" + +namespace NYT::NAuth { + +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK ITvmServicePtr CreateTvmService( + TTvmServiceConfigPtr /*config*/, + TProfiler /*profiler*/) +{ + THROW_ERROR_EXCEPTION("Not implemented"); +} + +Y_WEAK IDynamicTvmServicePtr CreateDynamicTvmService( + TTvmServiceConfigPtr /*config*/, + TProfiler /*profiler*/) +{ + THROW_ERROR_EXCEPTION("Not implemented"); +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK TStringBuf RemoveTicketSignature(TStringBuf /*ticketBody*/) +{ + THROW_ERROR_EXCEPTION("Not implemented"); +} + +//////////////////////////////////////////////////////////////////////////////// + +Y_WEAK bool IsDummyTvmServiceImplementation() +{ + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NAuth diff --git a/yt/yt/library/tvm/service/ya.make b/yt/yt/library/tvm/service/ya.make new file mode 100644 index 0000000000..e65db3983f --- /dev/null +++ b/yt/yt/library/tvm/service/ya.make @@ -0,0 +1,34 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + config.cpp + tvm_service_common.cpp + tvm_service_dummy.cpp +) + +PEERDIR( + library/cpp/yt/memory + library/cpp/yt/logging + yt/yt/core +) + +IF(NOT OPENSOURCE) + SRCS( + GLOBAL tvm_service_yandex.cpp + ) + + PEERDIR( + library/cpp/tvmauth + library/cpp/tvmauth/client + library/cpp/tvmauth/client/misc/api/dynamic_dst + yt/yt/library/tvm + ) +ENDIF() + +END() + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/vector_hdrf/fair_share_update.cpp b/yt/yt/library/vector_hdrf/fair_share_update.cpp new file mode 100644 index 0000000000..17c5ab14b6 --- /dev/null +++ b/yt/yt/library/vector_hdrf/fair_share_update.cpp @@ -0,0 +1,1524 @@ +#include "fair_share_update.h" +#include "resource_helpers.h" +#include "private.h" + +#include <yt/yt/core/ytree/fluent.h> + +// TODO(ignat): move finally to library +#include <yt/yt/core/misc/finally.h> + +#include <yt/yt/library/numeric/binary_search.h> + +#include <yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.h> + +namespace NYT::NVectorHdrf { + +using namespace NProfiling; + +//////////////////////////////////////////////////////////////////////////////// + +TString ToString(const TDetailedFairShare& detailedFairShare) +{ + return ToStringViaBuilder(detailedFairShare); +} + +void FormatValue(TStringBuilderBase* builder, const TDetailedFairShare& detailedFairShare, TStringBuf /* format */) +{ + builder->AppendFormat( + "{StrongGuarantee: %.6g, IntegralGuarantee: %.6g, WeightProportional: %.6g}", + detailedFairShare.StrongGuarantee, + detailedFairShare.IntegralGuarantee, + detailedFairShare.WeightProportional); +} + +//////////////////////////////////////////////////////////////////////////////// + +TResourceVector TSchedulableAttributes::GetGuaranteeShare() const +{ + return StrongGuaranteeShare + ProposedIntegralShare; +} + +void TSchedulableAttributes::SetFairShare(const TResourceVector& fairShare) +{ + FairShare.Total = fairShare; + FairShare.StrongGuarantee = TResourceVector::Min(fairShare, StrongGuaranteeShare); + FairShare.IntegralGuarantee = TResourceVector::Min(fairShare - FairShare.StrongGuarantee, ProposedIntegralShare); + FairShare.WeightProportional = fairShare - FairShare.StrongGuarantee - FairShare.IntegralGuarantee; +} + +//////////////////////////////////////////////////////////////////////////////// + +TResourceVector AdjustProposedIntegralShare( + const TResourceVector& limitsShare, + const TResourceVector& strongGuaranteeShare, + TResourceVector proposedIntegralShare) +{ + auto guaranteeShare = strongGuaranteeShare + proposedIntegralShare; + if (!Dominates(limitsShare, guaranteeShare)) { + YT_VERIFY(Dominates(limitsShare + TResourceVector::SmallEpsilon(), guaranteeShare)); + YT_VERIFY(Dominates(limitsShare, strongGuaranteeShare)); + + proposedIntegralShare = limitsShare - strongGuaranteeShare; + for (auto resource : TEnumTraits<EJobResourceType>::GetDomainValues()) { + constexpr int MaxAdjustmentIterationCount = 32; + + // NB(eshcherbin): Always should be no more than a single iteration, but to remove my paranoia I've bounded iteration count. + int iterationCount = 0; + while (limitsShare[resource] < strongGuaranteeShare[resource] + proposedIntegralShare[resource] && + iterationCount < MaxAdjustmentIterationCount) + { + proposedIntegralShare[resource] = std::nextafter(proposedIntegralShare[resource], 0.0); + ++iterationCount; + } + } + } + + return proposedIntegralShare; +} + +//////////////////////////////////////////////////////////////////////////////// + +void TElement::DetermineEffectiveStrongGuaranteeResources(TFairShareUpdateContext* /* context */) +{ } + +bool TElement::IsOperation() const +{ + return false; +} + +bool TElement::IsRoot() const +{ + return false; +} + +TPool* TElement::AsPool() +{ + return dynamic_cast<TPool*>(this); +} + +TOperationElement* TElement::AsOperation() +{ + return dynamic_cast<TOperationElement*>(this); +} + +void TElement::AdjustStrongGuarantees(const TFairShareUpdateContext* /* context */) +{ } + +void TElement::InitIntegralPoolLists(TFairShareUpdateContext* /* context */) +{ } + +void TElement::UpdateAttributes(const TFairShareUpdateContext* context) +{ + Attributes().LimitsShare = ComputeLimitsShare(context); + YT_VERIFY(Dominates(TResourceVector::Ones(), Attributes().LimitsShare)); + YT_VERIFY(Dominates(Attributes().LimitsShare, TResourceVector::Zero())); + + Attributes().StrongGuaranteeShare = TResourceVector::FromJobResources(Attributes().EffectiveStrongGuaranteeResources, context->TotalResourceLimits); + + // NB: We need to ensure that |FairShareByFitFactor_(0.0)| is less than or equal to |LimitsShare| so that there exists a feasible fit factor and |MaxFitFactorBySuggestion_| is well defined. + // To achieve this we limit |StrongGuarantee| with |LimitsShare| here, and later adjust the sum of children's |StrongGuarantee| to fit into the parent's |StrongGuarantee|. + // This way children can't ask more than parent's |LimitsShare| when given a zero suggestion. + Attributes().StrongGuaranteeShare = TResourceVector::Min(Attributes().StrongGuaranteeShare, Attributes().LimitsShare); + + if (GetResourceUsageAtUpdate() == TJobResources()) { + Attributes().DominantResource = GetDominantResource(GetResourceDemand(), context->TotalResourceLimits); + } else { + Attributes().DominantResource = GetDominantResource(GetResourceUsageAtUpdate(), context->TotalResourceLimits); + } + + Attributes().UsageShare = TResourceVector::FromJobResources(GetResourceUsageAtUpdate(), context->TotalResourceLimits); + Attributes().DemandShare = TResourceVector::FromJobResources(GetResourceDemand(), context->TotalResourceLimits); + YT_VERIFY(Dominates(Attributes().DemandShare, Attributes().UsageShare)); +} + +void TElement::UpdateCumulativeAttributes(TFairShareUpdateContext* context) +{ + UpdateAttributes(context); +} + +void TElement::CheckFairShareFeasibility() const +{ + const auto& demandShare = Attributes().DemandShare; + const auto& fairShare = Attributes().FairShare.Total; + bool isFairShareSignificantlyGreaterThanDemandShare = + !Dominates(demandShare + TResourceVector::SmallEpsilon(), fairShare); + if (isFairShareSignificantlyGreaterThanDemandShare) { + std::vector<EJobResourceType> significantlyGreaterResources; + for (auto resource : TEnumTraits<EJobResourceType>::GetDomainValues()) { + if (demandShare[resource] + RatioComputationPrecision <= fairShare[resource]) { + significantlyGreaterResources.push_back(resource); + } + } + + const auto& Logger = GetLogger(); + YT_LOG_WARNING( + "Fair share is significantly greater than demand share " + "(FairShare: %v, DemandShare: %v, SignificantlyGreaterResources: %v)", + fairShare, + demandShare, + significantlyGreaterResources); + } +} + +TResourceVector TElement::ComputeLimitsShare(const TFairShareUpdateContext* context) const +{ + return TResourceVector::FromJobResources(Min(GetResourceLimits(), context->TotalResourceLimits), context->TotalResourceLimits); +} + +void TElement::ResetFairShareFunctions() +{ + AreFairShareFunctionsPrepared_ = false; +} + +void TElement::PrepareFairShareFunctions(TFairShareUpdateContext* context) +{ + if (AreFairShareFunctionsPrepared_) { + return; + } + + { + TWallTimer timer; + PrepareFairShareByFitFactor(context); + context->PrepareFairShareByFitFactorTotalTime += timer.GetElapsedCpuTime(); + } + YT_VERIFY(FairShareByFitFactor_.has_value()); + NDetail::VerifyNondecreasing(*FairShareByFitFactor_, GetLogger()); + YT_VERIFY(FairShareByFitFactor_->IsTrimmed()); + + { + TWallTimer timer; + PrepareMaxFitFactorBySuggestion(context); + context->PrepareMaxFitFactorBySuggestionTotalTime += timer.GetElapsedCpuTime(); + } + YT_VERIFY(MaxFitFactorBySuggestion_.has_value()); + YT_VERIFY(MaxFitFactorBySuggestion_->LeftFunctionBound() == 0.0); + YT_VERIFY(MaxFitFactorBySuggestion_->RightFunctionBound() == 1.0); + NDetail::VerifyNondecreasing(*MaxFitFactorBySuggestion_, GetLogger()); + YT_VERIFY(MaxFitFactorBySuggestion_->IsTrimmed()); + + { + TWallTimer timer; + FairShareBySuggestion_ = FairShareByFitFactor_->Compose(*MaxFitFactorBySuggestion_); + context->ComposeTotalTime += timer.GetElapsedCpuTime(); + } + YT_VERIFY(FairShareBySuggestion_.has_value()); + YT_VERIFY(FairShareBySuggestion_->LeftFunctionBound() == 0.0); + YT_VERIFY(FairShareBySuggestion_->RightFunctionBound() == 1.0); + NDetail::VerifyNondecreasing(*FairShareBySuggestion_, GetLogger()); + YT_VERIFY(FairShareBySuggestion_->IsTrimmed()); + + { + TWallTimer timer; + *FairShareBySuggestion_ = NDetail::CompressFunction(*FairShareBySuggestion_, NDetail::CompressFunctionEpsilon); + context->CompressFunctionTotalTime += timer.GetElapsedCpuTime(); + } + NDetail::VerifyNondecreasing(*FairShareBySuggestion_, GetLogger()); + + AreFairShareFunctionsPrepared_ = true; +} + +void TElement::PrepareMaxFitFactorBySuggestion(TFairShareUpdateContext* context) +{ + YT_VERIFY(FairShareByFitFactor_); + + std::vector<TScalarPiecewiseLinearFunction> mffForComponents; // Mff stands for "MaxFitFactor". + + for (int r = 0; r < ResourceCount; r++) { + // Fsbff stands for "FairShareByFitFactor". + auto fsbffComponent = NDetail::ExtractComponent(r, *FairShareByFitFactor_); + YT_VERIFY(fsbffComponent.IsTrimmed()); + + double limit = Attributes().LimitsShare[r]; + // NB(eshcherbin): We definitely cannot use a precise inequality here. See YT-13864. + YT_VERIFY(fsbffComponent.LeftFunctionValue() < limit + RatioComputationPrecision); + limit = std::min(std::max(limit, fsbffComponent.LeftFunctionValue()), fsbffComponent.RightFunctionValue()); + + double guarantee = Attributes().GetGuaranteeShare()[r]; + guarantee = std::min(std::max(guarantee, fsbffComponent.LeftFunctionValue()), limit); + + auto mffForComponent = std::move(fsbffComponent) + .Transpose() + .Narrow(guarantee, limit) + .TrimLeft() + .Shift(/* deltaArgument */ -guarantee) + .ExtendRight(/* newRightBound */ 1.0) + .Trim(); + mffForComponents.push_back(std::move(mffForComponent)); + } + + { + TWallTimer timer; + MaxFitFactorBySuggestion_ = PointwiseMin(mffForComponents); + context->PointwiseMinTotalTime += timer.GetElapsedCpuTime(); + } +} + +TResourceVector TElement::GetVectorSuggestion(double suggestion) const +{ + auto vectorSuggestion = TResourceVector::FromDouble(suggestion) + Attributes().StrongGuaranteeShare; + vectorSuggestion = TResourceVector::Min(vectorSuggestion, Attributes().LimitsShare); + return vectorSuggestion; +} + +void TElement::DistributeFreeVolume() +{ } + +TResourceVector TElement::GetTotalTruncatedFairShare() const +{ + return TotalTruncatedFairShare_; +} + +//////////////////////////////////////////////////////////////////////////////// + +void TCompositeElement::DetermineEffectiveStrongGuaranteeResources(TFairShareUpdateContext* context) +{ + TJobResources totalExplicitChildrenGuaranteeResources; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + + auto& childEffectiveGuaranteeResources = child->Attributes().EffectiveStrongGuaranteeResources; + childEffectiveGuaranteeResources = ToJobResources( + *child->GetStrongGuaranteeResourcesConfig(), + /* defaultValue */ {}); + totalExplicitChildrenGuaranteeResources += childEffectiveGuaranteeResources; + } + + const auto& effectiveStrongGuaranteeResources = Attributes().EffectiveStrongGuaranteeResources; + if (!IsRoot() && !Dominates(effectiveStrongGuaranteeResources, totalExplicitChildrenGuaranteeResources)) { + const auto& Logger = GetLogger(); + // NB: This should never happen because we validate the guarantees at master. + YT_LOG_WARNING( + "Total children's explicit strong guarantees exceeds the effective strong guarantee at pool" + "(EffectiveStrongGuarantees: %v, TotalExplicitChildrenGuarantees: %v)", + effectiveStrongGuaranteeResources, + totalExplicitChildrenGuaranteeResources); + } + + DetermineImplicitEffectiveStrongGuaranteeResources(totalExplicitChildrenGuaranteeResources, context); + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + GetChild(childIndex)->DetermineEffectiveStrongGuaranteeResources(context); + } +} + +void TCompositeElement::DetermineImplicitEffectiveStrongGuaranteeResources( + const TJobResources& totalExplicitChildrenGuaranteeResources, + TFairShareUpdateContext* context) +{ + const auto& effectiveStrongGuaranteeResources = Attributes().EffectiveStrongGuaranteeResources; + auto residualGuaranteeResources = Max(effectiveStrongGuaranteeResources - totalExplicitChildrenGuaranteeResources, TJobResources{}); + auto mainResourceType = context->MainResource; + auto parentMainResourceGuarantee = GetResource(effectiveStrongGuaranteeResources, mainResourceType); + auto doDetermineImplicitGuarantees = [&] (const auto TJobResourcesConfig::* resourceDataMember, EJobResourceType resourceType) { + if (resourceType == mainResourceType) { + return; + } + + std::vector<std::optional<double>> implicitGuarantees; + implicitGuarantees.resize(GetChildCount()); + + auto residualGuarantee = GetResource(residualGuaranteeResources, resourceType); + auto parentResourceGuarantee = GetResource(effectiveStrongGuaranteeResources, resourceType); + double totalImplicitGuarantee = 0.0; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + if (child->GetStrongGuaranteeResourcesConfig()->*resourceDataMember) { + continue; + } + + auto childMainResourceGuarantee = GetResource(child->Attributes().EffectiveStrongGuaranteeResources, mainResourceType); + double mainResourceRatio = parentMainResourceGuarantee > 0 + ? childMainResourceGuarantee / parentMainResourceGuarantee + : 0.0; + + auto& childImplicitGuarantee = implicitGuarantees[childIndex]; + childImplicitGuarantee = mainResourceRatio * parentResourceGuarantee; + totalImplicitGuarantee += *childImplicitGuarantee; + } + + // NB: It is possible to overcommit guarantees at the first level of the tree, so we don't want to do + // additional checks and rescaling. Instead, we handle this later when we adjust |StrongGuaranteeShare|. + if (!IsRoot() && totalImplicitGuarantee > residualGuarantee) { + auto scalingFactor = residualGuarantee / totalImplicitGuarantee; + for (auto& childImplicitGuarantee : implicitGuarantees) { + if (childImplicitGuarantee) { + *childImplicitGuarantee *= scalingFactor; + } + } + } + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + if (const auto& childImplicitGuarantee = implicitGuarantees[childIndex]) { + SetResource(child->Attributes().EffectiveStrongGuaranteeResources, resourceType, *childImplicitGuarantee); + } + } + }; + + TJobResourcesConfig::ForEachResource(doDetermineImplicitGuarantees); +} + +void TCompositeElement::InitIntegralPoolLists(TFairShareUpdateContext* context) +{ + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + GetChild(childIndex)->InitIntegralPoolLists(context); + } +} + +void TCompositeElement::UpdateCumulativeAttributes(TFairShareUpdateContext* context) +{ + Attributes().BurstRatio = GetSpecifiedBurstRatio(); + Attributes().TotalBurstRatio = Attributes().BurstRatio; + Attributes().ResourceFlowRatio = GetSpecifiedResourceFlowRatio(); + Attributes().TotalResourceFlowRatio = Attributes().ResourceFlowRatio; + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + child->UpdateCumulativeAttributes(context); + + Attributes().TotalResourceFlowRatio += child->Attributes().TotalResourceFlowRatio; + Attributes().TotalBurstRatio += child->Attributes().TotalBurstRatio; + } + + TElement::UpdateCumulativeAttributes(context); + + if (GetMode() == ESchedulingMode::Fifo) { + PrepareFifoPool(); + } +} + +void TCompositeElement::PrepareFifoPool() +{ + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + YT_VERIFY(GetChild(childIndex)->IsOperation()); + } + + SortedChildren_.clear(); + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + SortedChildren_.push_back(GetChild(childIndex)); + } + + std::sort( + begin(SortedChildren_), + end(SortedChildren_), + std::bind( + &TCompositeElement::HasHigherPriorityInFifoMode, + this, + std::placeholders::_1, + std::placeholders::_2)); + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + SortedChildren_[childIndex]->Attributes().FifoIndex = childIndex; + } +} + +void TCompositeElement::AdjustStrongGuarantees(const TFairShareUpdateContext* context) +{ + const auto& Logger = GetLogger(); + + TResourceVector totalPoolChildrenStrongGuaranteeShare; + TResourceVector totalChildrenStrongGuaranteeShare; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + totalChildrenStrongGuaranteeShare += child->Attributes().StrongGuaranteeShare; + + if (!child->IsOperation()) { + totalPoolChildrenStrongGuaranteeShare += child->Attributes().StrongGuaranteeShare; + } + } + + if (!Dominates(Attributes().StrongGuaranteeShare, totalPoolChildrenStrongGuaranteeShare)) { + // Drop strong guarantee shares of operations, adjust strong guarantee shares of pools. + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + if (child->IsOperation()) { + child->Attributes().StrongGuaranteeShare = TResourceVector::Zero(); + } + } + + // Use binary search instead of division to avoid problems with precision. + ComputeByFitting( + /* getter */ [&] (double fitFactor, const TElement* child) -> TResourceVector { + return child->Attributes().StrongGuaranteeShare * fitFactor; + }, + /* setter */ [&] (TElement* child, const TResourceVector& value) { + YT_LOG_DEBUG("Adjusting strong guarantee shares (ChildId: %v, OldStrongGuaranteeShare: %v, NewStrongGuaranteeShare: %v)", + child->GetId(), + child->Attributes().StrongGuaranteeShare, + value); + child->Attributes().StrongGuaranteeShare = value; + }, + /* maxSum */ Attributes().StrongGuaranteeShare); + } else if (!Dominates(Attributes().StrongGuaranteeShare, totalChildrenStrongGuaranteeShare)) { + // Adjust strong guarantee shares of operations, preserve strong guarantee shares of pools. + ComputeByFitting( + /* getter */ [&] (double fitFactor, const TElement* child) -> TResourceVector { + if (child->IsOperation()) { + return child->Attributes().StrongGuaranteeShare * fitFactor; + } else { + return child->Attributes().StrongGuaranteeShare; + } + }, + /* setter */ [&] (TElement* child, const TResourceVector& value) { + YT_LOG_DEBUG("Adjusting string guarantee shares (ChildId: %v, OldStrongGuaranteeShare: %v, NewStrongGuaranteeShare: %v)", + child->GetId(), + child->Attributes().StrongGuaranteeShare, + value); + child->Attributes().StrongGuaranteeShare = value; + }, + /* maxSum */ Attributes().StrongGuaranteeShare); + } + + if (IsRoot()) { + Attributes().PromisedFairShare = TResourceVector::FromJobResources(context->TotalResourceLimits, context->TotalResourceLimits); + Attributes().EstimatedGuaranteeShare = Attributes().StrongGuaranteeShare; + } + + auto computeGuaranteeFairShare = [&] (TResourceVector TSchedulableAttributes::* estimatedGuaranteeFairShare) { + double weightSum = 0.0; + auto undistributedEstimatedGuaranteeFairShare = Attributes().*estimatedGuaranteeFairShare; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + weightSum += child->GetWeight(); + + // NB: Sum of total strong guarantee share and total resource flow can be greater than total resource limits. This results in a scheduler alert. + // However, no additional adjustment is done so we need to handle this case here as well. + child->Attributes().*estimatedGuaranteeFairShare = TResourceVector::Min( + child->Attributes().StrongGuaranteeShare + TResourceVector::FromDouble(child->Attributes().TotalResourceFlowRatio), + undistributedEstimatedGuaranteeFairShare); + undistributedEstimatedGuaranteeFairShare -= child->Attributes().*estimatedGuaranteeFairShare; + } + + for (auto resourceType : TEnumTraits<EJobResourceType>::GetDomainValues()) { + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + (child->Attributes().*estimatedGuaranteeFairShare)[resourceType] += undistributedEstimatedGuaranteeFairShare[resourceType] * child->GetWeight() / weightSum; + } + } + }; + + computeGuaranteeFairShare(/*estimatedGuaranteeFairShare*/ &TSchedulableAttributes::PromisedFairShare); + computeGuaranteeFairShare(/*estimatedGuaranteeFairShare*/ &TSchedulableAttributes::EstimatedGuaranteeShare); + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + GetChild(childIndex)->AdjustStrongGuarantees(context); + } +} + +template <class TValue, class TGetter, class TSetter> +TValue TCompositeElement::ComputeByFitting( + const TGetter& getter, + const TSetter& setter, + TValue maxSum, + bool strictMode) +{ + auto checkSum = [&] (double fitFactor) -> bool { + TValue sum = {}; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + sum += getter(fitFactor, child); + } + + if constexpr (std::is_same_v<TValue, TResourceVector>) { + return Dominates(maxSum, sum); + } else { + return maxSum >= sum; + } + }; + + double fitFactor; + if (!strictMode && !checkSum(0.0)) { + // Even left bound doesn't satisfy predicate. + fitFactor = 0.0; + } else { + // Run binary search to compute fit factor. + fitFactor = FloatingPointInverseLowerBound(0.0, 1.0, checkSum); + } + + TValue resultSum = {}; + + // Compute actual values from fit factor. + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + TValue value = getter(fitFactor, child); + resultSum += value; + setter(child, value); + } + + return resultSum; +} + +void TCompositeElement::PrepareFairShareFunctions(TFairShareUpdateContext* context) +{ + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + child->PrepareFairShareFunctions(context); + } + + TElement::PrepareFairShareFunctions(context); +} + +void TCompositeElement::PrepareFairShareByFitFactor(TFairShareUpdateContext* context) +{ + switch (GetMode()) { + case ESchedulingMode::Fifo: + PrepareFairShareByFitFactorFifo(context); + break; + + case ESchedulingMode::FairShare: + PrepareFairShareByFitFactorNormal(context); + break; + + default: + YT_ABORT(); + } +} + +// Fit factor for a FIFO pool is defined as the number of satisfied children plus the suggestion +// of the first child that is not satisfied, if any. +// A child is said to be satisfied when it is suggested the whole cluster (|suggestion == 1.0|). +// Note that this doesn't necessarily mean that the child's demand is satisfied. +// For an empty FIFO pool fit factor is not well defined. +// +// The unambiguity of the definition of the fit factor follows the fact that the suggestion of +// an unsatisfied child is, by definition, less than 1. +// +// Note that we assume all children have no guaranteed resources, so for any child: +// |child->FairShareBySuggestion_(0.0) == TResourceVector::Zero()|, and 0.0 is not a discontinuity +// point of |child->FairShareBySuggestion_|. +void TCompositeElement::PrepareFairShareByFitFactorFifo(TFairShareUpdateContext* context) +{ + TWallTimer timer; + auto finally = Finally([&] { + context->PrepareFairShareByFitFactorFifoTotalTime += timer.GetElapsedCpuTime(); + }); + + if (GetChildCount() == 0) { + FairShareByFitFactor_ = TVectorPiecewiseLinearFunction::Constant(0.0, 1.0, TResourceVector::Zero()); + return; + } + + double rightFunctionBound = GetChildCount(); + FairShareByFitFactor_ = TVectorPiecewiseLinearFunction::Constant(0.0, rightFunctionBound, TResourceVector::Zero()); + + double currentRightBound = 0.0; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = SortedChildren_[childIndex]; + const auto& childFSBS = *child->FairShareBySuggestion_; + + // NB(eshcherbin): Children of FIFO pools don't have guaranteed resources. See the function comment. + YT_VERIFY(childFSBS.IsTrimmedLeft() && childFSBS.IsTrimmedRight()); + YT_VERIFY(childFSBS.LeftFunctionValue() == TResourceVector::Zero()); + + // TODO(antonkikh): This can be implemented much more efficiently by concatenating functions instead of adding. + *FairShareByFitFactor_ += childFSBS + .Shift(/* deltaArgument */ currentRightBound) + .Extend(/* newLeftBound */ 0.0, /* newRightBound */ rightFunctionBound); + currentRightBound += 1.0; + } + + YT_VERIFY(currentRightBound == rightFunctionBound); +} + +void TCompositeElement::PrepareFairShareByFitFactorNormal(TFairShareUpdateContext* context) +{ + TWallTimer timer; + auto finally = Finally([&] { + context->PrepareFairShareByFitFactorNormalTotalTime += timer.GetElapsedCpuTime(); + }); + + if (GetChildCount() == 0) { + FairShareByFitFactor_ = TVectorPiecewiseLinearFunction::Constant(0.0, 1.0, TResourceVector::Zero()); + return; + } + + std::vector<TVectorPiecewiseLinearFunction> childrenFunctions; + double minWeight = GetMinChildWeight(); + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + const auto& childFSBS = *child->FairShareBySuggestion_; + + auto childFunction = childFSBS + .ScaleArgument(child->GetWeight() / minWeight) + .ExtendRight(/* newRightBound */ 1.0); + + childrenFunctions.push_back(std::move(childFunction)); + } + + FairShareByFitFactor_ = TVectorPiecewiseLinearFunction::Sum(childrenFunctions); +} + +double TCompositeElement::GetMinChildWeight() const +{ + double minWeight = std::numeric_limits<double>::max(); + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + if (child->GetWeight() > RatioComputationPrecision) { + minWeight = std::min(minWeight, child->GetWeight()); + } + } + return minWeight; +} + +// Returns a vector of suggestions for children from |SortedEnabledChildren_| based on the given fit factor. +TCompositeElement::TChildSuggestions TCompositeElement::GetChildSuggestionsFifo(double fitFactor) +{ + YT_VERIFY(fitFactor <= SortedChildren_.size()); + + int satisfiedChildCount = static_cast<int>(fitFactor); + double unsatisfiedChildSuggestion = fitFactor - satisfiedChildCount; + + TChildSuggestions childSuggestions(SortedChildren_.size(), 0.0); + for (int i = 0; i < satisfiedChildCount; i++) { + childSuggestions[i] = 1.0; + } + + if (unsatisfiedChildSuggestion != 0.0) { + childSuggestions[satisfiedChildCount] = unsatisfiedChildSuggestion; + } + + return childSuggestions; +} + +// Returns a vector of suggestions for children from |EnabledChildren_| based on the given fit factor. +TCompositeElement::TChildSuggestions TCompositeElement::GetChildSuggestionsNormal(double fitFactor) +{ + const double minWeight = GetMinChildWeight(); + + TChildSuggestions childSuggestions; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + childSuggestions.push_back(std::min(1.0, fitFactor * (child->GetWeight() / minWeight))); + } + + return childSuggestions; +} + +void TCompositeElement::ComputeAndSetFairShare(double suggestion, TFairShareUpdateContext* context) +{ + const auto& Logger = GetLogger(); + + if (GetChildCount() == 0) { + Attributes().SetFairShare(TResourceVector::Zero()); + return; + } + + auto suggestedFairShare = FairShareBySuggestion_->ValueAt(suggestion); + + // Find the right fit factor to use when computing suggestions for children. + + // NB(eshcherbin): Vector of suggestions returned by |getEnabledChildSuggestions| must be consistent + // with |children|, i.e. i-th suggestion is meant to be given to i-th enabled child. + // This implicit correspondence between children and suggestions is done for optimization purposes. + std::vector<TElement*> children; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + children.push_back( + GetMode() == ESchedulingMode::Fifo + ? SortedChildren_[childIndex] + : GetChild(childIndex)); + } + + auto getEnabledChildSuggestions = (GetMode() == ESchedulingMode::Fifo) + ? std::bind(&TCompositeElement::GetChildSuggestionsFifo, this, std::placeholders::_1) + : std::bind(&TCompositeElement::GetChildSuggestionsNormal, this, std::placeholders::_1); + + auto getChildrenSuggestedFairShare = [&] (double fitFactor) { + auto childSuggestions = getEnabledChildSuggestions(fitFactor); + YT_VERIFY(childSuggestions.size() == children.size()); + + TResourceVector childrenSuggestedFairShare; + for (int childIndex = 0; childIndex < std::ssize(children); ++childIndex) { + const auto& child = children[childIndex]; + auto childSuggestion = childSuggestions[childIndex]; + childrenSuggestedFairShare += child->FairShareBySuggestion_->ValueAt(childSuggestion); + } + + return childrenSuggestedFairShare; + }; + auto checkFitFactor = [&] (double fitFactor) { + // Check that we can safely use the given fit factor to compute suggestions for children. + return Dominates(suggestedFairShare + TResourceVector::SmallEpsilon(), getChildrenSuggestedFairShare(fitFactor)); + }; + + // Usually MFFBS(suggestion) is the right fit factor to use for child suggestions. + auto fitFactor = MaxFitFactorBySuggestion_->ValueAt(suggestion); + if (!checkFitFactor(fitFactor)) { + YT_ASSERT(checkFitFactor(0.0)); + + // However, sometimes we need to tweak MFFBS(suggestion) in order not to suggest too much to children. + // NB(eshcherbin): Possible to optimize this by using galloping, as the target fit factor + // should be very, very close to our first estimate. + fitFactor = FloatingPointInverseLowerBound( + /* lo */ 0.0, + /* hi */ fitFactor, + /* predicate */ checkFitFactor); + } + + // Propagate suggestions to children and collect the total used fair share. + + auto childSuggestions = getEnabledChildSuggestions(fitFactor); + YT_VERIFY(childSuggestions.size() == children.size()); + + TResourceVector childrenUsedFairShare; + for (int childIndex = 0; childIndex < std::ssize(children); ++childIndex) { + const auto& child = children[childIndex]; + auto childSuggestion = childSuggestions[childIndex]; + child->ComputeAndSetFairShare(childSuggestion, context); + childrenUsedFairShare += child->Attributes().FairShare.Total; + } + + // Validate children total fair share. + bool suggestedShareNearlyDominatesChildrenUsedShare = + Dominates(suggestedFairShare + TResourceVector::SmallEpsilon(), childrenUsedFairShare); + bool usedShareNearSuggestedShare = + TResourceVector::Near(childrenUsedFairShare, suggestedFairShare, 1e-4 * MaxComponent(childrenUsedFairShare)); + + YT_LOG_WARNING_UNLESS(usedShareNearSuggestedShare && suggestedShareNearlyDominatesChildrenUsedShare, + "Fair share significantly differs from predicted in pool (" + "Mode: %v, " + "Suggestion: %.20v, " + "VectorSuggestion: %.20v, " + "SuggestedFairShare: %.20v, " + "ChildrenUsedFairShare: %.20v, " + "Difference: %.20v, " + "FitFactor: %.20v, " + "FSBFFPredicted: %.20v, " + "ChildrenSuggestedFairShare: %.20v, " + "ChildrenCount: %v)", + GetMode(), + suggestion, + GetVectorSuggestion(suggestion), + suggestedFairShare, + childrenUsedFairShare, + suggestedFairShare - childrenUsedFairShare, + fitFactor, + FairShareByFitFactor_->ValueAt(fitFactor), + getChildrenSuggestedFairShare(fitFactor), + GetChildCount()); + + YT_VERIFY(suggestedShareNearlyDominatesChildrenUsedShare); + + // Set fair share. + + Attributes().SetFairShare(suggestedFairShare); + CheckFairShareFeasibility(); +} + +void TCompositeElement::TruncateFairShareInFifoPools() +{ + THashSet<TElement*> truncatedChildren; + if (GetMode() == ESchedulingMode::Fifo && IsFairShareTruncationInFifoPoolEnabled()) { + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto *childOperation = SortedChildren_[childIndex]->AsOperation(); + + YT_VERIFY(childOperation); + + const auto& childAttributes = childOperation->Attributes(); + auto childFairShare = childAttributes.FairShare.Total; + if (childFairShare == TResourceVector::Zero()) { + continue; + } + + // NB(eshcherbin, YT-15061): This truncation is only used in GPU-trees to enable preemption of jobs of gang operations + // which fair share is less than demand. + bool isChildFullySatisfied = Dominates(childFairShare + TResourceVector::Epsilon(), childAttributes.DemandShare); + bool shouldTruncate = !isChildFullySatisfied && childOperation->IsGang(); + if (shouldTruncate) { + const auto& Logger = GetLogger(); + + TotalTruncatedFairShare_ += childFairShare; + childOperation->Attributes().SetFairShare(TResourceVector::Zero()); + truncatedChildren.insert(childOperation); + + YT_LOG_DEBUG("Truncated operation fair share in FIFO pool (OperationId: %v, TruncatedFairShare: %v)", + childOperation->GetId(), + childFairShare); + } + } + } + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto* child = GetChild(childIndex); + if (!truncatedChildren.contains(child)) { + child->TruncateFairShareInFifoPools(); + TotalTruncatedFairShare_ += child->GetTotalTruncatedFairShare(); + } + } + + // TODO(eshcherbin): Should we use epsilon here? + if (TotalTruncatedFairShare_ != TResourceVector::Zero()) { + auto fairShare = TResourceVector::Max(Attributes().FairShare.Total - TotalTruncatedFairShare_, TResourceVector::Zero()); + Attributes().SetFairShare(fairShare); + } +} + +void TCompositeElement::UpdateOverflowAndAcceptableVolumesRecursively() +{ + const auto& Logger = GetLogger(); + auto& attributes = Attributes(); + + auto thisPool = AsPool(); + if (thisPool && thisPool->GetIntegralGuaranteeType() != EIntegralGuaranteeType::None) { + return; + } + + TResourceVolume childrenAcceptableVolume; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + if (auto* childPool = GetChild(childIndex)->AsPool()) { + childPool->UpdateOverflowAndAcceptableVolumesRecursively(); + attributes.ChildrenVolumeOverflow += childPool->Attributes().VolumeOverflow; + childrenAcceptableVolume += childPool->Attributes().AcceptableVolume; + } + } + + bool canAcceptFreeVolume = CanAcceptFreeVolume(); + + TResourceVolume::ForEachResource([&] (EJobResourceType /*resourceType*/, auto TResourceVolume::* resourceDataMember) { + auto diff = attributes.ChildrenVolumeOverflow.*resourceDataMember - childrenAcceptableVolume.*resourceDataMember; + if (diff > 0) { + attributes.VolumeOverflow.*resourceDataMember = diff; + attributes.AcceptableVolume.*resourceDataMember = 0; + } else { + attributes.VolumeOverflow.*resourceDataMember = 0; + attributes.AcceptableVolume.*resourceDataMember = canAcceptFreeVolume ? -diff : 0; + } + }); + + if (!attributes.VolumeOverflow.IsZero()) { + YT_LOG_DEBUG("Pool has volume overflow (Volume: %v)", attributes.VolumeOverflow); + } +} + +void TCompositeElement::DistributeFreeVolume() +{ + const auto& Logger = GetLogger(); + auto& attributes = Attributes(); + + TResourceVolume freeVolume = attributes.AcceptedFreeVolume; + + auto* thisPool = AsPool(); + if (thisPool && thisPool->GetIntegralGuaranteeType() != EIntegralGuaranteeType::None) { + if (!freeVolume.IsZero()) { + thisPool->IntegralResourcesState().AccumulatedVolume += freeVolume; + YT_LOG_DEBUG("Pool has accepted free volume (FreeVolume: %v)", freeVolume); + } + return; + } + + if (ShouldDistributeFreeVolumeAmongChildren() && !(freeVolume.IsZero() && attributes.ChildrenVolumeOverflow.IsZero())) { + YT_LOG_DEBUG( + "Distributing free volume among children (FreeVolumeFromParent: %v, ChildrenVolumeOverflow: %v)", + freeVolume, + attributes.ChildrenVolumeOverflow); + + freeVolume += attributes.ChildrenVolumeOverflow; + + struct TChildAttributes { + int Index; + double Weight; + TSchedulableAttributes* Attributes; + double AcceptableVolumeToWeightRatio; + }; + + TResourceVolume::ForEachResource([&] (EJobResourceType /*resourceType*/, auto TResourceVolume::* resourceDataMember) { + if (freeVolume.*resourceDataMember == 0) { + return; + } + std::vector<TChildAttributes> hungryChildren; + auto weightSum = 0.0; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + auto& childAttributes = GetChild(childIndex)->Attributes(); + if (childAttributes.AcceptableVolume.*resourceDataMember > RatioComputationPrecision && + childAttributes.TotalResourceFlowRatio > RatioComputationPrecision) + { + // Resource flow is taken as weight. + auto weight = childAttributes.TotalResourceFlowRatio; + hungryChildren.push_back(TChildAttributes{ + .Index = childIndex, + .Weight = weight, + .Attributes = &childAttributes, + .AcceptableVolumeToWeightRatio = static_cast<double>(childAttributes.AcceptableVolume.*resourceDataMember) / weight, + }); + weightSum += weight; + } + } + + // Children will be saturated in ascending order of |AcceptableVolumeToWeightRatio|. + std::sort( + hungryChildren.begin(), + hungryChildren.end(), + [] (const TChildAttributes& lhs, const TChildAttributes& rhs) { + return lhs.AcceptableVolumeToWeightRatio < rhs.AcceptableVolumeToWeightRatio; + }); + + auto it = hungryChildren.begin(); + // First we provide free volume to the pools that cannot fully consume the suggested volume. + for (; it != hungryChildren.end(); ++it) { + const auto suggestedFreeVolume = static_cast<double>(freeVolume.*resourceDataMember) * (it->Weight / weightSum); + const auto acceptableVolume = it->Attributes->AcceptableVolume.*resourceDataMember; + if (suggestedFreeVolume < acceptableVolume) { + break; + } + it->Attributes->AcceptedFreeVolume.*resourceDataMember = acceptableVolume; + freeVolume.*resourceDataMember -= acceptableVolume; + weightSum -= it->Weight; + } + + // Then we provide free volume to remaining pools that will fully consume the suggested volume. + for (; it != hungryChildren.end(); ++it) { + auto suggestedFreeVolume = static_cast<double>(freeVolume.*resourceDataMember) * (it->Weight / weightSum); + it->Attributes->AcceptedFreeVolume.*resourceDataMember = static_cast<std::remove_reference_t<decltype(freeVolume.*resourceDataMember)>>(suggestedFreeVolume); + } + }); + } + + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + GetChild(childIndex)->DistributeFreeVolume(); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +void TPool::InitIntegralPoolLists(TFairShareUpdateContext* context) +{ + switch (GetIntegralGuaranteeType()) { + case EIntegralGuaranteeType::Burst: + context->BurstPools.push_back(this); + break; + case EIntegralGuaranteeType::Relaxed: + context->RelaxedPools.push_back(this); + break; + default: + break; + } + TCompositeElement::InitIntegralPoolLists(context); +} + +void TPool::UpdateAccumulatedResourceVolume(TFairShareUpdateContext* context) +{ + const auto& Logger = GetLogger(); + auto& attributes = Attributes(); + + if (context->TotalResourceLimits == TJobResources()) { + return; + } + + if (!context->PreviousUpdateTime) { + return; + } + + auto periodSinceLastUpdate = context->Now - *context->PreviousUpdateTime; + auto& integralResourcesState = IntegralResourcesState(); + + auto oldVolume = integralResourcesState.AccumulatedVolume; + auto poolCapacity = TResourceVolume(context->TotalResourceLimits * attributes.ResourceFlowRatio, context->IntegralPoolCapacitySaturationPeriod); + + auto zero = TResourceVolume(); + integralResourcesState.AccumulatedVolume += + TResourceVolume(context->TotalResourceLimits, periodSinceLastUpdate) * attributes.ResourceFlowRatio; + integralResourcesState.AccumulatedVolume -= + TResourceVolume(context->TotalResourceLimits, periodSinceLastUpdate) * integralResourcesState.LastShareRatio; + integralResourcesState.AccumulatedVolume = Max(integralResourcesState.AccumulatedVolume, zero); + + auto upperLimit = Max(oldVolume, poolCapacity); + + attributes.VolumeOverflow = Max(integralResourcesState.AccumulatedVolume - upperLimit, TResourceVolume()); + if (CanAcceptFreeVolume()) { + attributes.AcceptableVolume = Max(poolCapacity - integralResourcesState.AccumulatedVolume, zero); + } + + integralResourcesState.AccumulatedVolume = Min(integralResourcesState.AccumulatedVolume, upperLimit); + + YT_LOG_DEBUG( + "Accumulated resource volume updated " + "(ResourceFlowRatio: %v, PeriodSinceLastUpdateInSeconds: %v, TotalResourceLimits: %v, LastIntegralShareRatio: %v, " + "PoolCapacity: %v, OldVolume: %v, UpdatedVolume: %v, VolumeOverflow: %v, AcceptableVolume: %v)", + attributes.ResourceFlowRatio, + periodSinceLastUpdate.SecondsFloat(), + context->TotalResourceLimits, + integralResourcesState.LastShareRatio, + poolCapacity, + oldVolume, + integralResourcesState.AccumulatedVolume, + attributes.VolumeOverflow, + attributes.AcceptableVolume); +} + +//////////////////////////////////////////////////////////////////////////////// + +void TRootElement::DetermineEffectiveStrongGuaranteeResources(TFairShareUpdateContext* context) +{ + Attributes().EffectiveStrongGuaranteeResources = context->TotalResourceLimits; + + TCompositeElement::DetermineEffectiveStrongGuaranteeResources(context); +} + +bool TRootElement::IsRoot() const +{ + return true; +} + +void TRootElement::UpdateCumulativeAttributes(TFairShareUpdateContext* context) +{ + TCompositeElement::UpdateCumulativeAttributes(context); + + Attributes().StrongGuaranteeShare = TResourceVector::Zero(); + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + Attributes().StrongGuaranteeShare += child->Attributes().StrongGuaranteeShare; + } +} + +void TRootElement::TruncateFairShareInFifoPools() +{ + const auto& Logger = GetLogger(); + + TCompositeElement::TruncateFairShareInFifoPools(); + + YT_LOG_DEBUG_UNLESS(TotalTruncatedFairShare_ == TResourceVector::Zero(), + "Truncated fair share in FIFO pools (NewFairShare: %v, TotalTruncatedFairShare: %v)", + Attributes().FairShare.Total, + TotalTruncatedFairShare_); +} + +void TRootElement::ValidateAndAdjustSpecifiedGuarantees(TFairShareUpdateContext* context) +{ + auto totalResourceFlow = context->TotalResourceLimits * Attributes().TotalResourceFlowRatio; + auto totalBurstResources = context->TotalResourceLimits * Attributes().TotalBurstRatio; + TJobResources totalStrongGuaranteeResources; + for (int childIndex = 0; childIndex < GetChildCount(); ++childIndex) { + const auto* child = GetChild(childIndex); + totalStrongGuaranteeResources += child->Attributes().EffectiveStrongGuaranteeResources; + } + + if (!Dominates(context->TotalResourceLimits, totalStrongGuaranteeResources + totalResourceFlow)) { + context->Errors.push_back(TError(EErrorCode::PoolTreeGuaranteesOvercommit, "Strong guarantees and resource flows exceed total cluster resources") + << TErrorAttribute("total_strong_guarantee_resources", totalStrongGuaranteeResources) + << TErrorAttribute("total_resource_flow", totalResourceFlow) + << TErrorAttribute("total_cluster_resources", context->TotalResourceLimits)); + } + + if (!Dominates(context->TotalResourceLimits, totalStrongGuaranteeResources + totalBurstResources)) { + context->Errors.push_back(TError(EErrorCode::PoolTreeGuaranteesOvercommit, "Strong guarantees and burst guarantees exceed total cluster resources") + << TErrorAttribute("total_strong_guarantee_resources", totalStrongGuaranteeResources) + << TErrorAttribute("total_burst_resources", totalBurstResources) + << TErrorAttribute("total_cluster_resources", context->TotalResourceLimits)); + + auto checkSum = [&] (double fitFactor) -> bool { + auto sum = Attributes().StrongGuaranteeShare * fitFactor; + for (const auto& pool : context->BurstPools) { + sum += TResourceVector::FromDouble(pool->Attributes().BurstRatio) * fitFactor; + } + return Dominates(TResourceVector::Ones(), sum); + }; + + double fitFactor = FloatingPointInverseLowerBound(0.0, 1.0, checkSum); + + // NB(eshcherbin): Note that we validate the sum of EffectiveStrongGuaranteeResources but adjust StrongGuaranteeShare. + // During validation we need to check the absolute values to handle corner cases correctly and always show the alert. See: YT-14758. + // During adjustment we need to assure the invariants required for vector fair share computation. + Attributes().StrongGuaranteeShare = Attributes().StrongGuaranteeShare * fitFactor; + for (const auto& pool : context->BurstPools) { + pool->Attributes().BurstRatio *= fitFactor; + } + } + + AdjustStrongGuarantees(context); +} + +//////////////////////////////////////////////////////////////////////////////// + +bool TOperationElement::IsOperation() const +{ + return true; +} + +void TOperationElement::PrepareFairShareByFitFactor(TFairShareUpdateContext* context) +{ + TWallTimer timer; + auto finally = Finally([&] { + context->PrepareFairShareByFitFactorOperationsTotalTime += timer.GetElapsedCpuTime(); + }); + + TVectorPiecewiseLinearFunction::TBuilder builder; + + // First we try to satisfy the current usage by giving equal fair share for each resource. + // More precisely, for fit factor 0 <= f <= 1, fair share for resource r will be equal to min(usage[r], f * maxUsage). + double maxUsage = MaxComponent(Attributes().UsageShare); + if (maxUsage == 0.0) { + builder.PushSegment({0.0, TResourceVector::Zero()}, {1.0, TResourceVector::Zero()}); + } else { + TCompactVector<double, ResourceCount> sortedUsage(Attributes().UsageShare.begin(), Attributes().UsageShare.end()); + std::sort(sortedUsage.begin(), sortedUsage.end()); + + builder.AddPoint({0.0, TResourceVector::Zero()}); + double previousUsageFitFactor = 0.0; + for (auto usage : sortedUsage) { + double currentUsageFitFactor = usage / maxUsage; + if (currentUsageFitFactor > previousUsageFitFactor) { + builder.AddPoint({ + currentUsageFitFactor, + TResourceVector::Min(TResourceVector::FromDouble(usage), Attributes().UsageShare)}); + previousUsageFitFactor = currentUsageFitFactor; + } + } + YT_VERIFY(previousUsageFitFactor == 1.0); + } + + // After that we just give fair share proportionally to the remaining demand. + builder.PushSegment({{1.0, Attributes().UsageShare}, {2.0, Attributes().DemandShare}}); + + FairShareByFitFactor_ = builder.Finish(); +} + +void TOperationElement::ComputeAndSetFairShare(double suggestion, TFairShareUpdateContext* /*context*/) +{ + auto fairShare = FairShareBySuggestion_->ValueAt(suggestion); + Attributes().SetFairShare(fairShare); + CheckFairShareFeasibility(); + + if (AreDetailedLogsEnabled()) { + const auto& Logger = GetLogger(); + + const auto fsbsSegment = FairShareBySuggestion_->SegmentAt(suggestion); + const auto fitFactor = MaxFitFactorBySuggestion_->ValueAt(suggestion); + const auto fsbffSegment = FairShareByFitFactor_->SegmentAt(fitFactor); + + YT_LOG_DEBUG( + "Updated operation fair share (" + "Suggestion: %.10g, " + "UsedFairShare: %.10g, " + "FSBSSegmentArguments: {%.10g, %.10g}, " + "FSBSSegmentValues: {%.10g, %.10g}, " + "FitFactor: %.10g, " + "FSBFFSegmentArguments: {%.10g, %.10g}, " + "FSBFFSegmentValues: {%.10g, %.10g})", + suggestion, + fairShare, + fsbsSegment.LeftBound(), fsbsSegment.RightBound(), + fsbsSegment.LeftValue(), fsbsSegment.RightValue(), + fitFactor, + fsbffSegment.LeftBound(), fsbffSegment.RightBound(), + fsbffSegment.LeftValue(), fsbffSegment.RightValue()); + } +} + +void TOperationElement::TruncateFairShareInFifoPools() +{ } + +TResourceVector TOperationElement::ComputeLimitsShare(const TFairShareUpdateContext* context) const +{ + return TResourceVector::Min(TElement::ComputeLimitsShare(context), GetBestAllocationShare()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TFairShareUpdateContext::TFairShareUpdateContext( + const TJobResources totalResourceLimits, + const EJobResourceType mainResource, + const TDuration integralPoolCapacitySaturationPeriod, + const TDuration integralSmoothPeriod, + const TInstant now, + const std::optional<TInstant> previousUpdateTime) + : TotalResourceLimits(totalResourceLimits) + , MainResource(mainResource) + , IntegralPoolCapacitySaturationPeriod(integralPoolCapacitySaturationPeriod) + , IntegralSmoothPeriod(integralSmoothPeriod) + , Now(now) + , PreviousUpdateTime(previousUpdateTime) +{ } + +//////////////////////////////////////////////////////////////////////////////// + +TFairShareUpdateExecutor::TFairShareUpdateExecutor( + const TRootElementPtr& rootElement, + TFairShareUpdateContext* context) + : RootElement_(rootElement) + , Context_(context) +{ } + +/// Steps of fair share update: +/// +/// 1. Initialize burst and relaxed pool lists. This is a single pass through the tree. +/// +/// 2. Update attributes needed for calculation of fair share (LimitsShare, DemandShare, UsageShare, StrongGuaranteeShare and others); +/// +/// 3. Consume and refill accumulated resource volume of integral pools. +/// The amount of resources consumed by a pool is based on its integral guarantee share since the last fair share update. +/// Refilling is based on the resource flow ratio which was calculated in the previous step. +/// +/// 4. Validate that the sum of burst and strong guarantee shares meet the total resources and that the strong guarantee share of every pool meets the limits share of that pool. +/// Shrink the guarantees in case of limits violations. +/// +/// 5. Calculate integral shares for burst pools. +/// We temporarily increase the pool's resource guarantees by burst guarantees, and calculate how many resources the pool would consume within these extended guarantees. +/// Then we subtract the pool's strong guarantee share from the consumed resources to estimate the integral shares. +/// Descendants of burst pools have their fair share functions built on this step. +/// +/// 6. Estimate the amount of available resources after satisfying strong and burst guarantees of all pools. +/// +/// 7. Distribute available resources among the relaxed pools using binary search. +/// We build fair share functions for descendants of relaxed pools in this step. +/// +/// 8. Build fair share functions and compute final fair shares of all pools. +/// The weight proportional component emerges here. +void TFairShareUpdateExecutor::Run() +{ + const auto& Logger = FairShareLogger; + + TWallTimer timer; + + RootElement_->DetermineEffectiveStrongGuaranteeResources(Context_); + RootElement_->InitIntegralPoolLists(Context_); + RootElement_->UpdateCumulativeAttributes(Context_); + ConsumeAndRefillIntegralPools(); + RootElement_->ValidateAndAdjustSpecifiedGuarantees(Context_); + + UpdateBurstPoolIntegralShares(); + UpdateRelaxedPoolIntegralShares(); + + RootElement_->PrepareFairShareFunctions(Context_); + RootElement_->ComputeAndSetFairShare(/*suggestion*/ 1.0, Context_); + RootElement_->TruncateFairShareInFifoPools(); + + UpdateRootFairShare(); + + auto totalDuration = timer.GetElapsedCpuTime(); + + YT_LOG_DEBUG( + "Finished updating fair share (" + "TotalTime: %v, " + "PrepareFairShareByFitFactor/TotalTime: %v, " + "PrepareFairShareByFitFactor/Operations/TotalTime: %v, " + "PrepareFairShareByFitFactor/Fifo/TotalTime: %v, " + "PrepareFairShareByFitFactor/Normal/TotalTime: %v, " + "PrepareMaxFitFactorBySuggestion/TotalTime: %v, " + "PrepareMaxFitFactorBySuggestion/PointwiseMin/TotalTime: %v, " + "Compose/TotalTime: %v., " + "CompressFunction/TotalTime: %v)", + CpuDurationToDuration(totalDuration).MicroSeconds(), + CpuDurationToDuration(Context_->PrepareFairShareByFitFactorTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->PrepareFairShareByFitFactorOperationsTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->PrepareFairShareByFitFactorFifoTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->PrepareFairShareByFitFactorNormalTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->PrepareMaxFitFactorBySuggestionTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->PointwiseMinTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->ComposeTotalTime).MicroSeconds(), + CpuDurationToDuration(Context_->CompressFunctionTotalTime).MicroSeconds()); +} + +void TFairShareUpdateExecutor::UpdateBurstPoolIntegralShares() +{ + const auto& Logger = FairShareLogger; + + for (auto& burstPool : Context_->BurstPools) { + auto integralRatio = std::min(burstPool->Attributes().BurstRatio, GetIntegralShareRatioByVolume(burstPool)); + auto proposedIntegralShare = TResourceVector::Min( + TResourceVector::FromDouble(integralRatio), + GetHierarchicalAvailableLimitsShare(burstPool)); + YT_VERIFY(Dominates(proposedIntegralShare, TResourceVector::Zero())); + + proposedIntegralShare = AdjustProposedIntegralShare( + burstPool->Attributes().LimitsShare, + burstPool->Attributes().StrongGuaranteeShare, + proposedIntegralShare); + + burstPool->Attributes().ProposedIntegralShare = proposedIntegralShare; + burstPool->PrepareFairShareFunctions(Context_); + burstPool->Attributes().ProposedIntegralShare = TResourceVector::Zero(); + + auto fairShareWithinGuarantees = burstPool->FairShareBySuggestion_->ValueAt(0.0); + auto integralShare = TResourceVector::Max(fairShareWithinGuarantees - burstPool->Attributes().StrongGuaranteeShare, TResourceVector::Zero()); + IncreaseHierarchicalIntegralShare(burstPool, integralShare); + burstPool->ResetFairShareFunctions(); + burstPool->IntegralResourcesState().LastShareRatio = MaxComponent(integralShare); + + YT_LOG_DEBUG( + "Provided integral share for burst pool " + "(Pool: %v, ShareRatioByVolume: %v, ProposedIntegralShare: %v, FSWithingGuarantees: %v, IntegralShare: %v)", + burstPool->GetId(), + GetIntegralShareRatioByVolume(burstPool), + proposedIntegralShare, + fairShareWithinGuarantees, + integralShare); + } +} + +void TFairShareUpdateExecutor::UpdateRelaxedPoolIntegralShares() +{ + const auto& Logger = FairShareLogger; + + if (Context_->RelaxedPools.empty()) { + return; + } + + auto availableShare = TResourceVector::Ones(); + for (int childIndex = 0; childIndex < RootElement_->GetChildCount(); ++childIndex) { + const auto* child = RootElement_->GetChild(childIndex); + auto usedShare = TResourceVector::Min(child->Attributes().GetGuaranteeShare(), child->Attributes().DemandShare); + availableShare -= usedShare; + } + + std::vector<TPool*> relaxedPools; + std::vector<double> weights; + std::vector<TResourceVector> originalLimits; + for (auto& relaxedPool : Context_->RelaxedPools) { + double integralShareRatio = GetIntegralShareRatioByVolume(relaxedPool); + if (integralShareRatio == 0) { + continue; + } + relaxedPools.push_back(relaxedPool); + weights.push_back(integralShareRatio); + originalLimits.push_back(relaxedPool->Attributes().LimitsShare); + + // It is incorporated version of this method below. + // relaxedPool->ApplyLimitsForRelaxedPool(); + { + auto relaxedPoolLimit = TResourceVector::Min( + TResourceVector::FromDouble(integralShareRatio), + relaxedPool->GetIntegralShareLimitForRelaxedPool()); + relaxedPoolLimit += relaxedPool->Attributes().StrongGuaranteeShare; + relaxedPool->Attributes().LimitsShare = TResourceVector::Min(relaxedPool->Attributes().LimitsShare, relaxedPoolLimit); + } + + relaxedPool->PrepareFairShareFunctions(Context_); + } + + if (relaxedPools.empty()) { + return; + } + + double minWeight = *std::min_element(weights.begin(), weights.end()); + YT_VERIFY(minWeight > 0); + for (auto& weight : weights) { + weight = weight / minWeight; + } + + auto checkFitFactor = [&] (double fitFactor) { + TResourceVector fairShareResult; + for (int index = 0; index < std::ssize(relaxedPools); ++index) { + auto suggestion = std::min(1.0, fitFactor * weights[index]); + auto fairShare = relaxedPools[index]->FairShareBySuggestion_->ValueAt(suggestion); + fairShareResult += TResourceVector::Max(fairShare - relaxedPools[index]->Attributes().StrongGuaranteeShare, TResourceVector::Zero()); + } + + return Dominates(availableShare, fairShareResult); + }; + + auto fitFactor = FloatingPointInverseLowerBound( + /* lo */ 0.0, + /* hi */ 1.0, + /* predicate */ checkFitFactor); + + for (int index = 0; index < std::ssize(relaxedPools); ++index) { + auto weight = weights[index]; + const auto& relaxedPool = relaxedPools[index]; + auto suggestion = std::min(1.0, fitFactor * weight); + auto fairShareWithinGuarantees = relaxedPool->FairShareBySuggestion_->ValueAt(suggestion); + + auto integralShare = TResourceVector::Max(fairShareWithinGuarantees - relaxedPool->Attributes().StrongGuaranteeShare, TResourceVector::Zero()); + + relaxedPool->Attributes().LimitsShare = originalLimits[index]; + + auto limitedIntegralShare = TResourceVector::Min( + integralShare, + GetHierarchicalAvailableLimitsShare(relaxedPool)); + YT_VERIFY(Dominates(limitedIntegralShare, TResourceVector::Zero())); + IncreaseHierarchicalIntegralShare(relaxedPool, limitedIntegralShare); + relaxedPool->ResetFairShareFunctions(); + relaxedPool->IntegralResourcesState().LastShareRatio = MaxComponent(limitedIntegralShare); + + YT_LOG_DEBUG("Provided integral share for relaxed pool " + "(Pool: %v, ShareRatioByVolume: %v, Suggestion: %v, FSWithingGuarantees: %v, IntegralShare: %v, LimitedIntegralShare: %v)", + relaxedPool->GetId(), + GetIntegralShareRatioByVolume(relaxedPool), + suggestion, + fairShareWithinGuarantees, + integralShare, + limitedIntegralShare); + } +} + +void TFairShareUpdateExecutor::ConsumeAndRefillIntegralPools() +{ + for (auto* pool : Context_->BurstPools) { + pool->UpdateAccumulatedResourceVolume(Context_); + } + for (auto* pool : Context_->RelaxedPools) { + pool->UpdateAccumulatedResourceVolume(Context_); + } + + RootElement_->UpdateOverflowAndAcceptableVolumesRecursively(); + RootElement_->DistributeFreeVolume(); +} + +void TFairShareUpdateExecutor::UpdateRootFairShare() +{ + // Make fair share at root equal to sum of children. + TResourceVector totalUsedStrongGuaranteeShare; + TResourceVector totalFairShare; + for (int childIndex = 0; childIndex < RootElement_->GetChildCount(); ++childIndex) { + const auto* child = RootElement_->GetChild(childIndex); + totalUsedStrongGuaranteeShare += child->Attributes().FairShare.StrongGuarantee; + totalFairShare += child->Attributes().FairShare.Total; + } + + // NB(eshcherbin): In order to compute the detailed fair share components correctly, + // we need to set |Attributes_.StrongGuaranteeShare| to the actual used strong guarantee share before calling |SetFairShare|. + // However, afterwards it seems more natural to restore the previous value, which shows + // the total configured strong guarantee shares in the tree. + { + auto staticStrongGuaranteeShare = RootElement_->Attributes().StrongGuaranteeShare; + RootElement_->Attributes().StrongGuaranteeShare = totalUsedStrongGuaranteeShare; + RootElement_->Attributes().SetFairShare(totalFairShare); + RootElement_->Attributes().StrongGuaranteeShare = staticStrongGuaranteeShare; + } +} + +double TFairShareUpdateExecutor::GetIntegralShareRatioByVolume(const TPool* pool) const +{ + const auto& accumulatedVolume = pool->IntegralResourcesState().AccumulatedVolume; + return accumulatedVolume.GetMinResourceRatio(Context_->TotalResourceLimits) / + Context_->IntegralSmoothPeriod.SecondsFloat(); +} + +TResourceVector TFairShareUpdateExecutor::GetHierarchicalAvailableLimitsShare(const TElement* element) const +{ + auto* current = element; + auto resultLimitsShare = TResourceVector::Ones(); + while (!current->IsRoot()) { + const auto& limitsShare = current->Attributes().LimitsShare; + const auto& effectiveGuaranteeShare = TResourceVector::Min( + current->Attributes().GetGuaranteeShare(), + current->Attributes().DemandShare); + + resultLimitsShare = TResourceVector::Min(resultLimitsShare, limitsShare - effectiveGuaranteeShare); + YT_VERIFY(Dominates(resultLimitsShare, TResourceVector::Zero())); + + current = current->GetParentElement(); + } + + return resultLimitsShare; +} + +void TFairShareUpdateExecutor::IncreaseHierarchicalIntegralShare(TElement* element, const TResourceVector& delta) +{ + auto* current = element; + while (current) { + // We allow guarantee share overcommit at root, because some part of strong guarantees can be reused as a relaxed integral share. + auto increasedProposedIntegralShare = current->Attributes().ProposedIntegralShare + delta; + if (!current->IsRoot()) { + increasedProposedIntegralShare = AdjustProposedIntegralShare( + current->Attributes().LimitsShare, + current->Attributes().StrongGuaranteeShare, + increasedProposedIntegralShare); + } + + current->Attributes().ProposedIntegralShare = increasedProposedIntegralShare; + current = current->GetParentElement(); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf diff --git a/yt/yt/library/vector_hdrf/fair_share_update.h b/yt/yt/library/vector_hdrf/fair_share_update.h new file mode 100644 index 0000000000..5b6f1be491 --- /dev/null +++ b/yt/yt/library/vector_hdrf/fair_share_update.h @@ -0,0 +1,373 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/core/logging/log.h> + +#include <yt/yt/core/profiling/timing.h> + +#include <yt/yt/library/vector_hdrf/job_resources.h> +#include <yt/yt/library/vector_hdrf/public.h> +#include <yt/yt/library/vector_hdrf/resource_vector.h> +#include <yt/yt/library/vector_hdrf/resource_volume.h> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +class TFairShareUpdateExecutor; +struct TFairShareUpdateContext; + +class TElement; +class TCompositeElement; +class TPool; +class TRootElement; +class TOperationElement; + +//////////////////////////////////////////////////////////////////////////////// + +struct TDetailedFairShare +{ + TResourceVector StrongGuarantee = {}; + TResourceVector IntegralGuarantee = {}; + TResourceVector WeightProportional = {}; + TResourceVector Total = {}; +}; + +TString ToString(const TDetailedFairShare& detailedFairShare); + +void FormatValue(TStringBuilderBase* builder, const TDetailedFairShare& detailedFairShare, TStringBuf /* format */); + +//////////////////////////////////////////////////////////////////////////////// + +struct TIntegralResourcesState +{ + TResourceVolume AccumulatedVolume; + double LastShareRatio = 0.0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TSchedulableAttributes +{ + EJobResourceType DominantResource = EJobResourceType::Cpu; + + TDetailedFairShare FairShare; + TResourceVector UsageShare; + TResourceVector DemandShare; + TResourceVector LimitsShare; + TResourceVector StrongGuaranteeShare; + TResourceVector ProposedIntegralShare; + TResourceVector PromisedFairShare; + TResourceVector EstimatedGuaranteeShare; + + TResourceVolume VolumeOverflow; + TResourceVolume AcceptableVolume; + TResourceVolume AcceptedFreeVolume; + TResourceVolume ChildrenVolumeOverflow; + + TJobResources EffectiveStrongGuaranteeResources; + + double BurstRatio = 0.0; + double TotalBurstRatio = 0.0; + double ResourceFlowRatio = 0.0; + double TotalResourceFlowRatio = 0.0; + + std::optional<int> FifoIndex; + + TResourceVector GetGuaranteeShare() const; + + void SetFairShare(const TResourceVector& fairShare); +}; + +//////////////////////////////////////////////////////////////////////////////// + +//! Adjusts |proposedIntegralShare| so that the total guarantee share does not exceed limits share. +//! If |strongGuaranteeShare| + |proposedIntegralShare| <= |limitShare|, returns |proposedIntegralShare|. +//! Otherwise (due to a precision error), slightly decreases components of |proposedIntegralShare| until the inequality holds +//! and returns the resulting vector. +TResourceVector AdjustProposedIntegralShare( + const TResourceVector& limitsShare, + const TResourceVector& strongGuaranteeShare, + TResourceVector proposedIntegralShare); + +//////////////////////////////////////////////////////////////////////////////// + +class TElement + : public virtual TRefCounted +{ +public: + virtual const TJobResources& GetResourceDemand() const = 0; + virtual const TJobResources& GetResourceUsageAtUpdate() const = 0; + // New method - should incapsulate ResourceLimits_ calculation logic and BestAllocation logic for operations. + virtual const TJobResources& GetResourceLimits() const = 0; + + virtual const TJobResourcesConfig* GetStrongGuaranteeResourcesConfig() const = 0; + virtual double GetWeight() const = 0; + + virtual TSchedulableAttributes& Attributes() = 0; + virtual const TSchedulableAttributes& Attributes() const = 0; + + virtual TElement* GetParentElement() const = 0; + + virtual bool IsRoot() const; + virtual bool IsOperation() const; + TPool* AsPool(); + TOperationElement* AsOperation(); + + virtual TString GetId() const = 0; + + virtual const NLogging::TLogger& GetLogger() const = 0; + virtual bool AreDetailedLogsEnabled() const = 0; + + // It is public for testing purposes. + void ResetFairShareFunctions(); + +private: + bool AreFairShareFunctionsPrepared_ = false; + std::optional<TVectorPiecewiseLinearFunction> FairShareByFitFactor_; + std::optional<TVectorPiecewiseLinearFunction> FairShareBySuggestion_; + std::optional<TScalarPiecewiseLinearFunction> MaxFitFactorBySuggestion_; + + TResourceVector TotalTruncatedFairShare_; + + virtual void PrepareFairShareFunctions(TFairShareUpdateContext* context); + virtual void PrepareFairShareByFitFactor(TFairShareUpdateContext* context) = 0; + void PrepareMaxFitFactorBySuggestion(TFairShareUpdateContext* context); + + virtual void DetermineEffectiveStrongGuaranteeResources(TFairShareUpdateContext* context); + virtual void UpdateCumulativeAttributes(TFairShareUpdateContext* context); + virtual void ComputeAndSetFairShare(double suggestion, TFairShareUpdateContext* context) = 0; + virtual void TruncateFairShareInFifoPools() = 0; + + void CheckFairShareFeasibility() const; + + virtual TResourceVector ComputeLimitsShare(const TFairShareUpdateContext* context) const; + void UpdateAttributes(const TFairShareUpdateContext* context); + + TResourceVector GetVectorSuggestion(double suggestion) const; + + virtual void AdjustStrongGuarantees(const TFairShareUpdateContext* context); + virtual void InitIntegralPoolLists(TFairShareUpdateContext* context); + virtual void DistributeFreeVolume(); + + TResourceVector GetTotalTruncatedFairShare() const; + + friend class TCompositeElement; + friend class TPool; + friend class TRootElement; + friend class TOperationElement; + friend class TFairShareUpdateExecutor; +}; + +DECLARE_REFCOUNTED_CLASS(TElement) +DEFINE_REFCOUNTED_TYPE(TElement) + +//////////////////////////////////////////////////////////////////////////////// + +class TCompositeElement + : public virtual TElement +{ +public: + virtual TElement* GetChild(int index) = 0; + virtual const TElement* GetChild(int index) const = 0; + virtual int GetChildCount() const = 0; + + virtual ESchedulingMode GetMode() const = 0; + virtual bool HasHigherPriorityInFifoMode(const TElement* lhs, const TElement* rhs) const = 0; + + virtual double GetSpecifiedBurstRatio() const = 0; + virtual double GetSpecifiedResourceFlowRatio() const = 0; + + virtual bool IsFairShareTruncationInFifoPoolEnabled() const = 0; + virtual bool CanAcceptFreeVolume() const = 0; + virtual bool ShouldDistributeFreeVolumeAmongChildren() const = 0; + +private: + using TChildSuggestions = std::vector<double>; + + std::vector<TElement*> SortedChildren_; + + void PrepareFairShareFunctions(TFairShareUpdateContext* context) override; + void PrepareFairShareByFitFactor(TFairShareUpdateContext* context) override; + void PrepareFairShareByFitFactorFifo(TFairShareUpdateContext* context); + void PrepareFairShareByFitFactorNormal(TFairShareUpdateContext* context); + + void AdjustStrongGuarantees(const TFairShareUpdateContext* context) override; + void InitIntegralPoolLists(TFairShareUpdateContext* context) override; + void DetermineEffectiveStrongGuaranteeResources(TFairShareUpdateContext* context) override; + void DetermineImplicitEffectiveStrongGuaranteeResources( + const TJobResources& totalExplicitChildrenGuaranteeResources, + TFairShareUpdateContext* context); + void UpdateCumulativeAttributes(TFairShareUpdateContext* context) override; + void UpdateOverflowAndAcceptableVolumesRecursively(); + void DistributeFreeVolume() override; + void ComputeAndSetFairShare(double suggestion, TFairShareUpdateContext* context) override; + void TruncateFairShareInFifoPools() override; + + void PrepareFifoPool(); + + double GetMinChildWeight() const; + + /// strict_mode = true means that a caller guarantees that the sum predicate is true at least for fit factor = 0.0. + /// strict_mode = false means that if the sum predicate is false for any fit factor, we fit children to the least possible sum + /// (i. e. use fit factor = 0.0) + template <class TValue, class TGetter, class TSetter> + TValue ComputeByFitting( + const TGetter& getter, + const TSetter& setter, + TValue maxSum, + bool strictMode = true); + + TChildSuggestions GetChildSuggestionsFifo(double fitFactor); + TChildSuggestions GetChildSuggestionsNormal(double fitFactor); + + friend class TPool; + friend class TRootElement; + friend class TFairShareUpdateExecutor; +}; + +DECLARE_REFCOUNTED_CLASS(CompositeElement) +DEFINE_REFCOUNTED_TYPE(TCompositeElement) + +//////////////////////////////////////////////////////////////////////////////// + +class TPool + : public virtual TCompositeElement +{ +public: + // NB: it is combination of options on pool and on tree. + virtual TResourceVector GetIntegralShareLimitForRelaxedPool() const = 0; + + virtual const TIntegralResourcesState& IntegralResourcesState() const = 0; + virtual TIntegralResourcesState& IntegralResourcesState() = 0; + + virtual EIntegralGuaranteeType GetIntegralGuaranteeType() const = 0; + +private: + void InitIntegralPoolLists(TFairShareUpdateContext* context) override; + + void UpdateAccumulatedResourceVolume(TFairShareUpdateContext* context); + + friend class TFairShareUpdateExecutor; +}; + +DECLARE_REFCOUNTED_CLASS(TPool) +DEFINE_REFCOUNTED_TYPE(TPool) + +//////////////////////////////////////////////////////////////////////////////// + +class TRootElement + : public virtual TCompositeElement +{ +public: + bool IsRoot() const override; + +private: + void DetermineEffectiveStrongGuaranteeResources(TFairShareUpdateContext* context) override; + void UpdateCumulativeAttributes(TFairShareUpdateContext* context) override; + void TruncateFairShareInFifoPools() override; + + void ValidateAndAdjustSpecifiedGuarantees(TFairShareUpdateContext* context); + + friend class TElement; + friend class TCompositeElement; + friend class TFairShareUpdateExecutor; +}; + +DECLARE_REFCOUNTED_CLASS(TRootElement) +DEFINE_REFCOUNTED_TYPE(TRootElement) + +//////////////////////////////////////////////////////////////////////////////// + +class TOperationElement + : public virtual TElement +{ +public: + bool IsOperation() const override; + + virtual TResourceVector GetBestAllocationShare() const = 0; + + virtual bool IsGang() const = 0; + +private: + void PrepareFairShareByFitFactor(TFairShareUpdateContext* context) override; + + void ComputeAndSetFairShare(double suggestion, TFairShareUpdateContext* context) override; + void TruncateFairShareInFifoPools() override; + TResourceVector ComputeLimitsShare(const TFairShareUpdateContext* context) const override; + + friend class TFairShareUpdateExecutor; +}; + +DECLARE_REFCOUNTED_CLASS(TOperationElement) +DEFINE_REFCOUNTED_TYPE(TOperationElement) + +//////////////////////////////////////////////////////////////////////////////// + +struct TFairShareUpdateContext +{ + // TODO(eshcherbin): Create a separate fair share update config instead of passing all options in context. + TFairShareUpdateContext( + const TJobResources totalResourceLimits, + const EJobResourceType mainResource, + const TDuration integralPoolCapacitySaturationPeriod, + const TDuration integralSmoothPeriod, + const TInstant now, + const std::optional<TInstant> previousUpdateTime); + + const TJobResources TotalResourceLimits; + + const EJobResourceType MainResource; + const TDuration IntegralPoolCapacitySaturationPeriod; + const TDuration IntegralSmoothPeriod; + + const TInstant Now; + const std::optional<TInstant> PreviousUpdateTime; + + std::vector<TError> Errors; + + NProfiling::TCpuDuration PrepareFairShareByFitFactorTotalTime = {}; + NProfiling::TCpuDuration PrepareFairShareByFitFactorOperationsTotalTime = {}; + NProfiling::TCpuDuration PrepareFairShareByFitFactorFifoTotalTime = {}; + NProfiling::TCpuDuration PrepareFairShareByFitFactorNormalTotalTime = {}; + NProfiling::TCpuDuration PrepareMaxFitFactorBySuggestionTotalTime = {}; + NProfiling::TCpuDuration PointwiseMinTotalTime = {}; + NProfiling::TCpuDuration ComposeTotalTime = {}; + NProfiling::TCpuDuration CompressFunctionTotalTime = {}; + + std::vector<TPool*> RelaxedPools; + std::vector<TPool*> BurstPools; +}; + +//////////////////////////////////////////////////////////////////////////////// + +class TFairShareUpdateExecutor +{ +public: + TFairShareUpdateExecutor( + const TRootElementPtr& rootElement, + // TODO(ignat): split context on input and output parts. + TFairShareUpdateContext* context); + + void Run(); + +private: + const TRootElementPtr RootElement_; + + void ConsumeAndRefillIntegralPools(); + void UpdateBurstPoolIntegralShares(); + void UpdateRelaxedPoolIntegralShares(); + void UpdateRootFairShare(); + + double GetIntegralShareRatioByVolume(const TPool* pool) const; + TResourceVector GetHierarchicalAvailableLimitsShare(const TElement* element) const; + void IncreaseHierarchicalIntegralShare(TElement* element, const TResourceVector& delta); + + TFairShareUpdateContext* Context_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/job_resources.cpp b/yt/yt/library/vector_hdrf/job_resources.cpp new file mode 100644 index 0000000000..2bfaa8c464 --- /dev/null +++ b/yt/yt/library/vector_hdrf/job_resources.cpp @@ -0,0 +1,286 @@ +#include "job_resources.h" + +namespace NYT::NVectorHdrf { + +using std::round; + +//////////////////////////////////////////////////////////////////////////////// + +TJobResources TJobResources::Infinite() +{ + TJobResources result; +#define XX(name, Name) result.Set##Name(std::numeric_limits<decltype(result.Get##Name())>::max() / 4); + ITERATE_JOB_RESOURCES(XX) +#undef XX + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +EJobResourceType GetDominantResource( + const TJobResources& demand, + const TJobResources& limits) +{ + auto maxType = EJobResourceType::Cpu; + double maxRatio = 0.0; + auto update = [&] (auto a, auto b, EJobResourceType type) { + if (static_cast<double>(b) > 0.0) { + double ratio = static_cast<double>(a) / static_cast<double>(b); + if (ratio > maxRatio) { + maxRatio = ratio; + maxType = type; + } + } + }; + #define XX(name, Name) update(demand.Get##Name(), limits.Get##Name(), EJobResourceType::Name); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return maxType; +} + +double GetDominantResourceUsage( + const TJobResources& usage, + const TJobResources& limits) +{ + double maxRatio = 0.0; + auto update = [&] (auto a, auto b) { + if (static_cast<double>(b) > 0.0) { + double ratio = static_cast<double>(a) / static_cast<double>(b); + if (ratio > maxRatio) { + maxRatio = ratio; + } + } + }; + #define XX(name, Name) update(usage.Get##Name(), limits.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return maxRatio; +} + +double GetResource(const TJobResources& resources, EJobResourceType type) +{ + switch (type) { + #define XX(name, Name) \ + case EJobResourceType::Name: \ + return static_cast<double>(resources.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + default: + Y_FAIL(); + } +} + +void SetResource(TJobResources& resources, EJobResourceType type, double value) +{ + switch (type) { + #define XX(name, Name) \ + case EJobResourceType::Name: \ + resources.Set##Name(value); \ + break; + ITERATE_JOB_RESOURCES(XX) + #undef XX + default: + Y_FAIL(); + } +} + +double GetMinResourceRatio( + const TJobResources& nominator, + const TJobResources& denominator) +{ + double result = std::numeric_limits<double>::max(); + bool updated = false; + auto update = [&] (auto a, auto b) { + if (static_cast<double>(b) > 0.0) { + result = std::min(result, static_cast<double>(a) / static_cast<double>(b)); + updated = true; + } + }; + #define XX(name, Name) update(nominator.Get##Name(), denominator.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return updated ? result : 0.0; +} + +double GetMaxResourceRatio( + const TJobResources& nominator, + const TJobResources& denominator) +{ + double result = 0.0; + auto update = [&] (auto a, auto b) { + if (static_cast<double>(b) > 0.0) { + result = std::max(result, static_cast<double>(a) / static_cast<double>(b)); + } + }; + #define XX(name, Name) update(nominator.Get##Name(), denominator.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TJobResources operator + (const TJobResources& lhs, const TJobResources& rhs) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(lhs.Get##Name() + rhs.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TJobResources& operator += (TJobResources& lhs, const TJobResources& rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() + rhs.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TJobResources operator - (const TJobResources& lhs, const TJobResources& rhs) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(lhs.Get##Name() - rhs.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TJobResources& operator -= (TJobResources& lhs, const TJobResources& rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() - rhs.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TJobResources operator * (const TJobResources& lhs, i64 rhs) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(lhs.Get##Name() * rhs); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TJobResources operator * (const TJobResources& lhs, double rhs) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(static_cast<decltype(lhs.Get##Name())>(round(lhs.Get##Name() * rhs))); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TJobResources& operator *= (TJobResources& lhs, i64 rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() * rhs); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TJobResources& operator *= (TJobResources& lhs, double rhs) +{ + #define XX(name, Name) lhs.Set##Name(static_cast<decltype(lhs.Get##Name())>(round(lhs.Get##Name() * rhs))); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TJobResources operator - (const TJobResources& resources) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(-resources.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +bool operator == (const TJobResources& lhs, const TJobResources& rhs) +{ + return + #define XX(name, Name) lhs.Get##Name() == rhs.Get##Name() && + ITERATE_JOB_RESOURCES(XX) + #undef XX + true; +} + +bool operator != (const TJobResources& lhs, const TJobResources& rhs) +{ + return !(lhs == rhs); +} + +bool Dominates(const TJobResources& lhs, const TJobResources& rhs) +{ + return + #define XX(name, Name) lhs.Get##Name() >= rhs.Get##Name() && + ITERATE_JOB_RESOURCES(XX) + #undef XX + true; +} + +bool StrictlyDominates(const TJobResources& lhs, const TJobResources& rhs) +{ + return + #define XX(name, Name) lhs.Get##Name() > rhs.Get##Name() && + ITERATE_JOB_RESOURCES(XX) + #undef XX + true; +} + +TJobResources Max(const TJobResources& lhs, const TJobResources& rhs) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(std::max(lhs.Get##Name(), rhs.Get##Name())); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TJobResources Min(const TJobResources& lhs, const TJobResources& rhs) +{ + TJobResources result; + #define XX(name, Name) result.Set##Name(std::min(lhs.Get##Name(), rhs.Get##Name())); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +bool TJobResourcesConfig::IsNonTrivial() +{ + bool isNonTrivial = false; + ForEachResource([this, &isNonTrivial] (auto TJobResourcesConfig::* resourceDataMember, EJobResourceType /*resourceType*/) { + isNonTrivial = isNonTrivial || (this->*resourceDataMember).has_value(); + }); + return isNonTrivial; +} + +bool TJobResourcesConfig::IsEqualTo(const TJobResourcesConfig& other) +{ + bool result = true; + ForEachResource([this, &result, &other] (auto TJobResourcesConfig::* resourceDataMember, EJobResourceType /*resourceType*/) { + result = result && (this->*resourceDataMember == other.*resourceDataMember); + }); + return result; +} + +TJobResourcesConfig& TJobResourcesConfig::operator+=(const TJobResourcesConfig& addend) +{ + ForEachResource([this, &addend] (auto TJobResourcesConfig::* resourceDataMember, EJobResourceType /*resourceType*/) { + if (!(addend.*resourceDataMember).has_value()) { + return; + } + if ((this->*resourceDataMember).has_value()) { + *(this->*resourceDataMember) += *(addend.*resourceDataMember); + } else { + this->*resourceDataMember = addend.*resourceDataMember; + } + }); + return *this; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/job_resources.h b/yt/yt/library/vector_hdrf/job_resources.h new file mode 100644 index 0000000000..c33418225d --- /dev/null +++ b/yt/yt/library/vector_hdrf/job_resources.h @@ -0,0 +1,139 @@ +#pragma once + +#include <yt/yt/library/numeric/fixed_point_number.h> + +// TODO(ignat): migrate to enum class +#include <library/cpp/yt/misc/enum.h> +#include <library/cpp/yt/misc/property.h> + +#include <optional> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +// Uses precision of 2 decimal digits. +using TCpuResource = TFixedPointNumber<i64, 2>; + +//////////////////////////////////////////////////////////////////////////////// + +// Implementation detail. +class TEmptyJobResourcesBase +{ }; + +class TJobResources + : public TEmptyJobResourcesBase +{ +public: + DEFINE_BYVAL_RW_PROPERTY(i64, UserSlots); + DEFINE_BYVAL_RW_PROPERTY(TCpuResource, Cpu); + DEFINE_BYVAL_RW_PROPERTY(int, Gpu); + DEFINE_BYVAL_RW_PROPERTY(i64, Memory); + DEFINE_BYVAL_RW_PROPERTY(i64, Network); + +public: + inline void SetCpu(double cpu) + { + Cpu_ = TCpuResource(cpu); + } + + TJobResources() = default; + TJobResources(const TJobResources&) = default; + TJobResources& operator=(const TJobResources& other) = default; + + static TJobResources Infinite(); +}; + +#define ITERATE_JOB_RESOURCES(XX) \ + XX(user_slots, UserSlots) \ + XX(cpu, Cpu) \ + XX(gpu, Gpu) \ + XX(user_memory, Memory) \ + XX(network, Network) + +// NB(antonkikh): Resource types must be numbered from 0 to N - 1. +DEFINE_ENUM(EJobResourceType, + (UserSlots) + (Cpu) + (Gpu) + (Memory) + (Network) +); + +EJobResourceType GetDominantResource( + const TJobResources& demand, + const TJobResources& limits); + +double GetDominantResourceUsage( + const TJobResources& usage, + const TJobResources& limits); + +double GetResource( + const TJobResources& resources, + EJobResourceType type); + +void SetResource( + TJobResources& resources, + EJobResourceType type, + double value); + +double GetMinResourceRatio( + const TJobResources& nominator, + const TJobResources& denominator); + +double GetMaxResourceRatio( + const TJobResources& nominator, + const TJobResources& denominator); + +TJobResources operator + (const TJobResources& lhs, const TJobResources& rhs); +TJobResources& operator += (TJobResources& lhs, const TJobResources& rhs); + +TJobResources operator - (const TJobResources& lhs, const TJobResources& rhs); +TJobResources& operator -= (TJobResources& lhs, const TJobResources& rhs); + +TJobResources operator * (const TJobResources& lhs, i64 rhs); +TJobResources operator * (const TJobResources& lhs, double rhs); +TJobResources& operator *= (TJobResources& lhs, i64 rhs); +TJobResources& operator *= (TJobResources& lhs, double rhs); + +TJobResources operator - (const TJobResources& resources); + +bool operator == (const TJobResources& lhs, const TJobResources& rhs); +bool operator != (const TJobResources& lhs, const TJobResources& rhs); + +bool Dominates(const TJobResources& lhs, const TJobResources& rhs); +bool StrictlyDominates(const TJobResources& lhs, const TJobResources& rhs); + +TJobResources Max(const TJobResources& lhs, const TJobResources& rhs); +TJobResources Min(const TJobResources& lhs, const TJobResources& rhs); + +//////////////////////////////////////////////////////////////////////////////// + +class TJobResourcesConfig +{ +public: + std::optional<int> UserSlots; + std::optional<double> Cpu; + std::optional<int> Network; + std::optional<i64> Memory; + std::optional<int> Gpu; + + template <class T> + static void ForEachResource(T processResource) + { + processResource(&TJobResourcesConfig::UserSlots, EJobResourceType::UserSlots); + processResource(&TJobResourcesConfig::Cpu, EJobResourceType::Cpu); + processResource(&TJobResourcesConfig::Network, EJobResourceType::Network); + processResource(&TJobResourcesConfig::Memory, EJobResourceType::Memory); + processResource(&TJobResourcesConfig::Gpu, EJobResourceType::Gpu); + } + + bool IsNonTrivial(); + bool IsEqualTo(const TJobResourcesConfig& other); + + TJobResourcesConfig& operator+=(const TJobResourcesConfig& addend); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf diff --git a/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers-inl.h b/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers-inl.h new file mode 100644 index 0000000000..cbcff7d32f --- /dev/null +++ b/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers-inl.h @@ -0,0 +1,52 @@ +#ifndef PIECEWISE_LINEAR_FUNCTION_HELPERS_H_ +#error "Direct inclusion of this file is not allowed, include piecewise_linear_function_helpers.h" +// For the sake of sane code completion. +#include "piecewise_linear_function_helpers.h" +#endif + +namespace NYT::NVectorHdrf::NDetail { + +//////////////////////////////////////////////////////////////////////////////// + +template <class TPiecewiseFunction> +void VerifyNondecreasing(const TPiecewiseFunction& vecFunc, const NLogging::TLogger& Logger) +{ + using TValue = typename TPiecewiseFunction::TValueType; + + auto dominates = [&] (const TValue& lhs, const TValue& rhs) -> bool { + if constexpr (std::is_same_v<TValue, double>) { + return lhs >= rhs; + } else { + return Dominates(lhs, rhs); + } + }; + + for (const auto& segment : vecFunc.Segments()) { + if (dominates(segment.RightValue(), segment.LeftValue())) { + continue; + } + + YT_LOG_ERROR( + "The vector function is decreasing at segment {%.16lf, %.16lf} (BoundValues: {%.16lf, %.16lf}, %s)", + segment.LeftBound(), + segment.RightBound(), + segment.LeftValue(), + segment.RightValue()); + + Y_VERIFY_DEBUG(false); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class TSegment> +TSegment ConnectSegments(const TSegment& firstSegment, const TSegment& secondSegment) { + return TSegment( + {firstSegment.LeftBound(), firstSegment.LeftValue()}, + {secondSegment.RightBound(), secondSegment.RightValue()} + ); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf::NDetail diff --git a/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.cpp b/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.cpp new file mode 100644 index 0000000000..5315607c0c --- /dev/null +++ b/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.cpp @@ -0,0 +1,136 @@ +#include "piecewise_linear_function_helpers.h" + +namespace NYT::NVectorHdrf::NDetail { + +//////////////////////////////////////////////////////////////////////////////// + +TScalarPiecewiseLinearFunction ExtractComponent(int resourceIndex, const TVectorPiecewiseLinearFunction& vecFunc) +{ + TScalarPiecewiseLinearFunction::TBuilder builder; + + for (const auto& segment : vecFunc.Segments()) { + builder.PushSegment(decltype(builder)::TSegment( + {segment.LeftBound(), segment.LeftValue()[resourceIndex]}, + {segment.RightBound(), segment.RightValue()[resourceIndex]})); + } + + return builder.Finish(); +} + +TScalarPiecewiseSegment ExtractComponent(int resourceIndex, const TVectorPiecewiseSegment& vecSegment) +{ + return TScalarPiecewiseSegment{ + {vecSegment.LeftBound(), vecSegment.LeftValue()[resourceIndex]}, + {vecSegment.RightBound(), vecSegment.RightValue()[resourceIndex]} + }; +} + +TUnpackedVectorPiecewiseSegment UnpackVectorSegment(const TVectorPiecewiseSegment& vecSegment) +{ + TUnpackedVectorPiecewiseSegment result; + + for (int r = 0; r < ResourceCount; ++r) { + result.push_back(ExtractComponent(r, vecSegment)); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +int CompareSegments(const TScalarPiecewiseSegment& firstSegment, const TScalarPiecewiseSegment& secondSegment) +{ + double firstDeltaBound = firstSegment.RightBound() - firstSegment.LeftBound(); + double firstDeltaValue = firstSegment.RightValue() - firstSegment.LeftValue(); + double secondDeltaBound = secondSegment.RightBound() - secondSegment.LeftBound(); + double secondDeltaValue = secondSegment.RightValue() - secondSegment.LeftValue(); + double crossProduct = firstDeltaBound * secondDeltaValue - secondDeltaBound * firstDeltaValue; + + if (crossProduct > 0) { + return 1; + } + if (crossProduct < 0) { + return -1; + } + return 0; +} + +TUnpackedVectorPiecewiseSegmentBounds GetBounds(const TUnpackedVectorPiecewiseSegment& segments, double epsilon) +{ + TUnpackedVectorPiecewiseSegment topBounds; + TUnpackedVectorPiecewiseSegment bottomBounds; + + for (const auto& segment : segments) { + topBounds.push_back({ + {segment.LeftBound(), segment.LeftValue()}, + {segment.RightBound(), segment.RightValue() + epsilon}}); + bottomBounds.push_back({ + {segment.LeftBound(), segment.LeftValue()}, + {segment.RightBound(), segment.RightValue() - epsilon}}); + } + + return {topBounds, bottomBounds}; +} + +TVectorPiecewiseLinearFunction CompressFunction(const TVectorPiecewiseLinearFunction& vecFunc, double epsilon) +{ + const auto& functionSegments = vecFunc.Segments(); + Y_VERIFY(!functionSegments.empty()); + + TVectorPiecewiseLinearFunction::TBuilder builder; + + // For an interval of function's segments, |accumulatedSegment| is the segment that connects + // the start of the leftmost segment of the interval with the end of the rightmost segment. + auto accumulatedSegment = functionSegments.front(); + // We say that a segment is "feasible" if it is below the top bound and above the bottom bound for each resource. + auto accumulatedBounds = GetBounds(UnpackVectorSegment(functionSegments.front()), epsilon); + + bool isFirst = true; + for (const auto& currentSegment : functionSegments) { + if (isFirst) { + isFirst = false; + continue; + } + + auto newAccumulatedSegment = ConnectSegments(accumulatedSegment, currentSegment); + auto unpackedNewAccumulatedSegment = UnpackVectorSegment(newAccumulatedSegment); + auto currentBounds = GetBounds(unpackedNewAccumulatedSegment, epsilon); + + bool canExtendAccumulatedInterval = true; + for (int r = 0; r < ResourceCount; ++r) { + // If the accumulated top bound for resource |r| is above the top bound of current segment, then update the top bound. + if (CompareSegments(currentBounds.Top[r], accumulatedBounds.Top[r]) > 0) { + accumulatedBounds.Top[r] = currentBounds.Top[r]; + } + // If the accumulated bottom bound for resource |r| is below the bottom bound of current segment, then update the bottom bound. + if (CompareSegments(currentBounds.Bottom[r], accumulatedBounds.Bottom[r]) < 0) { + accumulatedBounds.Bottom[r] = currentBounds.Bottom[r]; + } + // If |accumulatedSegment| is infeasible, we cannot extend the interval of merged segments any further. + if (CompareSegments(accumulatedBounds.Bottom[r], unpackedNewAccumulatedSegment[r]) <= 0 + || CompareSegments(accumulatedBounds.Top[r], unpackedNewAccumulatedSegment[r]) >= 0) + { + canExtendAccumulatedInterval = false; + break; + } + } + + // If we can greedily extend the interval, do so, otherwise merge the accumulated interval and start a new one. + if (canExtendAccumulatedInterval) { + accumulatedSegment = newAccumulatedSegment; + } else { + builder.PushSegment(accumulatedSegment); + accumulatedSegment = currentSegment; + accumulatedBounds = GetBounds(UnpackVectorSegment(currentSegment), epsilon); + } + } + + // Finally, merge the last accumulated interval of segments. + builder.PushSegment(accumulatedSegment); + + return builder.Finish(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf::NDetail diff --git a/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.h b/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.h new file mode 100644 index 0000000000..c2cbe0886a --- /dev/null +++ b/yt/yt/library/vector_hdrf/piecewise_linear_function_helpers.h @@ -0,0 +1,71 @@ +#pragma once + +#include "public.h" + +#include <yt/yt/library/vector_hdrf/resource_vector.h> + +#include <library/cpp/yt/logging/logger.h> + +namespace NYT::NVectorHdrf::NDetail { + +//////////////////////////////////////////////////////////////////////////////// + +static const double CompressFunctionEpsilon = 1e-15; + +//////////////////////////////////////////////////////////////////////////////// + +NVectorHdrf::TScalarPiecewiseLinearFunction ExtractComponent(int resourceIndex, const NVectorHdrf::TVectorPiecewiseLinearFunction& vecFunc); + +NVectorHdrf::TScalarPiecewiseSegment ExtractComponent(int resourceIndex, const NVectorHdrf::TVectorPiecewiseSegment& vecSegment); + +//! Transposed representation of a vector-valued segment, where its individual components are stored as separate scalar-valued vectors. +using TUnpackedVectorPiecewiseSegment = std::vector<NVectorHdrf::TScalarPiecewiseSegment>; +TUnpackedVectorPiecewiseSegment UnpackVectorSegment(const NVectorHdrf::TVectorPiecewiseSegment& vecSegment); + +//////////////////////////////////////////////////////////////////////////////// + +template <class TPiecewiseFunction> +void VerifyNondecreasing(const TPiecewiseFunction& vecFunc, const NLogging::TLogger& Logger); + +//////////////////////////////////////////////////////////////////////////////// + +//! Given two vectors U and V, their orientation is the sign of their 2D cross product. +//! Orientation of two segments is defined as the orientation of their corresponding vectors. +//! Currently we only have to deal with monotonic functions and all segments are pointing to the upper-right, +//! so if the orientation of vectors U and V is positive, then it means that V is "above" U (the same for negative/below). +//! +//! Returns the orientation of two segments. +int CompareSegments(const NVectorHdrf::TScalarPiecewiseSegment& firstSegment, const NVectorHdrf::TScalarPiecewiseSegment& secondSegment); + +//! Returns the segment that connects the start of |firstSegment| with the end of |secondSegment|. +template <class TSegment> +TSegment ConnectSegments(const TSegment& firstSegment, const TSegment& secondSegment); + +//! Given a scalar segment, if we shift its right value by +|epsilon| and -|epsilon|, +//! then the resulting segments are called the top and bottom bounds for this segment. +//! This struct holds bounds for individual components of a vector-valued segment. +struct TUnpackedVectorPiecewiseSegmentBounds +{ + TUnpackedVectorPiecewiseSegment Top; + TUnpackedVectorPiecewiseSegment Bottom; +}; +TUnpackedVectorPiecewiseSegmentBounds GetBounds(const TUnpackedVectorPiecewiseSegment& segments, double epsilon); + +//! Transforms the function so that: +//! (1) the resulting function differs from the original by less than |epsilon| pointwise, and +//! (2) it has as few segments as possible. +//! +//! Unfortunately, we couldn't think of an efficient algorithm that solves this problem exactly, +//! so here we implemented a greedy algorithm, that gives a good approximation (we think). +//! Details: https://wiki.yandex-team.ru/yt/internal/hdrfv-function-compression/. +NVectorHdrf::TVectorPiecewiseLinearFunction CompressFunction( + const NVectorHdrf::TVectorPiecewiseLinearFunction& vecFunc, + double epsilon); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf::NDetail + +#define PIECEWISE_LINEAR_FUNCTION_HELPERS_H_ +#include "piecewise_linear_function_helpers-inl.h" +#undef PIECEWISE_LINEAR_FUNCTION_HELPERS_H_ diff --git a/yt/yt/library/vector_hdrf/private.h b/yt/yt/library/vector_hdrf/private.h new file mode 100644 index 0000000000..1494363fa3 --- /dev/null +++ b/yt/yt/library/vector_hdrf/private.h @@ -0,0 +1,11 @@ +#include <library/cpp/yt/logging/logger.h> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +inline const NLogging::TLogger FairShareLogger{"FairShare"}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf diff --git a/yt/yt/library/vector_hdrf/public.h b/yt/yt/library/vector_hdrf/public.h new file mode 100644 index 0000000000..687bfc6ec8 --- /dev/null +++ b/yt/yt/library/vector_hdrf/public.h @@ -0,0 +1,30 @@ +#pragma once + +// TODO(ignat): migrate to enum class +#include <library/cpp/yt/misc/enum.h> + +#include <yt/yt/core/misc/error_code.h> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +DEFINE_ENUM(ESchedulingMode, + (Fifo) + (FairShare) +); + +DEFINE_ENUM(EIntegralGuaranteeType, + (None) + (Burst) + (Relaxed) +); + +YT_DEFINE_ERROR_ENUM( + ((PoolTreeGuaranteesOvercommit) (29000)) +); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/resource_helpers.cpp b/yt/yt/library/vector_hdrf/resource_helpers.cpp new file mode 100644 index 0000000000..f05629e7f7 --- /dev/null +++ b/yt/yt/library/vector_hdrf/resource_helpers.cpp @@ -0,0 +1,157 @@ +#include "resource_helpers.h" + +#include <yt/yt/core/ytree/fluent.h> + +#include <yt/yt/library/numeric/serialize/fixed_point_number.h> + +namespace NYT::NVectorHdrf { + +using namespace NYson; +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +TJobResources ToJobResources(const TJobResourcesConfig& config, TJobResources defaultValue) +{ + if (config.UserSlots) { + defaultValue.SetUserSlots(*config.UserSlots); + } + if (config.Cpu) { + defaultValue.SetCpu(*config.Cpu); + } + if (config.Network) { + defaultValue.SetNetwork(*config.Network); + } + if (config.Memory) { + defaultValue.SetMemory(*config.Memory); + } + if (config.Gpu) { + defaultValue.SetGpu(*config.Gpu); + } + return defaultValue; +} + +//////////////////////////////////////////////////////////////////////////////// + +void Serialize(const TJobResources& resources, IYsonConsumer* consumer) +{ + BuildYsonFluently(consumer) + .BeginMap() + #define XX(name, Name) .Item(#name).Value(resources.Get##Name()) + ITERATE_JOB_RESOURCES(XX) + #undef XX + .EndMap(); +} + +void Deserialize(TJobResources& resources, INodePtr node) +{ + auto mapNode = node->AsMap(); + #define XX(name, Name) \ + if (auto child = mapNode->FindChild(#name)) { \ + auto value = resources.Get##Name(); \ + Deserialize(value, child); \ + resources.Set##Name(value); \ + } + ITERATE_JOB_RESOURCES(XX) + #undef XX +} + +void FormatValue(TStringBuilderBase* builder, const TJobResources& resources, TStringBuf /* format */) +{ + builder->AppendFormat( + "{UserSlots: %v, Cpu: %v, Gpu: %v, Memory: %vMB, Network: %v}", + resources.GetUserSlots(), + resources.GetCpu(), + resources.GetGpu(), + resources.GetMemory() / 1_MB, + resources.GetNetwork()); +} + +//////////////////////////////////////////////////////////////////////////////// + +void Serialize(const TResourceVolume& volume, NYson::IYsonConsumer* consumer) +{ + NYTree::BuildYsonFluently(consumer) + .BeginMap() + #define XX(name, Name) .Item(#name).Value(volume.Get##Name()) + ITERATE_JOB_RESOURCES(XX) + #undef XX + .EndMap(); +} + +void Deserialize(TResourceVolume& volume, INodePtr node) +{ + auto mapNode = node->AsMap(); + #define XX(name, Name) \ + if (auto child = mapNode->FindChild(#name)) { \ + auto value = volume.Get##Name(); \ + Deserialize(value, child); \ + volume.Set##Name(value); \ + } + ITERATE_JOB_RESOURCES(XX) + #undef XX +} + +void Deserialize(TResourceVolume& volume, TYsonPullParserCursor* cursor) +{ + Deserialize(volume, ExtractTo<INodePtr>(cursor)); +} + +void Serialize(const TResourceVector& resourceVector, IYsonConsumer* consumer) +{ + auto fluent = NYTree::BuildYsonFluently(consumer).BeginMap(); + for (int index = 0; index < ResourceCount; ++index) { + fluent + .Item(FormatEnum(TResourceVector::GetResourceTypeById(index))) + .Value(resourceVector[index]); + } + fluent.EndMap(); +} + +void FormatValue(TStringBuilderBase* builder, const TResourceVolume& volume, TStringBuf /* format */) +{ + builder->AppendFormat( + "{UserSlots: %.2f, Cpu: %v, Gpu: %.2f, Memory: %.2fMBs, Network: %.2f}", + volume.GetUserSlots(), + volume.GetCpu(), + volume.GetGpu(), + volume.GetMemory() / 1_MB, + volume.GetNetwork()); +} + +TString ToString(const TResourceVolume& volume) +{ + return ToStringViaBuilder(volume); +} + +void FormatValue(TStringBuilderBase* builder, const TResourceVector& resourceVector, TStringBuf format) +{ + auto getResourceSuffix = [] (EJobResourceType resourceType) { + const auto& resourceNames = TEnumTraits<EJobResourceType>::GetDomainNames(); + switch (resourceType) { + case EJobResourceType::UserSlots: + // S is for Slots. + return 'S'; + + default: + return resourceNames[ToUnderlying(resourceType)][0]; + } + }; + + builder->AppendChar('['); + bool isFirst = true; + for (auto resourceType : TEnumTraits<EJobResourceType>::GetDomainValues()) { + if (!isFirst) { + builder->AppendChar(' '); + } + isFirst = false; + + FormatValue(builder, resourceVector[resourceType], format); + builder->AppendChar(getResourceSuffix(resourceType)); + } + builder->AppendChar(']'); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf diff --git a/yt/yt/library/vector_hdrf/resource_helpers.h b/yt/yt/library/vector_hdrf/resource_helpers.h new file mode 100644 index 0000000000..4b62aa3681 --- /dev/null +++ b/yt/yt/library/vector_hdrf/resource_helpers.h @@ -0,0 +1,43 @@ +#pragma once + +#include "public.h" +#include "resource_vector.h" +#include "resource_volume.h" + +#include <yt/yt/core/yson/consumer.h> + +#include <yt/yt/core/ytree/node.h> + +#include <yt/yt/core/misc/string_builder.h> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +TJobResources ToJobResources(const TJobResourcesConfig& config, TJobResources defaultValue); + +//////////////////////////////////////////////////////////////////////////////// + +void Serialize(const TJobResources& resources, NYson::IYsonConsumer* consumer); +void Deserialize(TJobResources& resources, NYTree::INodePtr node); + +void FormatValue(TStringBuilderBase* builder, const TJobResources& resources, TStringBuf /* format */); + +//////////////////////////////////////////////////////////////////////////////// + +void Serialize(const TResourceVolume& volume, NYson::IYsonConsumer* consumer); +void Deserialize(TResourceVolume& volume, NYTree::INodePtr node); +void Deserialize(TResourceVolume& volume, NYson::TYsonPullParserCursor* cursor); + +void Serialize(const TResourceVector& resourceVector, NYson::IYsonConsumer* consumer); + +void FormatValue(TStringBuilderBase* builder, const TResourceVolume& volume, TStringBuf /* format */); +TString ToString(const TResourceVolume& volume); + +void FormatValue(TStringBuilderBase* builder, const TResourceVector& resourceVector, TStringBuf format); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + + diff --git a/yt/yt/library/vector_hdrf/resource_vector.cpp b/yt/yt/library/vector_hdrf/resource_vector.cpp new file mode 100644 index 0000000000..8661e0926f --- /dev/null +++ b/yt/yt/library/vector_hdrf/resource_vector.cpp @@ -0,0 +1,35 @@ +#include "resource_vector.h" + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +TResourceVector TResourceVector::FromJobResources( + const TJobResources& resources, + const TJobResources& totalLimits, + double zeroDivByZero, + double oneDivByZero) +{ + auto computeResult = [&] (auto resourceValue, auto resourceLimit, double& result) { + if (static_cast<double>(resourceLimit) == 0.0) { + if (static_cast<double>(resourceValue) == 0.0) { + result = zeroDivByZero; + } else { + result = oneDivByZero; + } + } else { + result = static_cast<double>(resourceValue) / static_cast<double>(resourceLimit); + } + }; + + TResourceVector resultVector; + #define XX(name, Name) computeResult(resources.Get##Name(), totalLimits.Get##Name(), resultVector[EJobResourceType::Name]); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return resultVector; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/resource_vector.h b/yt/yt/library/vector_hdrf/resource_vector.h new file mode 100644 index 0000000000..4362486e15 --- /dev/null +++ b/yt/yt/library/vector_hdrf/resource_vector.h @@ -0,0 +1,113 @@ +#pragma once + +#include <yt/yt/library/vector_hdrf/job_resources.h> + +#include <yt/yt/library/numeric/binary_search.h> +#include <yt/yt/library/numeric/double_array.h> +#include <yt/yt/library/numeric/piecewise_linear_function.h> + +#include <util/generic/cast.h> + +#include <cmath> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +static constexpr double RatioComputationPrecision = 1e-9; +static constexpr double RatioComparisonPrecision = 1e-4; +static constexpr double InfiniteResourceAmount = 1e10; + +//////////////////////////////////////////////////////////////////////////////// + +inline constexpr int GetResourceCount() noexcept +{ + int res = 0; + #define XX(name, Name) do { res += 1; } while(false); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return res; +} + +static constexpr int ResourceCount = GetResourceCount(); +static_assert(TEnumTraits<EJobResourceType>::GetDomainSize() == ResourceCount); + +class TResourceVector + : public TDoubleArrayBase<ResourceCount, TResourceVector> +{ +private: + using TBase = TDoubleArrayBase<ResourceCount, TResourceVector>; + +public: + using TBase::TDoubleArrayBase; + using TBase::operator[]; + + Y_FORCE_INLINE double& operator[](EJobResourceType resourceType) + { + static_assert(TEnumTraits<EJobResourceType>::GetDomainSize() == ResourceCount); + return (*this)[GetIdByResourceType(resourceType)]; + } + + Y_FORCE_INLINE const double& operator[](EJobResourceType resourceType) const + { + static_assert(TEnumTraits<EJobResourceType>::GetDomainSize() == ResourceCount); + return (*this)[GetIdByResourceType(resourceType)]; + } + + static TResourceVector FromJobResources( + const TJobResources& resources, + const TJobResources& totalLimits, + double zeroDivByZero = 0.0, + double oneDivByZero = 0.0); + + static constexpr TResourceVector SmallEpsilon() + { + return FromDouble(RatioComputationPrecision); + } + + static constexpr TResourceVector Epsilon() + { + return FromDouble(RatioComparisonPrecision); + } + + static constexpr TResourceVector Infinity() + { + return FromDouble(InfiniteResourceAmount); + } + + Y_FORCE_INLINE static constexpr int GetIdByResourceType(EJobResourceType resourceType) + { + return static_cast<int>(resourceType); + } + + Y_FORCE_INLINE static constexpr EJobResourceType GetResourceTypeById(int resourceId) + { + return static_cast<EJobResourceType>(resourceId); + } +}; + +inline TJobResources operator*(const TJobResources& lhs, const TResourceVector& rhs) +{ + using std::round; + + TJobResources result; + #define XX(name, Name) do { \ + auto newValue = round(lhs.Get##Name() * rhs[EJobResourceType::Name]); \ + result.Set##Name(static_cast<decltype(lhs.Get##Name())>(newValue)); \ + } while (false); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +using TVectorPiecewiseSegment = TPiecewiseSegment<TResourceVector>; +using TScalarPiecewiseSegment = TPiecewiseSegment<double>; +using TVectorPiecewiseLinearFunction = TPiecewiseLinearFunction<TResourceVector>; +using TScalarPiecewiseLinearFunction = TPiecewiseLinearFunction<double>; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/resource_volume.cpp b/yt/yt/library/vector_hdrf/resource_volume.cpp new file mode 100644 index 0000000000..8629a1f01c --- /dev/null +++ b/yt/yt/library/vector_hdrf/resource_volume.cpp @@ -0,0 +1,133 @@ +#include "resource_volume.h" + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +using std::round; + +TResourceVolume::TResourceVolume(const TJobResources& jobResources, TDuration duration) +{ + auto seconds = duration.SecondsFloat(); + + #define XX(name, Name) Name##_ = static_cast<decltype(Name##_)>(jobResources.Get##Name() * seconds); + ITERATE_JOB_RESOURCES(XX) + #undef XX +} + +double TResourceVolume::GetMinResourceRatio(const TJobResources& denominator) const +{ + double result = std::numeric_limits<double>::max(); + bool updated = false; + auto update = [&] (auto a, auto b) { + if (static_cast<double>(b) > 0.0) { + result = std::min(result, static_cast<double>(a) / static_cast<double>(b)); + updated = true; + } + }; + #define XX(name, Name) update(Get##Name(), denominator.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return updated ? result : 0.0; +} + +bool TResourceVolume::IsZero() const +{ + bool result = true; + TResourceVolume::ForEachResource([&] (EJobResourceType /*resourceType*/, auto TResourceVolume::* resourceDataMember) { + result = result && this->*resourceDataMember == 0; + }); + return result; +} + +TResourceVolume Max(const TResourceVolume& lhs, const TResourceVolume& rhs) +{ + TResourceVolume result; + #define XX(name, Name) result.Set##Name(std::max(lhs.Get##Name(), rhs.Get##Name())); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +TResourceVolume Min(const TResourceVolume& lhs, const TResourceVolume& rhs) +{ + TResourceVolume result; + #define XX(name, Name) result.Set##Name(std::min(lhs.Get##Name(), rhs.Get##Name())); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return result; +} + +bool operator == (const TResourceVolume& lhs, const TResourceVolume& rhs) +{ + return + #define XX(name, Name) lhs.Get##Name() == rhs.Get##Name() && + ITERATE_JOB_RESOURCES(XX) + #undef XX + true; +} + +TResourceVolume& operator += (TResourceVolume& lhs, const TResourceVolume& rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() + rhs.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TResourceVolume& operator -= (TResourceVolume& lhs, const TResourceVolume& rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() - rhs.Get##Name()); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TResourceVolume& operator *= (TResourceVolume& lhs, double rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() * rhs); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TResourceVolume& operator /= (TResourceVolume& lhs, double rhs) +{ + #define XX(name, Name) lhs.Set##Name(lhs.Get##Name() / rhs); + ITERATE_JOB_RESOURCES(XX) + #undef XX + return lhs; +} + +TResourceVolume operator + (const TResourceVolume& lhs, const TResourceVolume& rhs) +{ + TResourceVolume result = lhs; + result += rhs; + return result; +} + +TResourceVolume operator - (const TResourceVolume& lhs, const TResourceVolume& rhs) +{ + TResourceVolume result = lhs; + result -= rhs; + return result; +} + +TResourceVolume operator * (const TResourceVolume& lhs, double rhs) +{ + TResourceVolume result = lhs; + result *= rhs; + return result; +} + +TResourceVolume operator / (const TResourceVolume& lhs, double rhs) +{ + TResourceVolume result = lhs; + result /= rhs; + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/resource_volume.h b/yt/yt/library/vector_hdrf/resource_volume.h new file mode 100644 index 0000000000..01c6c2b9da --- /dev/null +++ b/yt/yt/library/vector_hdrf/resource_volume.h @@ -0,0 +1,57 @@ +#pragma once + +#include <yt/yt/library/vector_hdrf/job_resources.h> + +#include <library/cpp/yt/misc/property.h> + +#include <util/datetime/base.h> + +namespace NYT::NVectorHdrf { + +//////////////////////////////////////////////////////////////////////////////// + +class TResourceVolume +{ +public: + DEFINE_BYVAL_RW_PROPERTY(double, UserSlots); + DEFINE_BYVAL_RW_PROPERTY(TCpuResource, Cpu); + DEFINE_BYVAL_RW_PROPERTY(double, Gpu); + DEFINE_BYVAL_RW_PROPERTY(double, Memory); + DEFINE_BYVAL_RW_PROPERTY(double, Network); + + TResourceVolume() = default; + + explicit TResourceVolume(const TJobResources& jobResources, TDuration duration); + + double GetMinResourceRatio(const TJobResources& denominator) const; + + bool IsZero() const; + + template <class TFunction> + static void ForEachResource(TFunction processResource) + { + processResource(EJobResourceType::UserSlots, &TResourceVolume::UserSlots_); + processResource(EJobResourceType::Cpu, &TResourceVolume::Cpu_); + processResource(EJobResourceType::Network, &TResourceVolume::Network_); + processResource(EJobResourceType::Memory, &TResourceVolume::Memory_); + processResource(EJobResourceType::Gpu, &TResourceVolume::Gpu_); + } +}; + +TResourceVolume Max(const TResourceVolume& lhs, const TResourceVolume& rhs); +TResourceVolume Min(const TResourceVolume& lhs, const TResourceVolume& rhs); + +bool operator == (const TResourceVolume& lhs, const TResourceVolume& rhs); +TResourceVolume& operator += (TResourceVolume& lhs, const TResourceVolume& rhs); +TResourceVolume& operator -= (TResourceVolume& lhs, const TResourceVolume& rhs); +TResourceVolume& operator *= (TResourceVolume& lhs, double rhs); +TResourceVolume& operator /= (TResourceVolume& lhs, double rhs); +TResourceVolume operator + (const TResourceVolume& lhs, const TResourceVolume& rhs); +TResourceVolume operator - (const TResourceVolume& lhs, const TResourceVolume& rhs); +TResourceVolume operator * (const TResourceVolume& lhs, double rhs); +TResourceVolume operator / (const TResourceVolume& lhs, double rhs); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NVectorHdrf + diff --git a/yt/yt/library/vector_hdrf/ya.make b/yt/yt/library/vector_hdrf/ya.make new file mode 100644 index 0000000000..f60c2e43e8 --- /dev/null +++ b/yt/yt/library/vector_hdrf/ya.make @@ -0,0 +1,26 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + job_resources.cpp + piecewise_linear_function_helpers.cpp + resource_vector.cpp + resource_volume.cpp + # Files below this line, depends on core/ + resource_helpers.cpp + fair_share_update.cpp +) + +PEERDIR( + yt/yt/library/numeric + # Core dependencies. + yt/yt/library/numeric/serialize + yt/yt/core +) + +END() + +RECURSE_FOR_TESTS( + unittests +) diff --git a/yt/yt/library/xor_filter/public.h b/yt/yt/library/xor_filter/public.h new file mode 100644 index 0000000000..230e50f3f5 --- /dev/null +++ b/yt/yt/library/xor_filter/public.h @@ -0,0 +1,14 @@ +#pragma once + +#include <library/cpp/yt/memory/ref_counted.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +struct TXorFilterMeta; +class TXorFilter; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/xor_filter/xor_filter.cpp b/yt/yt/library/xor_filter/xor_filter.cpp new file mode 100644 index 0000000000..7b2389a13d --- /dev/null +++ b/yt/yt/library/xor_filter/xor_filter.cpp @@ -0,0 +1,300 @@ +#include "xor_filter.h" + +#include <yt/yt/core/misc/numeric_helpers.h> +#include <yt/yt/core/misc/error.h> +#include <yt/yt/core/misc/serialize.h> + +#include <library/cpp/iterator/enumerate.h> + +#include <util/digest/multi.h> + +#include <queue> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +bool TXorFilter::IsInitialized() const +{ + return static_cast<bool>(Data_); +} + +void TXorFilter::Initialize(TSharedRef data) +{ + YT_VERIFY(!IsInitialized()); + Data_ = std::move(data); + LoadMeta(); +} + +bool TXorFilter::Contains(TFingerprint key) const +{ + YT_ASSERT(IsInitialized()); + + ui64 actualXorFingerprint = 0; + for (int hashIndex = 0; hashIndex < 3; ++hashIndex) { + actualXorFingerprint ^= GetEntry(Data_, GetSlot(key, hashIndex)); + } + return actualXorFingerprint == GetExpectedXorFingerprint(key); +} + +int TXorFilter::ComputeSlotCount(int keyCount) +{ + int slotCount = keyCount * LoadFactor + LoadFactorIncrement; + + // Make slotCount a multiple of 3. + slotCount = slotCount / 3 * 3; + + return slotCount; +} + +int TXorFilter::ComputeByteSize(int keyCount, int bitsPerKey) +{ + return DivCeil(ComputeSlotCount(keyCount) * bitsPerKey, WordSize) * sizeof(ui64); +} + +int TXorFilter::ComputeAllocationSize() const +{ + int dataSize = DivCeil(SlotCount_ * BitsPerKey_, WordSize) * sizeof(ui64); + return FormatVersionSize + dataSize + MetaSize; +} + +ui64 TXorFilter::GetUi64Word(TRef data, int index) const +{ + ui64 result; + std::memcpy( + &result, + data.begin() + index * sizeof(result) + FormatVersionSize, + sizeof(result)); + return result; +} + +void TXorFilter::SetUi64Word(TMutableRef data, int index, ui64 value) const +{ + std::memcpy( + data.begin() + index * sizeof(value) + FormatVersionSize, + &value, + sizeof(value)); +} + +ui64 TXorFilter::GetHash(ui64 key, int hashIndex) const +{ + ui64 hash = Salts_[hashIndex]; + HashCombine(hash, key); + return hash; +} + +ui64 TXorFilter::GetEntry(TRef data, int index) const +{ + // Fast path. + if (BitsPerKey_ == 8) { + return static_cast<ui8>(data[index + FormatVersionSize]); + } + + int startBit = index * BitsPerKey_; + int wordIndex = startBit / WordSize; + int offset = startBit % WordSize; + + auto loWord = GetUi64Word(data, wordIndex); + auto result = loWord >> offset; + + if (offset + BitsPerKey_ > WordSize) { + auto hiWord = GetUi64Word(data, wordIndex + 1); + result |= hiWord << (WordSize - offset); + } + + return result & MaskLowerBits(BitsPerKey_); +} + +void TXorFilter::SetEntry(TMutableRef data, int index, ui64 value) const +{ + // Fast path. + if (BitsPerKey_ == 8) { + data[index + FormatVersionSize] = static_cast<ui8>(value); + } + + int startBit = index * BitsPerKey_; + int wordIndex = startBit / WordSize; + int offset = startBit % WordSize; + + auto loWord = GetUi64Word(data, wordIndex); + loWord &= ~(MaskLowerBits(BitsPerKey_) << offset); + loWord ^= value << offset; + SetUi64Word(data, wordIndex, loWord); + + if (offset + BitsPerKey_ > WordSize) { + auto hiWord = GetUi64Word(data, wordIndex + 1); + hiWord &= ~(MaskLowerBits(BitsPerKey_) >> (WordSize - offset)); + hiWord ^= value >> (WordSize - offset); + SetUi64Word(data, wordIndex + 1, hiWord); + } +} + +int TXorFilter::GetSlot(ui64 key, int hashIndex) const +{ + auto hash = GetHash(key, hashIndex); + + // A faster way to generate an almost uniform integer in [0, SlotCount_ / 3). + // Note the "hash >> 32" part. Somehow higher 32 bits are distributed much + // better than lower ones, and that turned out to be critical for the filter + // building success probability. + auto res = static_cast<ui64>((hash >> 32) * (SlotCount_ / 3)) >> 32; + + return res + (SlotCount_ / 3 * hashIndex); +} + +ui64 TXorFilter::GetExpectedXorFingerprint(ui64 key) const +{ + return GetHash(key, 3) & MaskLowerBits(BitsPerKey_); +} + +void TXorFilter::SaveMeta(TMutableRef data) const +{ + { + char* ptr = data.begin(); + WritePod(ptr, static_cast<i32>(FormatVersion)); + } + + { + char* ptr = data.end() - MetaSize; + WritePod(ptr, Salts_); + WritePod(ptr, BitsPerKey_); + WritePod(ptr, SlotCount_); + YT_VERIFY(ptr == data.end()); + } +} + +void TXorFilter::LoadMeta() +{ + YT_ASSERT(IsInitialized()); + + int formatVersion; + { + const char* ptr = Data_.begin(); + ReadPod(ptr, formatVersion); + } + + if (formatVersion != 1) { + THROW_ERROR_EXCEPTION("Invalid XOR filter format version %v", + formatVersion); + } + + { + const char* ptr = Data_.end() - MetaSize; + ReadPod(ptr, Salts_); + ReadPod(ptr, BitsPerKey_); + ReadPod(ptr, SlotCount_); + YT_VERIFY(ptr == Data_.end()); + } +} + +TXorFilter::TXorFilter(int bitsPerKey, int slotCount) + : BitsPerKey_(bitsPerKey) + , SlotCount_(slotCount) +{ + if (bitsPerKey >= WordSize) { + THROW_ERROR_EXCEPTION("Cannot create xor filter: expected bits_per_key < %v, got %v", + WordSize, + bitsPerKey); + } + + for (int i = 0; i < 4; ++i) { + Salts_[i] = RandomNumber<ui64>(); + } +} + +TSharedRef TXorFilter::Build(TRange<TFingerprint> keys, int bitsPerKey, int trialCount) +{ + for (int trialIndex = 0; trialIndex < trialCount; ++trialIndex) { + if (auto data = DoBuild(keys, bitsPerKey)) { + return data; + } + } + + THROW_ERROR_EXCEPTION("Failed to build XOR filter in %v attempts", + trialCount); +} + +TSharedRef TXorFilter::DoBuild(TRange<TFingerprint> keys, int bitsPerKey) +{ + int slotCount = ComputeSlotCount(std::ssize(keys)); + + TXorFilter filter(bitsPerKey, slotCount); + auto data = TSharedMutableRef::Allocate(filter.ComputeAllocationSize()); + + std::vector<int> assignedKeysXor(slotCount); + std::vector<int> hitCount(slotCount); + + for (auto [keyIndex, key] : Enumerate(keys)) { + for (int hashIndex = 0; hashIndex < 3; ++hashIndex) { + int slot = filter.GetSlot(key, hashIndex); + assignedKeysXor[slot] ^= keyIndex; + ++hitCount[slot]; + } + } + + std::vector<char> inQueue(slotCount); + std::queue<int> queue; + for (int slot = 0; slot < slotCount; ++slot) { + if (hitCount[slot] == 1) { + queue.push(slot); + inQueue[slot] = true; + } + } + + std::vector<std::pair<int, int>> order; + order.reserve(keys.Size()); + + while (!queue.empty()) { + int candidateSlot = queue.front(); + queue.pop(); + + if (hitCount[candidateSlot] == 0) { + continue; + } + + YT_VERIFY(hitCount[candidateSlot] == 1); + + int keyIndex = assignedKeysXor[candidateSlot]; + YT_VERIFY(keyIndex != -1); + order.emplace_back(keyIndex, candidateSlot); + + auto key = keys[keyIndex]; + for (int hashIndex = 0; hashIndex < 3; ++hashIndex) { + int slot = filter.GetSlot(key, hashIndex); + assignedKeysXor[slot] ^= keyIndex; + if (--hitCount[slot] == 1 && !inQueue[slot]) { + inQueue[slot] = true; + queue.push(slot); + } + } + } + + if (std::ssize(order) < std::ssize(keys)) { + return {}; + } + + std::reverse(order.begin(), order.end()); + + for (auto [keyIndex, candidateSlot] : order) { + auto key = keys[keyIndex]; + + YT_VERIFY(filter.GetEntry(data, candidateSlot) == 0); + + ui64 expectedXor = filter.GetExpectedXorFingerprint(key); + ui64 actualXor = 0; + + for (int hashIndex = 0; hashIndex < 3; ++hashIndex) { + actualXor ^= filter.GetEntry(data, filter.GetSlot(key, hashIndex)); + } + + filter.SetEntry(data, candidateSlot, actualXor ^ expectedXor); + } + + filter.SaveMeta(data); + + return data; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/xor_filter/xor_filter.h b/yt/yt/library/xor_filter/xor_filter.h new file mode 100644 index 0000000000..60769c1759 --- /dev/null +++ b/yt/yt/library/xor_filter/xor_filter.h @@ -0,0 +1,80 @@ +#pragma once + +#include "public.h" + +#include <library/cpp/yt/farmhash/farm_hash.h> + +#include <library/cpp/yt/memory/range.h> +#include <library/cpp/yt/memory/ref.h> +#include <library/cpp/yt/memory/ref_counted.h> + +#include <array> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TXorFilter +{ +public: + TXorFilter() = default; + + bool IsInitialized() const; + void Initialize(TSharedRef data); + + bool Contains(TFingerprint key) const; + + static int ComputeByteSize(int keyCount, int bitsPerKey); + + static TSharedRef Build(TRange<TFingerprint> keys, int bitsPerKey, int trialCount = 10); + +private: + constexpr static int WordSize = 64; + static_assert(WordSize % sizeof(ui64) == 0); + + constexpr static double LoadFactor = 1.23; + constexpr static int LoadFactorIncrement = 32; + + constexpr static int FormatVersionSize = sizeof(i32); + static_assert(FormatVersionSize == 4); + + constexpr static int FormatVersion = 1; + + // First three salts are used for computing slots of a certain key. + // The fourth one is used to generate the expected fingerprint of the key. + std::array<ui64, 4> Salts_; + int BitsPerKey_; + int SlotCount_; + + constexpr static int MetaSize = sizeof(BitsPerKey_) + sizeof(Salts_) + sizeof(SlotCount_); + static_assert(MetaSize == 40, "Consider changing FormatVersion"); + + TSharedRef Data_; + + + //! Used when building filter. + TXorFilter(int bitsPerKey, int slotCount); + + void LoadMeta(); + void SaveMeta(TMutableRef data) const; + + ui64 GetUi64Word(TRef data, int index) const; + void SetUi64Word(TMutableRef data, int index, ui64 value) const; + + ui64 GetEntry(TRef data, int index) const; + void SetEntry(TMutableRef data, int index, ui64 value) const; + + ui64 GetHash(ui64 key, int hashIndex) const; + + int GetSlot(ui64 key, int hashIndex) const; + + ui64 GetExpectedXorFingerprint(ui64 key) const; + + static TSharedRef DoBuild(TRange<TFingerprint> keys, int bitsPerKey); + static int ComputeSlotCount(int keyCount); + int ComputeAllocationSize() const; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/library/xor_filter/ya.make b/yt/yt/library/xor_filter/ya.make new file mode 100644 index 0000000000..8a0c8aaa88 --- /dev/null +++ b/yt/yt/library/xor_filter/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +SRCS( + xor_filter.cpp +) + +PEERDIR( + yt/yt/core +) + +END() + +RECURSE( + unittests +) diff --git a/yt/yt/library/ytprof/http/handler.cpp b/yt/yt/library/ytprof/http/handler.cpp new file mode 100644 index 0000000000..382ffc1fec --- /dev/null +++ b/yt/yt/library/ytprof/http/handler.cpp @@ -0,0 +1,311 @@ +#include "handler.h" + +#include <yt/yt/core/concurrency/async_stream.h> + +#include <yt/yt/core/http/http.h> +#include <yt/yt/core/http/server.h> + +#include <yt/yt/library/ytprof/cpu_profiler.h> +#include <yt/yt/library/ytprof/spinlock_profiler.h> +#include <yt/yt/library/ytprof/heap_profiler.h> +#include <yt/yt/library/ytprof/profile.h> +#include <yt/yt/library/ytprof/symbolize.h> +#include <yt/yt/library/ytprof/external_pprof.h> + +#include <yt/yt/library/process/subprocess.h> + +#include <yt/yt/core/misc/finally.h> + +#include <library/cpp/cgiparam/cgiparam.h> + +#include <util/system/mutex.h> + +namespace NYT::NYTProf { + +using namespace NHttp; +using namespace NConcurrency; + +//////////////////////////////////////////////////////////////////////////////// + +class TBaseHandler + : public IHttpHandler +{ +public: + explicit TBaseHandler(const TBuildInfo& buildInfo) + : BuildInfo_(buildInfo) + { } + + virtual NProto::Profile BuildProfile(const TCgiParameters& params) = 0; + + void HandleRequest(const IRequestPtr& req, const IResponseWriterPtr& rsp) override + { + try { + TTryGuard guard(Lock_); + if (!guard) { + rsp->SetStatus(EStatusCode::TooManyRequests); + WaitFor(rsp->WriteBody(TSharedRef::FromString("Profile fetch already running"))) + .ThrowOnError(); + return; + } + + TCgiParameters params(req->GetUrl().RawQuery); + auto profile = BuildProfile(params); + Symbolize(&profile, true); + AddBuildInfo(&profile, BuildInfo_); + + if (auto it = params.Find("symbolize"); it == params.end() || it->second != "0") { + SymbolizeByExternalPProf(&profile, TSymbolizationOptions{ + .RunTool = RunSubprocess, + }); + } + + TStringStream profileBlob; + WriteProfile(&profileBlob, profile); + + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(profileBlob.Str()))) + .ThrowOnError(); + } catch (const std::exception& ex) { + if (rsp->AreHeadersFlushed()) { + throw; + } + + rsp->SetStatus(EStatusCode::InternalServerError); + WaitFor(rsp->WriteBody(TSharedRef::FromString(ex.what()))) + .ThrowOnError(); + + throw; + } + } + +protected: + const TBuildInfo BuildInfo_; + +private: + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, Lock_); +}; + +class TCpuProfilerHandler + : public TBaseHandler +{ +public: + using TBaseHandler::TBaseHandler; + + NProto::Profile BuildProfile(const TCgiParameters& params) override + { + auto duration = TDuration::Seconds(15); + if (auto it = params.Find("d"); it != params.end()) { + duration = TDuration::Parse(it->second); + } + + TCpuProfilerOptions options; + if (auto it = params.Find("freq"); it != params.end()) { + options.SamplingFrequency = FromString<int>(it->second); + } + + if (auto it = params.Find("record_action_run_time"); it != params.end()) { + options.RecordActionRunTime = true; + } + + if (auto it = params.Find("action_min_exec_time"); it != params.end()) { + options.SampleFilters.push_back(GetActionMinExecTimeFilter(TDuration::Parse(it->second))); + } + + TCpuProfiler profiler{options}; + profiler.Start(); + TDelayedExecutor::WaitForDuration(duration); + profiler.Stop(); + + return profiler.ReadProfile(); + } +}; + +class TSpinlockProfilerHandler + : public TBaseHandler +{ +public: + TSpinlockProfilerHandler(const TBuildInfo& buildInfo, bool yt) + : TBaseHandler(buildInfo) + , YT_(yt) + { } + + NProto::Profile BuildProfile(const TCgiParameters& params) override + { + auto duration = TDuration::Seconds(15); + if (auto it = params.Find("d"); it != params.end()) { + duration = TDuration::Parse(it->second); + } + + TSpinlockProfilerOptions options; + if (auto it = params.Find("frac"); it != params.end()) { + options.ProfileFraction = FromString<int>(it->second); + } + + if (YT_) { + TBlockingProfiler profiler{options}; + profiler.Start(); + TDelayedExecutor::WaitForDuration(duration); + profiler.Stop(); + + return profiler.ReadProfile(); + } else { + TSpinlockProfiler profiler{options}; + profiler.Start(); + TDelayedExecutor::WaitForDuration(duration); + profiler.Stop(); + + return profiler.ReadProfile(); + } + } + +private: + const bool YT_; +}; + +class TTCMallocSnapshotProfilerHandler + : public TBaseHandler +{ +public: + TTCMallocSnapshotProfilerHandler(const TBuildInfo& buildInfo, tcmalloc::ProfileType profileType) + : TBaseHandler(buildInfo) + , ProfileType_(profileType) + { } + + NProto::Profile BuildProfile(const TCgiParameters& /*params*/) override + { + return ReadHeapProfile(ProfileType_); + } + +private: + tcmalloc::ProfileType ProfileType_; +}; + +class TTCMallocAllocationProfilerHandler + : public TBaseHandler +{ +public: + using TBaseHandler::TBaseHandler; + + NProto::Profile BuildProfile(const TCgiParameters& params) override + { + auto duration = TDuration::Seconds(15); + if (auto it = params.Find("d"); it != params.end()) { + duration = TDuration::Parse(it->second); + } + + auto token = tcmalloc::MallocExtension::StartAllocationProfiling(); + TDelayedExecutor::WaitForDuration(duration); + return ConvertAllocationProfile(std::move(token).Stop()); + } +}; + +class TTCMallocStatHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /* req */, const IResponseWriterPtr& rsp) override + { + auto stat = tcmalloc::MallocExtension::GetStats(); + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(TString{stat}))) + .ThrowOnError(); + } +}; + +class TBinaryHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& req, const IResponseWriterPtr& rsp) override + { + try { + auto buildId = GetBuildId(); + TCgiParameters params(req->GetUrl().RawQuery); + + if (auto it = params.Find("check_build_id"); it != params.end()) { + if (it->second != buildId) { + THROW_ERROR_EXCEPTION("Wrong build id: %v != %v", it->second, buildId); + } + } + + rsp->SetStatus(EStatusCode::OK); + + TFileInput file{"/proc/self/exe"}; + auto adapter = CreateBufferedSyncAdapter(rsp); + file.ReadAll(*adapter); + adapter->Finish(); + + WaitFor(rsp->Close()) + .ThrowOnError(); + } catch (const std::exception& ex) { + if (rsp->AreHeadersFlushed()) { + throw; + } + + rsp->SetStatus(EStatusCode::InternalServerError); + WaitFor(rsp->WriteBody(TSharedRef::FromString(ex.what()))) + .ThrowOnError(); + + throw; + } + } +}; + +class TVersionHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /* req */, const IResponseWriterPtr& rsp) override + { + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(GetVersion()))) + .ThrowOnError(); + } +}; + +class TBuildIdHandler + : public IHttpHandler +{ +public: + void HandleRequest(const IRequestPtr& /* req */, const IResponseWriterPtr& rsp) override + { + rsp->SetStatus(EStatusCode::OK); + WaitFor(rsp->WriteBody(TSharedRef::FromString(GetVersion()))) + .ThrowOnError(); + } +}; + +void Register( + const NHttp::IServerPtr& server, + const TString& prefix, + const TBuildInfo& buildInfo) +{ + Register(server->GetPathMatcher(), prefix, buildInfo); +} + +void Register( + const IRequestPathMatcherPtr& handlers, + const TString& prefix, + const TBuildInfo& buildInfo) +{ + handlers->Add(prefix + "/profile", New<TCpuProfilerHandler>(buildInfo)); + + handlers->Add(prefix + "/lock", New<TSpinlockProfilerHandler>(buildInfo, false)); + handlers->Add(prefix + "/block", New<TSpinlockProfilerHandler>(buildInfo, true)); + + handlers->Add(prefix + "/heap", New<TTCMallocSnapshotProfilerHandler>(buildInfo, tcmalloc::ProfileType::kHeap)); + handlers->Add(prefix + "/peak", New<TTCMallocSnapshotProfilerHandler>(buildInfo, tcmalloc::ProfileType::kPeakHeap)); + handlers->Add(prefix + "/fragmentation", New<TTCMallocSnapshotProfilerHandler>(buildInfo, tcmalloc::ProfileType::kFragmentation)); + handlers->Add(prefix + "/allocations", New<TTCMallocAllocationProfilerHandler>(buildInfo)); + + handlers->Add(prefix + "/tcmalloc", New<TTCMallocStatHandler>()); + + handlers->Add(prefix + "/binary", New<TBinaryHandler>()); + + handlers->Add(prefix + "/version", New<TVersionHandler>()); + handlers->Add(prefix + "/buildid", New<TBuildIdHandler>()); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/http/handler.h b/yt/yt/library/ytprof/http/handler.h new file mode 100644 index 0000000000..fa96412d95 --- /dev/null +++ b/yt/yt/library/ytprof/http/handler.h @@ -0,0 +1,24 @@ +#pragma once + +#include <yt/yt/core/http/public.h> + +#include <yt/yt/library/ytprof/build_info.h> + +namespace NYT::NYTProf { + +//////////////////////////////////////////////////////////////////////////////// + +//! Register profiling handlers. +void Register( + const NHttp::IServerPtr& server, + const TString& prefix, + const TBuildInfo& buildInfo = TBuildInfo::GetDefault()); + +void Register( + const NHttp::IRequestPathMatcherPtr& handlers, + const TString& prefix, + const TBuildInfo& buildInfo = TBuildInfo::GetDefault()); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NYTProf diff --git a/yt/yt/library/ytprof/http/ya.make b/yt/yt/library/ytprof/http/ya.make new file mode 100644 index 0000000000..1a1f3ff20c --- /dev/null +++ b/yt/yt/library/ytprof/http/ya.make @@ -0,0 +1,16 @@ +LIBRARY() + +INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) + +SRCS( + handler.cpp +) + +PEERDIR( + library/cpp/cgiparam + yt/yt/core/http + yt/yt/library/ytprof + yt/yt/library/process +) + +END() |