diff options
author | Alexander Smirnov <alex@ydb.tech> | 2025-05-02 21:33:53 +0000 |
---|---|---|
committer | Alexander Smirnov <alex@ydb.tech> | 2025-05-02 21:33:53 +0000 |
commit | 726e4fe93a06affb8a5805f80f779e1ebc891ffc (patch) | |
tree | 0a22ac4b5a192f4cfc89252997f3d555396954d2 | |
parent | cfede7fd10c5032b322bc335caff4d30c7674e6f (diff) | |
parent | 940be57633df4940e96f5754ce1bc0d4e5934dc8 (diff) | |
download | ydb-726e4fe93a06affb8a5805f80f779e1ebc891ffc.tar.gz |
Merge pull request #17944 from ydb-platform/merge-libs-250501-0050
122 files changed, 2597 insertions, 696 deletions
diff --git a/build/conf/java.conf b/build/conf/java.conf index 31a84a38eca..cfa2e3cf384 100644 --- a/build/conf/java.conf +++ b/build/conf/java.conf @@ -20,6 +20,15 @@ module JAVA_LIBRARY: JAR_LIBRARY { .SEM=_BUILD_JAR_SEM } +module JAVA_TEST_LIBRARY: JAR_LIBRARY { + .SEM=_BUILD_JAR_SEM + + SET(_SEM_JAR_TARGET jar_test) + SET(MODULE_TYPE JAVA_TEST_LIBRARY) + SET(MODULE_TAG JAVA_TEST_LIBRARY) + SET_APPEND(PEERDIR_TAGS JAVA_TEST_LIBRARY) +} + module _JAR_PROGRAM_COMPILATION: JAR_LIBRARY { .IGNORED=JAVA_RUNTIME_PEERDIR JAVA_RUNTIME_EXCLUDE .ALIASES=JAVA_RUNTIME_PEERDIR=_NOOP_MACRO JAVA_RUNTIME_EXCLUDE=_NOOP_MACRO @@ -106,6 +115,8 @@ multimodule JAVA_ANNOTATION_PROCESSOR { } module JAR_COMPILATION: _JAR_PROGRAM_COMPILATION { .SEM=_BUILD_JAVA_ANNOTATION_PROCESSOR_SEM + + SET(MODULE_TYPE JAVA_ANNOTATION_PROCESSOR) } } @@ -165,6 +176,8 @@ multimodule JUNIT5 { PEERDIR(devtools/jtest-annotations/junit5) SET(MODULE_TYPE JUNIT5) + SET(MODULE_TAG JUNIT5) + SET_APPEND(PEERDIR_TAGS JUNIT5 JAVA_TEST_LIBRARY) } module JAR_COMPILATION: JAR_LIBRARY { .ALLOWED=YT_SPEC @@ -180,8 +193,11 @@ multimodule JUNIT5 { when ($OPENSOURCE != "yes") { PEERDIR+=devtools/jtest-annotations/junit5 } - SET(MODULE_TYPE JUNIT5) + SET(_SEM_JAR_TARGET junit5) + SET(MODULE_TYPE JUNIT5) + SET(MODULE_TAG JUNIT5) + SET_APPEND(PEERDIR_TAGS JUNIT5 JAVA_TEST_LIBRARY) when ($OPENSOURCE == "yes" && $AUTOCHECK == "yes") { # FIXME: Replace AUTOCHECK == yes with _not a host platform_ check after YMAKE-218 @@ -220,6 +236,8 @@ multimodule JTEST { .SEM=_SEM_IGNORED SET(MODULE_TYPE JTEST) + SET(MODULE_TAG JTEST) + SET_APPEND(PEERDIR_TAGS JTEST JAVA_TEST_LIBRARY) PEERDIR(devtools/junit-runner devtools/jtest-annotations/junit4) DEPENDENCY_MANAGEMENT(contrib/java/com/google/code/gson/gson/2.10.1 contrib/java/com/beust/jcommander/1.72 contrib/java/junit/junit/4.12) } @@ -233,6 +251,9 @@ multimodule JTEST { SET(_SEM_JAR_TARGET junit4) SET(MODULE_TYPE JTEST) + SET(MODULE_TAG JTEST) + SET_APPEND(PEERDIR_TAGS JTEST JAVA_TEST_LIBRARY) + DEPENDS(contrib/java/org/sonarsource/scanner/cli/sonar-scanner-cli/2.8) PEERDIR+=build/platform/java/jacoco-agent @@ -276,6 +297,8 @@ multimodule JTEST_FOR { .SEM=_SEM_IGNORED SET(MODULE_TYPE JTEST_FOR) + SET(MODULE_TAG JTEST) + SET_APPEND(PEERDIR_TAGS JTEST JAVA_TEST_LIBRARY) PEERDIR(${UNITTEST_DIR} devtools/junit-runner) DEPENDENCY_MANAGEMENT(contrib/java/com/google/code/gson/gson/2.8.6 contrib/java/com/beust/jcommander/1.72 contrib/java/junit/junit/4.12) } @@ -287,7 +310,11 @@ multimodule JTEST_FOR { .SEM=_BUILD_JUNIT4_JAR_SEM .RESTRICTED=JUNIT_TESTS_JAR + SET(_SEM_JAR_TARGET junit4) SET(MODULE_TYPE JTEST_FOR) + SET(MODULE_TAG JTEST) + SET_APPEND(PEERDIR_TAGS JTEST JAVA_TEST_LIBRARY) + DEPENDS(contrib/java/org/sonarsource/scanner/cli/sonar-scanner-cli/2.8) PEERDIR(devtools/junit-runner build/platform/java/jacoco-agent ${UNITTEST_DIR}) } @@ -544,7 +571,8 @@ module _JAR_BASE: _BARE_UNIT { SET(MODULE_TAG JAVA) - PEERDIR_TAGS=JAVA JAVA_PROTO JAVA_PROTO_FROM_SCHEMA JAVA_FBS JAVA_IDL DLL JAR_COMPILATION __EMPTY__ RESOURCE_LIB + # TODO: Remove JUNIT5 JTEST JAVA_TEST_LIBRARY + PEERDIR_TAGS=JAVA JAVA_PROTO JAVA_PROTO_FROM_SCHEMA JAVA_FBS JAVA_IDL DLL JAR_COMPILATION __EMPTY__ RESOURCE_LIB JUNIT5 JTEST JAVA_TEST_LIBRARY HAS_MANAGEABLE_PEERS=yes DYNAMIC_LINK=yes diff --git a/build/export_generators/gradle/generator.toml b/build/export_generators/gradle/generator.toml index be72a215542..426656dc2cc 100644 --- a/build/export_generators/gradle/generator.toml +++ b/build/export_generators/gradle/generator.toml @@ -16,13 +16,15 @@ template="build.gradle.kts.jinja" [targets.jar_proto] template={ path="build.gradle.kts.proto.jinja", dest="build.gradle.kts" } +[targets.jar_test] +is_test = true +is_extra_target = true + [targets.junit4] -template="build.gradle.kts.jinja" is_test = true is_extra_target = true [targets.junit5] -template="build.gradle.kts.jinja" is_test = true is_extra_target = true @@ -101,12 +103,3 @@ consumer-classpath="str" consumer-jar="str" consumer-type="str" consumer-prebuilt="flag" - -[merge] -test=[ - "/ut", - "/src/test", - "/src/test/java", - "/src/test-integration", - "/src/test-multicell" -] diff --git a/build/export_generators/ide-gradle/build.gradle.kts.any.jinja b/build/export_generators/ide-gradle/build.gradle.kts.any.jinja index 79936b6822c..d5126e76721 100644 --- a/build/export_generators/ide-gradle/build.gradle.kts.any.jinja +++ b/build/export_generators/ide-gradle/build.gradle.kts.any.jinja @@ -30,6 +30,7 @@ {%- if proto_template -%} {%- include "[generator]/proto_vars.jinja" -%} +{%- include "[generator]/import.jinja" -%} {%- include "[generator]/proto_import.jinja" -%} {%- include "[generator]/proto_builddir.jinja" -%} {%- include "[generator]/proto_plugins.jinja" -%} diff --git a/build/export_generators/ide-gradle/codegen_copy_file.jinja b/build/export_generators/ide-gradle/codegen_copy_file.jinja index 2b054b7ce1f..7689c4db1c4 100644 --- a/build/export_generators/ide-gradle/codegen_copy_file.jinja +++ b/build/export_generators/ide-gradle/codegen_copy_file.jinja @@ -16,7 +16,7 @@ val {{ varprefix }}{{ copy['_object_index'] }} = tasks.register<Copy>("{{ varpre into({{ PatchRoots(dst_path, false, true) }}) {%- if src_name != dst_name %} rename("{{ src_name }}", "{{ dst_name }}") +{%- endif %} } -{% endif -%} -{%- endfor -%} +{% endfor -%} {%- endif -%} diff --git a/build/export_generators/ide-gradle/common_vars.jinja b/build/export_generators/ide-gradle/common_vars.jinja new file mode 100644 index 00000000000..11662f423a4 --- /dev/null +++ b/build/export_generators/ide-gradle/common_vars.jinja @@ -0,0 +1,22 @@ +{%- if target is defined -%} +{%- set publish = target.publish -%} +{%- set with_kotlin = target.with_kotlin -%} +{%- if with_kotlin -%} +{%- set kotlin_version = target.kotlin_version -%} +{%- set with_kotlinc_plugin_allopen = target.with_kotlinc_plugin_allopen -%} +{%- set with_kotlinc_plugin_noarg = target.with_kotlinc_plugin_noarg -%} +{%- endif -%} +{%- set has_errorprone = target.use_errorprone and not disable_errorprone and target.consumer|selectattr('jar', 'startsWith', 'contrib/java/com/google/errorprone/error_prone_annotations')|length -%} +{%- else -%} +{#- No target, only extra_targets, get main features from extra_targets -#} +{%- set publish = extra_targets|selectattr('publish', 'eq', true)|length -%} +{%- set with_kotlin = extra_targets|selectattr('with_kotlin', 'eq', true)|length -%} +{%- if with_kotlin -%} +{%- set kotlin_version = extra_targets|selectattr('kotlin_version')|map(attribute='kotlin_version')|first -%} +{%- set with_kotlinc_plugin_allopen = extra_targets|selectattr('with_kotlinc_plugin_allopen')|map(attribute='with_kotlinc_plugin_allopen')|sum -%} +{%- set with_kotlinc_plugin_noarg = extra_targets|selectattr('with_kotlinc_plugin_noarg')|map(attribute='with_kotlinc_plugin_noarg')|sum -%} +{%- endif -%} +{%- set has_errorprone = extra_targets|selectattr('use_errorprone', 'eq', true)|length and not disable_errorprone and extra_targets|selectattr('consumer')|map(attribute='consumer')|sum|selectattr('jar', 'startsWith', 'contrib/java/com/google/errorprone/error_prone_annotations')|length -%} +{%- endif -%} + +{%- include "[generator]/jdk.jinja" -%} diff --git a/build/export_generators/ide-gradle/dependencies.jinja b/build/export_generators/ide-gradle/dependencies.jinja index deb64a4ee7f..3e89198b5f8 100644 --- a/build/export_generators/ide-gradle/dependencies.jinja +++ b/build/export_generators/ide-gradle/dependencies.jinja @@ -13,9 +13,9 @@ {%- endif -%} {%- endmacro -%} -{%- macro AddFileDeps(file_deps) -%} -{%- for file_dep in file_deps %} - "$arcadia_root/{{ file_dep.jar }}"{%- if not loop.last -%},{%- endif -%} +{%- macro AddFileJars(file_jars) -%} +{%- for file_jar in file_jars %} + "$arcadia_root/{{ file_jar }}"{%- if not loop.last -%},{%- endif -%} {%- endfor -%} {%- endmacro -%} @@ -56,13 +56,15 @@ {%- if not build_contribs -%} {%- set file_deps = file_deps|selectattr('type', 'ne', 'contrib') -%} {%- endif -%} -{%- set file_classpaths = file_deps|map(attribute='classpath') -%} +{%- set file_classpaths = file_deps|map(attribute='classpath')|unique|sort -%} +{%- set file_jars = file_deps|map(attribute='jar')|unique|sort -%} {%- set test_file_deps = extra_targets|selectattr('consumer')|map(attribute='consumer')|sum|selectattr('classpath')|selectattr('jar')|selectattr('prebuilt', 'eq', true) -%} {%- if not build_contribs -%} {%- set test_file_deps = test_file_deps|selectattr('type', 'ne', 'contrib') -%} {%- endif -%} -{%- set test_file_classpaths = test_file_deps|map(attribute='classpath') -%} +{%- set test_file_classpaths = test_file_deps|map(attribute='classpath')|unique|sort -%} +{%- set test_file_jars = test_file_deps|map(attribute='jar')|reject("in", file_jars)|unique|sort -%} dependencies { {%- if has_errorprone -%} @@ -86,15 +88,15 @@ dependencies { {{ AddNonFileDeps(extra_target, test_file_classpaths, "testImplementation", "testImplementation") }} {%- endfor -%} -{%- if file_deps|length %} +{%- if file_jars|length %} implementation(files(listOf({#- glue -#} -{{ AddFileDeps(file_deps) }} +{{ AddFileJars(file_jars) }} ))) {%- endif -%} -{%- if test_file_deps|length %} +{%- if test_file_jars|length %} testImplementation(files(listOf({#- glue -#} -{{ AddFileDeps(test_file_deps) }} +{{ AddFileJars(test_file_jars) }} ))) {%- endif %} } diff --git a/build/export_generators/ide-gradle/generator.toml b/build/export_generators/ide-gradle/generator.toml index e085cd71b68..61c247cba2f 100644 --- a/build/export_generators/ide-gradle/generator.toml +++ b/build/export_generators/ide-gradle/generator.toml @@ -19,13 +19,15 @@ template="build.gradle.kts.jinja" [targets.jar_proto] template={ path="build.gradle.kts.proto.jinja", dest="build.gradle.kts" } +[targets.jar_test] +is_test = true +is_extra_target = true + [targets.junit4] -template="build.gradle.kts.jinja" is_test = true is_extra_target = true [targets.junit5] -template="build.gradle.kts.jinja" is_test = true is_extra_target = true @@ -111,14 +113,3 @@ consumer-classpath="str" consumer-jar="str" consumer-type="str" consumer-prebuilt="flag" - -[merge] -test=[ - "/ut", - "/src/test", - "/src/test/java", - "/src/test-integration", - "/src/integration-test", - "/src/testFixtures", - "/src/intTest", -] diff --git a/build/export_generators/ide-gradle/javac_flags.jinja b/build/export_generators/ide-gradle/javac_flags.jinja index d8e40b3f6e4..c8c536cdba6 100644 --- a/build/export_generators/ide-gradle/javac_flags.jinja +++ b/build/export_generators/ide-gradle/javac_flags.jinja @@ -1,17 +1,17 @@ {%- set javac_flags = [] -%} -{%- set jvm_flags = [] -%} +{%- set compiler_jvm_flags = [] -%} {%- if target.javac.flags|length -%} -{#- skip errorprone options -#} +{#- skip errorprone options and JVM args -#} {%- set javac_flags = target.javac.flags|reject('startsWith', '-Xep')|reject('startsWith', '-J') -%} {%- if javac_flags|length -%} {%- if (javac_flags|length == 1) and (javac_flags|first == '-parameters') -%} {%- set javac_flags = [] -%} {%- endif -%} {%- endif -%} -{%- set jvm_flags = target.javac.flags|select('startsWith', '-J') -%} +{%- set compiler_jvm_flags = target.javac.flags|select('startsWith', '-J') -%} {%- endif -%} -{%- if javac_flags|length or jvm_flags|length or has_errorprone %} +{%- if javac_flags|length or compiler_jvm_flags|length or has_errorprone %} tasks.withType<JavaCompile> { {%- if javac_flags|length -%} @@ -20,18 +20,18 @@ tasks.withType<JavaCompile> { {%- endfor -%} {%- endif -%} -{%- if jvm_flags|length %} +{%- if compiler_jvm_flags|length %} options.isFork = true options.forkOptions.jvmArgs = listOf( -{%- for jvm_flag in jvm_flags -%} -"{{ jvm_flag|replace("-J", "") }}"{% if not loop.last %}, {% endif -%} +{%- for compiler_jvm_flag in compiler_jvm_flags -%} +"{{ compiler_jvm_flag|replace("-J", "") }}"{% if not loop.last %}, {% endif -%} {%- endfor -%}) {% endif -%} {%- if has_errorprone -%} {%- set ep_checks = target.javac.flags|select('startsWith', '-Xep:') -%} {%- set ep_checkopts = target.javac.flags|select('startsWith', '-XepOpt:') -%} -{%- set ep_props = target.javac.flags|reject('startsWith', '-XepOpt:')|select('startsWith', '-Xep') -%} +{%- set ep_props = target.javac.flags|reject('startsWith', '-Xep:')|reject('startsWith', '-XepOpt:')|select('startsWith', '-Xep') -%} {%- if ep_checks|length %} options.errorprone.checks.set( mapOf( diff --git a/build/export_generators/ide-gradle/kotlin_plugins.jinja b/build/export_generators/ide-gradle/kotlin_plugins.jinja index 823fa5243d4..0a2e7cd48d2 100644 --- a/build/export_generators/ide-gradle/kotlin_plugins.jinja +++ b/build/export_generators/ide-gradle/kotlin_plugins.jinja @@ -1,20 +1,20 @@ -{%- if target.with_kotlinc_plugin_allopen|length -%} +{%- if with_kotlinc_plugin_allopen|length -%} {%- set allopen_annotations = [] -%} -{%- if target.with_kotlinc_plugin_allopen|select('eq', 'preset=spring')|length -%} +{%- if with_kotlinc_plugin_allopen|select('eq', 'preset=spring')|length -%} {%- set allopen_annotations = allopen_annotations + ['org.springframework.stereotype.Component', 'org.springframework.transaction.annotation.Transactional', 'org.springframework.scheduling.annotation.Async', 'org.springframework.cache.annotation.Cacheable', 'org.springframework.boot.test.context.SpringBootTest', 'org.springframework.validation.annotation.Validated'] -%} {%- endif -%} -{%- if target.with_kotlinc_plugin_allopen|select('eq', 'preset=quarkus')|length -%} +{%- if with_kotlinc_plugin_allopen|select('eq', 'preset=quarkus')|length -%} {%- set allopen_annotations = allopen_annotations + ['javax.enterprise.context.ApplicationScoped', 'javax.enterprise.context.RequestScoped'] -%} {%- endif -%} -{%- if target.with_kotlinc_plugin_allopen|select('eq', 'preset=micronaut')|length -%} +{%- if with_kotlinc_plugin_allopen|select('eq', 'preset=micronaut')|length -%} {%- set allopen_annotations = allopen_annotations + ['io.micronaut.aop.Around', 'io.micronaut.aop.Introduction', 'io.micronaut.aop.InterceptorBinding', 'io.micronaut.aop.InterceptorBindingDefinitions'] -%} {%- endif -%} -{%- if target.with_kotlinc_plugin_allopen|select('startsWith', 'annotation=')|length -%} -{%- set sannotations = target.with_kotlinc_plugin_allopen|select('startsWith', 'annotation=')|join('|')|replace('annotation=','') -%} +{%- if with_kotlinc_plugin_allopen|select('startsWith', 'annotation=')|length -%} +{%- set sannotations = with_kotlinc_plugin_allopen|select('startsWith', 'annotation=')|join('|')|replace('annotation=','') -%} {%- set annotations = split(sannotations, '|') -%} {%- set allopen_annotations = allopen_annotations + annotations -%} {%- endif -%} -{%- set allopen_options = target.with_kotlinc_plugin_allopen|reject('startsWith', 'preset=')|reject('startsWith', 'annotation=')|reject('eq', 'default') %} +{%- set allopen_options = with_kotlinc_plugin_allopen|reject('startsWith', 'preset=')|reject('startsWith', 'annotation=')|reject('eq', 'default') %} allOpen { {%- if allopen_options|length -%} @@ -30,17 +30,17 @@ allOpen { } {% endif -%} -{%- if target.with_kotlinc_plugin_noarg|length -%} +{%- if with_kotlinc_plugin_noarg|length -%} {%- set noarg_annotations = [] -%} -{%- if target.with_kotlinc_plugin_noarg|select('eq', 'preset=jpa')|length -%} +{%- if with_kotlinc_plugin_noarg|select('eq', 'preset=jpa')|length -%} {%- set noarg_annotations = noarg_annotations + ['javax.persistence.Entity', 'javax.persistence.Embeddable', 'javax.persistence.MappedSuperclass', 'jakarta.persistence.Entity', 'jakarta.persistence.Embeddable', 'jakarta.persistence.MappedSuperclass'] -%} {%- endif -%} -{%- if target.with_kotlinc_plugin_noarg|select('startsWith', 'annotation=')|length -%} -{%- set sannotations = target.with_kotlinc_plugin_noarg|select('startsWith', 'annotation=')|join('|')|replace('annotation=','') -%} +{%- if with_kotlinc_plugin_noarg|select('startsWith', 'annotation=')|length -%} +{%- set sannotations = with_kotlinc_plugin_noarg|select('startsWith', 'annotation=')|join('|')|replace('annotation=','') -%} {%- set annotations = split(sannotations, '|') -%} {%- set noarg_annotations = noarg_annotations + annotations -%} {%- endif -%} -{%- set noarg_options = target.with_kotlinc_plugin_noarg|reject('startsWith', 'preset=')|reject('startsWith', 'annotation=')|reject('eq', 'default') %} +{%- set noarg_options = with_kotlinc_plugin_noarg|reject('startsWith', 'preset=')|reject('startsWith', 'annotation=')|reject('eq', 'default') %} noArg { {%- if noarg_options|length -%} diff --git a/build/export_generators/ide-gradle/kotlinc_flags.jinja b/build/export_generators/ide-gradle/kotlinc_flags.jinja index 3f6c8d90ad1..0aa895f668f 100644 --- a/build/export_generators/ide-gradle/kotlinc_flags.jinja +++ b/build/export_generators/ide-gradle/kotlinc_flags.jinja @@ -1,18 +1,18 @@ {%- if with_kotlin -%} {%- set kotlinc_flags = [] -%} -{%- if task.kotlinc.flags|length -%} -{%- set kotlinc_flags = kotlinc_flags -%} +{%- if target.kotlinc.flags|length -%} +{%- set kotlinc_flags = target.kotlinc.flags|unique -%} {%- endif -%} {%- set extra_kotlinc_flags = extra_targets|selectattr('kotlinc')|map(attribute='kotlinc')|map(attribute='flags')|sum -%} {%- if extra_kotlinc_flags|length -%} -{%- set kotlinc_flags = kotlinc_flags + extra_kotlinc_flags -%} +{%- set kotlinc_flags = kotlinc_flags + extra_kotlinc_flags|unique -%} {%- endif -%} {%- if kotlinc_flags|length %} -tasks.withType<KotlinCompile> { +tasks.withType<KotlinCompile>() { compilerOptions { {%- for kotlinc_flag in kotlinc_flags|unique %} - freeCompilerArgs.add("{{ kotlinc_flag|replace(export_root, "$arcadia_root")|replace(arcadia_root, "$arcadia_root") }}") + freeCompilerArgs.add({{ PatchRoots(kotlinc_flag, true) }}) {%- endfor %} } } diff --git a/build/export_generators/ide-gradle/proto_vars.jinja b/build/export_generators/ide-gradle/proto_vars.jinja index 5b3cfb6425d..6daaa87975b 100644 --- a/build/export_generators/ide-gradle/proto_vars.jinja +++ b/build/export_generators/ide-gradle/proto_vars.jinja @@ -1,8 +1,5 @@ -{%- set publish = target.publish -%} -{%- set with_kotlin = target.with_kotlin -%} -{%- set kotlin_version = target.kotlin_version -%} +{%- include "[generator]/common_vars.jinja" -%} + {%- set prepareProtosTask = target.proto_files|length or target.runs|length or target.custom_runs|length -%} {%- set libraries = target.consumer|selectattr('type', 'eq', 'library') -%} {%- set extractLibrariesProtosTask = libraries|length -%} - -{%- include "[generator]/jdk.jinja" -%} diff --git a/build/export_generators/ide-gradle/source_sets.jinja b/build/export_generators/ide-gradle/source_sets.jinja index e127569c458..a0086568b6f 100644 --- a/build/export_generators/ide-gradle/source_sets.jinja +++ b/build/export_generators/ide-gradle/source_sets.jinja @@ -46,34 +46,37 @@ sourceSets { {%- if target.proto_grpc %} java.srcDir("$buildDir/generated/source/proto/test/grpc") {%- endif -%} -{%- else %} - java.srcDir("ut/java") - resources.srcDir("ut/resources") - java.srcDir("src/test-integration/java") - resources.srcDir("src/test-integration/resources") - java.srcDir("src/integration-test/java") - resources.srcDir("src/integration-test/resources") - java.srcDir("src/testFixtures/java") - resources.srcDir("src/testFixtures/resources") - java.srcDir("src/intTest/java") - resources.srcDir("src/intTest/resources") - -{%- set extra_target_source_sets = extra_targets|selectattr('jar_source_set')|map(attribute='jar_source_set')|sum|reject('startsWith', 'src/test/java:')|unique -%} -{%- if extra_target_source_sets|length -%} -{%- for source_set in extra_target_source_sets -%} -{%- set srcdir_glob = split(source_set, ':') -%} -{%- set srcdir = srcdir_glob[0] %} +{%- elif extra_targets|length -%} +{%- for extra_target in extra_targets -%} +{%- set reldir = "" -%} +{%- if extra_target.test_reldir -%} +{%- set reldir = extra_target.test_reldir + "/" -%} +{%- endif -%} +{%- for source_set in extra_target.jar_source_set -%} +{%- set srcdir_glob = split(source_set, ':', 2) -%} +{%- if srcdir_glob[0][0] == "/" -%} +{#- Absolute path in glob -#} +{%- set srcdir = srcdir_glob[0] -%} +{%- else -%} +{%- set srcdir = reldir + srcdir_glob[0] -%} +{%- endif -%} +{%- if srcdir != "src/test/java" %} java.srcDir({{ PatchRoots(srcdir) }}) +{%- endif -%} {%- endfor -%} -{%- endif %} -{%- set extra_target_resource_sets = extra_targets|selectattr('jar_resource_set')|map(attribute='jar_resource_set')|sum|reject('startsWith', 'src/test/resources:')|unique -%} -{%- if extra_target_resource_sets|length -%} -{%- for resource_set in extra_target_resource_sets -%} -{%- set resdir_glob = split(resource_set, ':') -%} -{%- set resdir = resdir_glob[0] %} +{%- for resource_set in extra_target.jar_resource_set -%} +{%- set resdir_glob = split(resource_set, ':', 2) -%} +{%- if resdir_glob[0][0] == "/" -%} +{#- Absolute path in glob -#} +{%- set srcdir = resdir_glob[0] -%} +{%- else -%} +{%- set resdir = reldir + resdir_glob[0] -%} +{%- endif -%} +{%- if resdir != "src/test/resources" %} resources.srcDir({{ PatchRoots(resdir) }}) +{%- endif -%} {%- endfor -%} -{%- endif -%} +{%- endfor -%} {%- endif %} } } diff --git a/build/export_generators/ide-gradle/vars.jinja b/build/export_generators/ide-gradle/vars.jinja index 26f7e3621c7..f7bb31c3849 100644 --- a/build/export_generators/ide-gradle/vars.jinja +++ b/build/export_generators/ide-gradle/vars.jinja @@ -1,8 +1,4 @@ +{%- include "[generator]/common_vars.jinja" -%} + {%- set mainClass = target.app_main_class -%} -{%- set publish = target.publish -%} -{%- set with_kotlin = target.with_kotlin -%} -{%- set kotlin_version = target.kotlin_version -%} {%- set has_junit5_test = extra_targets|selectattr('junit5_test') -%} -{%- set has_errorprone = target.use_errorprone and not disable_errorprone and target.consumer|selectattr('jar', 'startsWith', 'contrib/java/com/google/errorprone/error_prone_annotations')|length -%} - -{%- include "[generator]/jdk.jinja" -%} diff --git a/build/external_resources/yexport/public.resources.json b/build/external_resources/yexport/public.resources.json index d71ac157e52..02b4f3014ee 100644 --- a/build/external_resources/yexport/public.resources.json +++ b/build/external_resources/yexport/public.resources.json @@ -1,13 +1,13 @@ { "by_platform": { "darwin": { - "uri": "sbr:8477042967" + "uri": "sbr:8632444017" }, "darwin-arm64": { - "uri": "sbr:8477040717" + "uri": "sbr:8632441113" }, "linux": { - "uri": "sbr:8477038461" + "uri": "sbr:8632437510" } } } diff --git a/build/external_resources/yexport/resources.json b/build/external_resources/yexport/resources.json index 82fc2a3c959..b29b46e18cf 100644 --- a/build/external_resources/yexport/resources.json +++ b/build/external_resources/yexport/resources.json @@ -1,13 +1,13 @@ { "by_platform": { "darwin": { - "uri": "sbr:8476957408" + "uri": "sbr:8632348873" }, "darwin-arm64": { - "uri": "sbr:8476954744" + "uri": "sbr:8632345760" }, "linux": { - "uri": "sbr:8476952778" + "uri": "sbr:8632343302" } } } diff --git a/build/mapping.conf.json b/build/mapping.conf.json index dc72a1ec161..19ad3a545d3 100644 --- a/build/mapping.conf.json +++ b/build/mapping.conf.json @@ -611,6 +611,8 @@ "8450186648": "{registry_endpoint}/8450186648", "8457945767": "{registry_endpoint}/8457945767", "8477042967": "{registry_endpoint}/8477042967", + "8628672485": "{registry_endpoint}/8628672485", + "8632444017": "{registry_endpoint}/8632444017", "5811823398": "{registry_endpoint}/5811823398", "5840611310": "{registry_endpoint}/5840611310", "5860185593": "{registry_endpoint}/5860185593", @@ -641,6 +643,8 @@ "8450185847": "{registry_endpoint}/8450185847", "8457944672": "{registry_endpoint}/8457944672", "8477040717": "{registry_endpoint}/8477040717", + "8628670306": "{registry_endpoint}/8628670306", + "8632441113": "{registry_endpoint}/8632441113", "5811822876": "{registry_endpoint}/5811822876", "5840610640": "{registry_endpoint}/5840610640", "5860184285": "{registry_endpoint}/5860184285", @@ -671,6 +675,8 @@ "8450184998": "{registry_endpoint}/8450184998", "8457943496": "{registry_endpoint}/8457943496", "8477038461": "{registry_endpoint}/8477038461", + "8628668563": "{registry_endpoint}/8628668563", + "8632437510": "{registry_endpoint}/8632437510", "5766172292": "{registry_endpoint}/5766172292", "5805431504": "{registry_endpoint}/5805431504", "5829027626": "{registry_endpoint}/5829027626", @@ -1997,6 +2003,8 @@ "8450186648": "devtools/yexport/bin/yexport for darwin", "8457945767": "devtools/yexport/bin/yexport for darwin", "8477042967": "devtools/yexport/bin/yexport for darwin", + "8628672485": "devtools/yexport/bin/yexport for darwin", + "8632444017": "devtools/yexport/bin/yexport for darwin", "5811823398": "devtools/yexport/bin/yexport for darwin-arm64", "5840611310": "devtools/yexport/bin/yexport for darwin-arm64", "5860185593": "devtools/yexport/bin/yexport for darwin-arm64", @@ -2027,6 +2035,8 @@ "8450185847": "devtools/yexport/bin/yexport for darwin-arm64", "8457944672": "devtools/yexport/bin/yexport for darwin-arm64", "8477040717": "devtools/yexport/bin/yexport for darwin-arm64", + "8628670306": "devtools/yexport/bin/yexport for darwin-arm64", + "8632441113": "devtools/yexport/bin/yexport for darwin-arm64", "5811822876": "devtools/yexport/bin/yexport for linux", "5840610640": "devtools/yexport/bin/yexport for linux", "5860184285": "devtools/yexport/bin/yexport for linux", @@ -2057,6 +2067,8 @@ "8450184998": "devtools/yexport/bin/yexport for linux", "8457943496": "devtools/yexport/bin/yexport for linux", "8477038461": "devtools/yexport/bin/yexport for linux", + "8628668563": "devtools/yexport/bin/yexport for linux", + "8632437510": "devtools/yexport/bin/yexport for linux", "5766172292": "devtools/ymake/bin/ymake for darwin", "5805431504": "devtools/ymake/bin/ymake for darwin", "5829027626": "devtools/ymake/bin/ymake for darwin", diff --git a/build/ymake_conf.py b/build/ymake_conf.py index 103a559cdf3..3810caa3b10 100755 --- a/build/ymake_conf.py +++ b/build/ymake_conf.py @@ -2528,7 +2528,7 @@ class CuDNN(object): return self.cudnn_version.value in ('7.6.5', '8.0.5', '8.6.0', '8.9.7', '9.0.0') def auto_cudnn_version(self): - return '9.0.0' + return '8.6.0' def print_(self): if self.cuda.have_cuda.value and self.have_cudnn(): diff --git a/contrib/python/google-auth/py3/.dist-info/METADATA b/contrib/python/google-auth/py3/.dist-info/METADATA index 952e020e3cd..53da88bdec8 100644 --- a/contrib/python/google-auth/py3/.dist-info/METADATA +++ b/contrib/python/google-auth/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: google-auth -Version: 2.38.0 +Version: 2.39.0 Summary: Google Authentication Library Home-page: https://github.com/googleapis/google-auth-library-python Author: Google Cloud Platform @@ -14,6 +14,7 @@ Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License @@ -28,21 +29,49 @@ Requires-Dist: cachetools<6.0,>=2.0.0 Requires-Dist: pyasn1-modules>=0.2.1 Requires-Dist: rsa<5,>=3.1.4 Provides-Extra: aiohttp -Requires-Dist: aiohttp<4.0.0.dev0,>=3.6.2; extra == "aiohttp" -Requires-Dist: requests<3.0.0.dev0,>=2.20.0; extra == "aiohttp" +Requires-Dist: aiohttp<4.0.0,>=3.6.2; extra == "aiohttp" +Requires-Dist: requests<3.0.0,>=2.20.0; extra == "aiohttp" Provides-Extra: enterprise_cert Requires-Dist: cryptography; extra == "enterprise-cert" Requires-Dist: pyopenssl; extra == "enterprise-cert" Provides-Extra: pyjwt Requires-Dist: pyjwt>=2.0; extra == "pyjwt" Requires-Dist: cryptography>=38.0.3; extra == "pyjwt" +Requires-Dist: cryptography<39.0.0; python_version < "3.8" and extra == "pyjwt" Provides-Extra: pyopenssl Requires-Dist: pyopenssl>=20.0.0; extra == "pyopenssl" Requires-Dist: cryptography>=38.0.3; extra == "pyopenssl" +Requires-Dist: cryptography<39.0.0; python_version < "3.8" and extra == "pyopenssl" Provides-Extra: reauth Requires-Dist: pyu2f>=0.1.5; extra == "reauth" Provides-Extra: requests -Requires-Dist: requests<3.0.0.dev0,>=2.20.0; extra == "requests" +Requires-Dist: requests<3.0.0,>=2.20.0; extra == "requests" +Provides-Extra: testing +Requires-Dist: grpcio; extra == "testing" +Requires-Dist: flask; extra == "testing" +Requires-Dist: freezegun; extra == "testing" +Requires-Dist: mock; extra == "testing" +Requires-Dist: oauth2client; extra == "testing" +Requires-Dist: pyjwt>=2.0; extra == "testing" +Requires-Dist: cryptography>=38.0.3; extra == "testing" +Requires-Dist: pytest; extra == "testing" +Requires-Dist: pytest-cov; extra == "testing" +Requires-Dist: pytest-localserver; extra == "testing" +Requires-Dist: pyopenssl>=20.0.0; extra == "testing" +Requires-Dist: pyu2f>=0.1.5; extra == "testing" +Requires-Dist: responses; extra == "testing" +Requires-Dist: urllib3; extra == "testing" +Requires-Dist: packaging; extra == "testing" +Requires-Dist: aiohttp<4.0.0,>=3.6.2; extra == "testing" +Requires-Dist: requests<3.0.0,>=2.20.0; extra == "testing" +Requires-Dist: aioresponses; extra == "testing" +Requires-Dist: pytest-asyncio; extra == "testing" +Requires-Dist: pyopenssl<24.3.0; extra == "testing" +Requires-Dist: aiohttp<3.10.0; extra == "testing" +Requires-Dist: cryptography<39.0.0; python_version < "3.8" and extra == "testing" +Provides-Extra: urllib3 +Requires-Dist: urllib3; extra == "urllib3" +Requires-Dist: packaging; extra == "urllib3" Google Auth Python Library ========================== diff --git a/contrib/python/google-auth/py3/google/auth/_default.py b/contrib/python/google-auth/py3/google/auth/_default.py index 1234fb25d78..cf0cdd77298 100644 --- a/contrib/python/google-auth/py3/google/auth/_default.py +++ b/contrib/python/google-auth/py3/google/auth/_default.py @@ -484,42 +484,8 @@ def _get_impersonated_service_account_credentials(filename, info, scopes): from google.auth import impersonated_credentials try: - source_credentials_info = info.get("source_credentials") - source_credentials_type = source_credentials_info.get("type") - if source_credentials_type == _AUTHORIZED_USER_TYPE: - source_credentials, _ = _get_authorized_user_credentials( - filename, source_credentials_info - ) - elif source_credentials_type == _SERVICE_ACCOUNT_TYPE: - source_credentials, _ = _get_service_account_credentials( - filename, source_credentials_info - ) - elif source_credentials_type == _EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE: - source_credentials, _ = _get_external_account_authorized_user_credentials( - filename, source_credentials_info - ) - else: - raise exceptions.InvalidType( - "source credential of type {} is not supported.".format( - source_credentials_type - ) - ) - impersonation_url = info.get("service_account_impersonation_url") - start_index = impersonation_url.rfind("/") - end_index = impersonation_url.find(":generateAccessToken") - if start_index == -1 or end_index == -1 or start_index > end_index: - raise exceptions.InvalidValue( - "Cannot extract target principal from {}".format(impersonation_url) - ) - target_principal = impersonation_url[start_index + 1 : end_index] - delegates = info.get("delegates") - quota_project_id = info.get("quota_project_id") - credentials = impersonated_credentials.Credentials( - source_credentials, - target_principal, - scopes, - delegates, - quota_project_id=quota_project_id, + credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + info, scopes=scopes ) except ValueError as caught_exc: msg = "Failed to load impersonated service account credentials from {}".format( diff --git a/contrib/python/google-auth/py3/google/auth/compute_engine/_metadata.py b/contrib/python/google-auth/py3/google/auth/compute_engine/_metadata.py index 06f99de0e2c..ddbe8ac2f70 100644 --- a/contrib/python/google-auth/py3/google/auth/compute_engine/_metadata.py +++ b/contrib/python/google-auth/py3/google/auth/compute_engine/_metadata.py @@ -159,6 +159,7 @@ def get( retry_count=5, headers=None, return_none_for_not_found_error=False, + timeout=_METADATA_DEFAULT_TIMEOUT, ): """Fetch a resource from the metadata server. @@ -178,6 +179,7 @@ def get( headers (Optional[Mapping[str, str]]): Headers for the request. return_none_for_not_found_error (Optional[bool]): If True, returns None for 404 error instead of throwing an exception. + timeout (int): How long to wait, in seconds for the metadata server to respond. Returns: Union[Mapping, str]: If the metadata server returns JSON, a mapping of @@ -204,7 +206,9 @@ def get( failure_reason = None for attempt in backoff: try: - response = request(url=url, method="GET", headers=headers_to_use) + response = request( + url=url, method="GET", headers=headers_to_use, timeout=timeout + ) if response.status in transport.DEFAULT_RETRYABLE_STATUS_CODES: _LOGGER.warning( "Compute Engine Metadata server unavailable on " diff --git a/contrib/python/google-auth/py3/google/auth/identity_pool.py b/contrib/python/google-auth/py3/google/auth/identity_pool.py index 47f9a55715c..c06f8842870 100644 --- a/contrib/python/google-auth/py3/google/auth/identity_pool.py +++ b/contrib/python/google-auth/py3/google/auth/identity_pool.py @@ -41,6 +41,7 @@ try: except ImportError: # pragma: NO COVER from collections import Mapping # type: ignore import abc +import base64 import json import os from typing import NamedTuple @@ -145,9 +146,88 @@ class _UrlSupplier(SubjectTokenSupplier): class _X509Supplier(SubjectTokenSupplier): """Internal supplier for X509 workload credentials. This class is used internally and always returns an empty string as the subject token.""" + def __init__(self, trust_chain_path, leaf_cert_callback): + self._trust_chain_path = trust_chain_path + self._leaf_cert_callback = leaf_cert_callback + @_helpers.copy_docstring(SubjectTokenSupplier) def get_subject_token(self, context, request): - return "" + # Import OpennSSL inline because it is an extra import only required by customers + # using mTLS. + from OpenSSL import crypto + + leaf_cert = crypto.load_certificate( + crypto.FILETYPE_PEM, self._leaf_cert_callback() + ) + trust_chain = self._read_trust_chain() + cert_chain = [] + + cert_chain.append(_X509Supplier._encode_cert(leaf_cert)) + + if trust_chain is None or len(trust_chain) == 0: + return json.dumps(cert_chain) + + # Append the first cert if it is not the leaf cert. + first_cert = _X509Supplier._encode_cert(trust_chain[0]) + if first_cert != cert_chain[0]: + cert_chain.append(first_cert) + + for i in range(1, len(trust_chain)): + encoded = _X509Supplier._encode_cert(trust_chain[i]) + # Check if the current cert is the leaf cert and raise an exception if it is. + if encoded == cert_chain[0]: + raise exceptions.RefreshError( + "The leaf certificate must be at the top of the trust chain file" + ) + else: + cert_chain.append(encoded) + return json.dumps(cert_chain) + + def _read_trust_chain(self): + # Import OpennSSL inline because it is an extra import only required by customers + # using mTLS. + from OpenSSL import crypto + + certificate_trust_chain = [] + # If no trust chain path was provided, return an empty list. + if self._trust_chain_path is None or self._trust_chain_path == "": + return certificate_trust_chain + try: + # Open the trust chain file. + with open(self._trust_chain_path, "rb") as f: + trust_chain_data = f.read() + # Split PEM data into individual certificates. + cert_blocks = trust_chain_data.split(b"-----BEGIN CERTIFICATE-----") + for cert_block in cert_blocks: + # Skip empty blocks. + if cert_block.strip(): + cert_data = b"-----BEGIN CERTIFICATE-----" + cert_block + try: + # Load each certificate and add it to the trust chain. + cert = crypto.load_certificate( + crypto.FILETYPE_PEM, cert_data + ) + certificate_trust_chain.append(cert) + except Exception as e: + raise exceptions.RefreshError( + "Error loading PEM certificates from the trust chain file '{}'".format( + self._trust_chain_path + ) + ) from e + return certificate_trust_chain + except FileNotFoundError: + raise exceptions.RefreshError( + "Trust chain file '{}' was not found.".format(self._trust_chain_path) + ) + + def _encode_cert(cert): + # Import OpennSSL inline because it is an extra import only required by customers + # using mTLS. + from OpenSSL import crypto + + return base64.b64encode( + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) + ).decode("utf-8") def _parse_token_data(token_content, format_type="text", subject_token_field_name=None): @@ -296,7 +376,9 @@ class Credentials(external_account.Credentials): self._credential_source_headers, ) else: # self._credential_source_certificate - self._subject_token_supplier = _X509Supplier() + self._subject_token_supplier = _X509Supplier( + self._trust_chain_path, self._get_cert_bytes + ) @_helpers.copy_docstring(external_account.Credentials) def retrieve_subject_token(self, request): @@ -314,6 +396,10 @@ class Credentials(external_account.Credentials): self._certificate_config_location ) + def _get_cert_bytes(self): + cert_path, _ = self._get_mtls_cert_and_key_paths() + return _mtls_helper._read_cert_file(cert_path) + def _mtls_required(self): return self._credential_source_certificate is not None @@ -350,6 +436,9 @@ class Credentials(external_account.Credentials): use_default = self._credential_source_certificate.get( "use_default_certificate_config" ) + self._trust_chain_path = self._credential_source_certificate.get( + "trust_chain_path" + ) if self._certificate_config_location and use_default: raise exceptions.MalformedError( "Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true." diff --git a/contrib/python/google-auth/py3/google/auth/impersonated_credentials.py b/contrib/python/google-auth/py3/google/auth/impersonated_credentials.py index ed7e3f00b1c..d49998cfbdc 100644 --- a/contrib/python/google-auth/py3/google/auth/impersonated_credentials.py +++ b/contrib/python/google-auth/py3/google/auth/impersonated_credentials.py @@ -47,6 +47,12 @@ _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds _GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" +_SOURCE_CREDENTIAL_AUTHORIZED_USER_TYPE = "authorized_user" +_SOURCE_CREDENTIAL_SERVICE_ACCOUNT_TYPE = "service_account" +_SOURCE_CREDENTIAL_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE = ( + "external_account_authorized_user" +) + def _make_iam_token_request( request, @@ -410,6 +416,75 @@ class Credentials( cred._target_scopes = scopes or default_scopes return cred + @classmethod + def from_impersonated_service_account_info(cls, info, scopes=None): + """Creates a Credentials instance from parsed impersonated service account credentials info. + + Args: + info (Mapping[str, str]): The impersonated service account credentials info in Google + format. + scopes (Sequence[str]): Optional list of scopes to include in the + credentials. + + Returns: + google.oauth2.credentials.Credentials: The constructed + credentials. + + Raises: + InvalidType: If the info["source_credentials"] are not a supported impersonation type + InvalidValue: If the info["service_account_impersonation_url"] is not in the expected format. + ValueError: If the info is not in the expected format. + """ + + source_credentials_info = info.get("source_credentials") + source_credentials_type = source_credentials_info.get("type") + if source_credentials_type == _SOURCE_CREDENTIAL_AUTHORIZED_USER_TYPE: + from google.oauth2 import credentials + + source_credentials = credentials.Credentials.from_authorized_user_info( + source_credentials_info + ) + elif source_credentials_type == _SOURCE_CREDENTIAL_SERVICE_ACCOUNT_TYPE: + from google.oauth2 import service_account + + source_credentials = service_account.Credentials.from_service_account_info( + source_credentials_info + ) + elif ( + source_credentials_type + == _SOURCE_CREDENTIAL_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE + ): + from google.auth import external_account_authorized_user + + source_credentials = external_account_authorized_user.Credentials.from_info( + source_credentials_info + ) + else: + raise exceptions.InvalidType( + "source credential of type {} is not supported.".format( + source_credentials_type + ) + ) + + impersonation_url = info.get("service_account_impersonation_url") + start_index = impersonation_url.rfind("/") + end_index = impersonation_url.find(":generateAccessToken") + if start_index == -1 or end_index == -1 or start_index > end_index: + raise exceptions.InvalidValue( + "Cannot extract target principal from {}".format(impersonation_url) + ) + target_principal = impersonation_url[start_index + 1 : end_index] + delegates = info.get("delegates") + quota_project_id = info.get("quota_project_id") + + return cls( + source_credentials, + target_principal, + scopes, + delegates, + quota_project_id=quota_project_id, + ) + class IDTokenCredentials(credentials.CredentialsWithQuotaProject): """Open ID Connect ID Token-based service account credentials. diff --git a/contrib/python/google-auth/py3/google/auth/transport/urllib3.py b/contrib/python/google-auth/py3/google/auth/transport/urllib3.py index 63144f5fffa..db4fa93ff11 100644 --- a/contrib/python/google-auth/py3/google/auth/transport/urllib3.py +++ b/contrib/python/google-auth/py3/google/auth/transport/urllib3.py @@ -34,13 +34,21 @@ except ImportError: # pragma: NO COVER try: import urllib3 # type: ignore import urllib3.exceptions # type: ignore + from packaging import version # type: ignore except ImportError as caught_exc: # pragma: NO COVER raise ImportError( - "The urllib3 library is not installed from please install the " - "urllib3 package to use the urllib3 transport." + "" + f"Error: {caught_exc}." + " The 'google-auth' library requires the extras installed " + "for urllib3 network transport." + "\n" + "Please install the necessary dependencies using pip:\n" + " pip install google-auth[urllib3]\n" + "\n" + "(Note: Using '[urllib3]' ensures the specific dependencies needed for this feature are installed. " + "We recommend running this command in your virtual environment.)" ) from caught_exc -from packaging import version # type: ignore from google.auth import environment_vars from google.auth import exceptions @@ -414,7 +422,7 @@ class AuthorizedHttp(RequestMethods): # type: ignore body=body, headers=headers, _credential_refresh_attempt=_credential_refresh_attempt + 1, - **kwargs + **kwargs, ) return response diff --git a/contrib/python/google-auth/py3/google/auth/version.py b/contrib/python/google-auth/py3/google/auth/version.py index 41a80e6c676..393caa8ad44 100644 --- a/contrib/python/google-auth/py3/google/auth/version.py +++ b/contrib/python/google-auth/py3/google/auth/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.38.0" +__version__ = "2.39.0" diff --git a/contrib/python/google-auth/py3/google/oauth2/id_token.py b/contrib/python/google-auth/py3/google/oauth2/id_token.py index b68ab6b303a..a6c51ce6381 100644 --- a/contrib/python/google-auth/py3/google/oauth2/id_token.py +++ b/contrib/python/google-auth/py3/google/oauth2/id_token.py @@ -284,6 +284,18 @@ def fetch_id_token_credentials(audience, request=None): return service_account.IDTokenCredentials.from_service_account_info( info, target_audience=audience ) + elif info.get("type") == "impersonated_service_account": + from google.auth import impersonated_credentials + + target_credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + + return impersonated_credentials.IDTokenCredentials( + target_credentials=target_credentials, + target_audience=audience, + include_email=True, + ) except ValueError as caught_exc: new_exc = exceptions.DefaultCredentialsError( "GOOGLE_APPLICATION_CREDENTIALS is not valid service account credentials.", diff --git a/contrib/python/google-auth/py3/patches/01-fix-tests.patch b/contrib/python/google-auth/py3/patches/01-fix-tests.patch index 1065289c64b..0494eab3354 100644 --- a/contrib/python/google-auth/py3/patches/01-fix-tests.patch +++ b/contrib/python/google-auth/py3/patches/01-fix-tests.patch @@ -47,7 +47,7 @@ +DATA_DIR = os.path.join(os.path.dirname(yc.source_path(__file__)), "..", "data") --- contrib/python/google-auth/py3/tests/oauth2/test_id_token.py (index) +++ contrib/python/google-auth/py3/tests/oauth2/test_id_token.py (working tree) -@@ -24,8 +24,9 @@ import google.auth.compute_engine._metadata +@@ -24,12 +24,13 @@ import google.auth.compute_engine._metadata from google.oauth2 import id_token from google.oauth2 import service_account @@ -56,7 +56,12 @@ - os.path.dirname(__file__), "../data/service_account.json" + os.path.dirname(yc.source_path(__file__)), "../data/service_account.json" ) - ID_TOKEN_AUDIENCE = "https://pubsub.googleapis.com" + + IMPERSONATED_SERVICE_ACCOUNT_FILE = os.path.join( +- os.path.dirname(__file__), ++ os.path.dirname(yc.source_path(__file__)), + "../data/impersonated_service_account_authorized_user_source.json", + ) @@ -265,1 +266,1 @@ def test_fetch_id_token_no_cred_exists(monkeypatch): - os.path.dirname(__file__), "../data/authorized_user.json" diff --git a/contrib/python/google-auth/py3/tests/compute_engine/test__metadata.py b/contrib/python/google-auth/py3/tests/compute_engine/test__metadata.py index a768b17fa0d..98d08fe4505 100644 --- a/contrib/python/google-auth/py3/tests/compute_engine/test__metadata.py +++ b/contrib/python/google-auth/py3/tests/compute_engine/test__metadata.py @@ -176,6 +176,7 @@ def test_get_success_json(): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert result[key] == value @@ -194,6 +195,7 @@ def test_get_success_json_content_type_charset(): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert result[key] == value @@ -213,6 +215,7 @@ def test_get_success_retry(mock_sleep): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert request.call_count == 2 assert result[key] == value @@ -228,6 +231,7 @@ def test_get_success_text(): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert result == data @@ -243,6 +247,7 @@ def test_get_success_params(): method="GET", url=_metadata._METADATA_ROOT + PATH + "?recursive=true", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert result == data @@ -257,6 +262,7 @@ def test_get_success_recursive_and_params(): method="GET", url=_metadata._METADATA_ROOT + PATH + "?recursive=true", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert result == data @@ -271,6 +277,7 @@ def test_get_success_recursive(): method="GET", url=_metadata._METADATA_ROOT + PATH + "?recursive=true", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert result == data @@ -292,6 +299,7 @@ def _test_get_success_custom_root_new_variable(): method="GET", url="http://{}/computeMetadata/v1/{}".format(fake_root, PATH), headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -312,6 +320,7 @@ def _test_get_success_custom_root_old_variable(): method="GET", url="http://{}/computeMetadata/v1/{}".format(fake_root, PATH), headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -328,6 +337,7 @@ def test_get_failure(mock_sleep): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -340,6 +350,7 @@ def test_get_return_none_for_not_found_error(): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -359,6 +370,7 @@ def test_get_failure_connection_failed(mock_sleep): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert request.call_count == 5 @@ -377,6 +389,7 @@ def test_get_too_many_requests_retryable_error_failure(): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert request.call_count == 5 @@ -393,6 +406,7 @@ def test_get_failure_bad_json(): method="GET", url=_metadata._METADATA_ROOT + PATH, headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -406,6 +420,7 @@ def test_get_project_id(): method="GET", url=_metadata._METADATA_ROOT + "project/project-id", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert project_id == project @@ -421,6 +436,7 @@ def test_get_universe_domain_success(): method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert universe_domain == "fake_universe_domain" @@ -434,6 +450,7 @@ def test_get_universe_domain_success_empty_response(): method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert universe_domain == "googleapis.com" @@ -449,6 +466,7 @@ def test_get_universe_domain_not_found(): method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert universe_domain == "googleapis.com" @@ -469,6 +487,7 @@ def test_get_universe_domain_retryable_error_failure(): method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert request.call_count == 5 @@ -511,11 +530,13 @@ def test_get_universe_domain_retryable_error_success(): method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) request_ok.assert_called_once_with( method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert universe_domain == "fake_universe_domain" @@ -535,6 +556,7 @@ def test_get_universe_domain_other_error(): method="GET", url=_metadata._METADATA_ROOT + "universe/universe-domain", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -559,6 +581,7 @@ def test_get_service_account_token(utcnow, mock_metrics_header_value): "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, }, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert token == "token" assert expiry == utcnow() + datetime.timedelta(seconds=ttl) @@ -585,6 +608,7 @@ def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, }, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert token == "token" assert expiry == utcnow() + datetime.timedelta(seconds=ttl) @@ -613,6 +637,7 @@ def test_get_service_account_token_with_scopes_string( "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, }, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert token == "token" assert expiry == utcnow() + datetime.timedelta(seconds=ttl) @@ -630,6 +655,7 @@ def test_get_service_account_info(): method="GET", url=_metadata._METADATA_ROOT + PATH + "/?recursive=true", headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert info[key] == value diff --git a/contrib/python/google-auth/py3/tests/data/trust_chain_with_leaf.pem b/contrib/python/google-auth/py3/tests/data/trust_chain_with_leaf.pem new file mode 100644 index 00000000000..250387d9d59 --- /dev/null +++ b/contrib/python/google-auth/py3/tests/data/trust_chain_with_leaf.pem @@ -0,0 +1,52 @@ +-----BEGIN CERTIFICATE----- +MIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV +BAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV +MRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM +7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer +uQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp +gyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4 ++WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3 +ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O +gN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh +GaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD +AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr +odJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk ++JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9 +ovNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql +ybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT +cDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIFtTCCA52gAwIBAgIJAPBsLZmNGfKtMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTYwOTIxMDI0NTEyWhcNMTYxMDIxMDI0NTEyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC +CgKCAgEAsiMC7mTsmUXwZoYlT4aHY1FLw8bxIXC+z3IqA+TY1WqfbeiZRo8MA5Zx +lTTxYMKPCZUE1XBc7jvD8GJhWIj6pToPYHn73B01IBkLBxq4kF1yV2Z7DVmkvc6H +EcxXXq8zkCx0j6XOfiI4+qkXnuQn8cvrk8xfhtnMMZM7iVm6VSN93iRP/8ey6xuL +XTHrDX7ukoRce1hpT8O+15GXNrY0irhhYQz5xKibNCJF3EjV28WMry8y7I8uYUFU +RWDiQawwK9ec1zhZ94v92+GZDlPevmcFmSERKYQ0NsKcT0Y3lGuGnaExs8GyOpnC +oksu4YJGXQjg7lkv4MxzsNbRqmCkUwxw1Mg6FP0tsCNsw9qTrkvWCRA9zp/aU+sZ +IBGh1t4UGCub8joeQFvHxvr/3F7mH/dyvCjA34u0Lo1VPx+jYUIi9i0odltMspDW +xOpjqdGARZYmlJP5Au9q5cQjPMcwS/EBIb8cwNl32mUE6WnFlep+38mNR/FghIjO +ViAkXuKQmcHe6xppZAoHFsO/t3l4Tjek5vNW7erI1rgrFku/fvkIW/G8V1yIm/+Q +F+CE4maQzCJfhftpkhM/sPC/FuLNBmNE8BHVX8y58xG4is/cQxL4Z9TsFIw0C5+3 +uTrFW9D0agysahMVzPGtCqhDQqJdIJrBQqlS6bztpzBA8zEI0skCAwEAAaOBpzCB +pDAdBgNVHQ4EFgQUz/8FmW6TfqXyNJZr7rhc+Tn5sKQwdQYDVR0jBG4wbIAUz/8F +mW6TfqXyNJZr7rhc+Tn5sKShSaRHMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpT +b21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGSCCQDw +bC2ZjRnyrTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4ICAQCQmrcfhurX +riR3Q0Y+nq040/3dJIAJXjyI9CEtxaU0nzCNTng7PwgZ0CKmCelQfInuwWFwBSHS +6kBfC1rgJeFnjnTt8a3RCgRlIgUr9NCdPSEccB7TurobwPJ2h6cJjjR8urcb0CXh +CEMvPneyPj0xUFY8vVKXMGWahz/kyfwIiVqcX/OtMZ29fUu1onbWl71g2gVLtUZl +sECdZ+AC/6HDCVpYIVETMl1T7N/XyqXZQiDLDNRDeZhnapz8w9fsW1KVujAZLNQR +pVnw2qa2UK1dSf2FHX+lQU5mFSYM4vtwaMlX/LgfdLZ9I796hFh619WwTVz+LO2N +vHnwBMabld3XSPuZRqlbBulDQ07Vbqdjv8DYSLA2aKI4ZkMMKuFLG/oS28V2ZYmv +/KpGEs5UgKY+P9NulYpTDwCU/6SomuQpP795wbG6sm7Hzq82r2RmB61GupNRGeqi +pXKsy69T388zBxYu6zQrosXiDl5YzaViH7tm0J7opye8dCWjjpnahki0vq2znti7 +6cWla2j8Xz1glvLz+JI/NCOMfxUInb82T7ijo80N0VJ2hzf7p2GxRZXAxAV9knLI +nM4F5TLjSd7ZhOOZ7ni/eZFueTMisWfypt2nc41whGjHMX/Zp1kPfhB4H2bLKIX/ +lSrwNr3qbGTEJX8JqpDBNVAd96XkMvDNyA== +-----END CERTIFICATE-----
\ No newline at end of file diff --git a/contrib/python/google-auth/py3/tests/data/trust_chain_without_leaf.pem b/contrib/python/google-auth/py3/tests/data/trust_chain_without_leaf.pem new file mode 100644 index 00000000000..9da0f37fedf --- /dev/null +++ b/contrib/python/google-auth/py3/tests/data/trust_chain_without_leaf.pem @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFtTCCA52gAwIBAgIJAPBsLZmNGfKtMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTYwOTIxMDI0NTEyWhcNMTYxMDIxMDI0NTEyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC +CgKCAgEAsiMC7mTsmUXwZoYlT4aHY1FLw8bxIXC+z3IqA+TY1WqfbeiZRo8MA5Zx +lTTxYMKPCZUE1XBc7jvD8GJhWIj6pToPYHn73B01IBkLBxq4kF1yV2Z7DVmkvc6H +EcxXXq8zkCx0j6XOfiI4+qkXnuQn8cvrk8xfhtnMMZM7iVm6VSN93iRP/8ey6xuL +XTHrDX7ukoRce1hpT8O+15GXNrY0irhhYQz5xKibNCJF3EjV28WMry8y7I8uYUFU +RWDiQawwK9ec1zhZ94v92+GZDlPevmcFmSERKYQ0NsKcT0Y3lGuGnaExs8GyOpnC +oksu4YJGXQjg7lkv4MxzsNbRqmCkUwxw1Mg6FP0tsCNsw9qTrkvWCRA9zp/aU+sZ +IBGh1t4UGCub8joeQFvHxvr/3F7mH/dyvCjA34u0Lo1VPx+jYUIi9i0odltMspDW +xOpjqdGARZYmlJP5Au9q5cQjPMcwS/EBIb8cwNl32mUE6WnFlep+38mNR/FghIjO +ViAkXuKQmcHe6xppZAoHFsO/t3l4Tjek5vNW7erI1rgrFku/fvkIW/G8V1yIm/+Q +F+CE4maQzCJfhftpkhM/sPC/FuLNBmNE8BHVX8y58xG4is/cQxL4Z9TsFIw0C5+3 +uTrFW9D0agysahMVzPGtCqhDQqJdIJrBQqlS6bztpzBA8zEI0skCAwEAAaOBpzCB +pDAdBgNVHQ4EFgQUz/8FmW6TfqXyNJZr7rhc+Tn5sKQwdQYDVR0jBG4wbIAUz/8F +mW6TfqXyNJZr7rhc+Tn5sKShSaRHMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpT +b21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGSCCQDw +bC2ZjRnyrTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4ICAQCQmrcfhurX +riR3Q0Y+nq040/3dJIAJXjyI9CEtxaU0nzCNTng7PwgZ0CKmCelQfInuwWFwBSHS +6kBfC1rgJeFnjnTt8a3RCgRlIgUr9NCdPSEccB7TurobwPJ2h6cJjjR8urcb0CXh +CEMvPneyPj0xUFY8vVKXMGWahz/kyfwIiVqcX/OtMZ29fUu1onbWl71g2gVLtUZl +sECdZ+AC/6HDCVpYIVETMl1T7N/XyqXZQiDLDNRDeZhnapz8w9fsW1KVujAZLNQR +pVnw2qa2UK1dSf2FHX+lQU5mFSYM4vtwaMlX/LgfdLZ9I796hFh619WwTVz+LO2N +vHnwBMabld3XSPuZRqlbBulDQ07Vbqdjv8DYSLA2aKI4ZkMMKuFLG/oS28V2ZYmv +/KpGEs5UgKY+P9NulYpTDwCU/6SomuQpP795wbG6sm7Hzq82r2RmB61GupNRGeqi +pXKsy69T388zBxYu6zQrosXiDl5YzaViH7tm0J7opye8dCWjjpnahki0vq2znti7 +6cWla2j8Xz1glvLz+JI/NCOMfxUInb82T7ijo80N0VJ2hzf7p2GxRZXAxAV9knLI +nM4F5TLjSd7ZhOOZ7ni/eZFueTMisWfypt2nc41whGjHMX/Zp1kPfhB4H2bLKIX/ +lSrwNr3qbGTEJX8JqpDBNVAd96XkMvDNyA== +-----END CERTIFICATE-----
\ No newline at end of file diff --git a/contrib/python/google-auth/py3/tests/data/trust_chain_wrong_order.pem b/contrib/python/google-auth/py3/tests/data/trust_chain_wrong_order.pem new file mode 100644 index 00000000000..e8dc5d35931 --- /dev/null +++ b/contrib/python/google-auth/py3/tests/data/trust_chain_wrong_order.pem @@ -0,0 +1,52 @@ +-----BEGIN CERTIFICATE----- +MIIFtTCCA52gAwIBAgIJAPBsLZmNGfKtMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTYwOTIxMDI0NTEyWhcNMTYxMDIxMDI0NTEyWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC +CgKCAgEAsiMC7mTsmUXwZoYlT4aHY1FLw8bxIXC+z3IqA+TY1WqfbeiZRo8MA5Zx +lTTxYMKPCZUE1XBc7jvD8GJhWIj6pToPYHn73B01IBkLBxq4kF1yV2Z7DVmkvc6H +EcxXXq8zkCx0j6XOfiI4+qkXnuQn8cvrk8xfhtnMMZM7iVm6VSN93iRP/8ey6xuL +XTHrDX7ukoRce1hpT8O+15GXNrY0irhhYQz5xKibNCJF3EjV28WMry8y7I8uYUFU +RWDiQawwK9ec1zhZ94v92+GZDlPevmcFmSERKYQ0NsKcT0Y3lGuGnaExs8GyOpnC +oksu4YJGXQjg7lkv4MxzsNbRqmCkUwxw1Mg6FP0tsCNsw9qTrkvWCRA9zp/aU+sZ +IBGh1t4UGCub8joeQFvHxvr/3F7mH/dyvCjA34u0Lo1VPx+jYUIi9i0odltMspDW +xOpjqdGARZYmlJP5Au9q5cQjPMcwS/EBIb8cwNl32mUE6WnFlep+38mNR/FghIjO +ViAkXuKQmcHe6xppZAoHFsO/t3l4Tjek5vNW7erI1rgrFku/fvkIW/G8V1yIm/+Q +F+CE4maQzCJfhftpkhM/sPC/FuLNBmNE8BHVX8y58xG4is/cQxL4Z9TsFIw0C5+3 +uTrFW9D0agysahMVzPGtCqhDQqJdIJrBQqlS6bztpzBA8zEI0skCAwEAAaOBpzCB +pDAdBgNVHQ4EFgQUz/8FmW6TfqXyNJZr7rhc+Tn5sKQwdQYDVR0jBG4wbIAUz/8F +mW6TfqXyNJZr7rhc+Tn5sKShSaRHMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpT +b21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGSCCQDw +bC2ZjRnyrTAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBBQUAA4ICAQCQmrcfhurX +riR3Q0Y+nq040/3dJIAJXjyI9CEtxaU0nzCNTng7PwgZ0CKmCelQfInuwWFwBSHS +6kBfC1rgJeFnjnTt8a3RCgRlIgUr9NCdPSEccB7TurobwPJ2h6cJjjR8urcb0CXh +CEMvPneyPj0xUFY8vVKXMGWahz/kyfwIiVqcX/OtMZ29fUu1onbWl71g2gVLtUZl +sECdZ+AC/6HDCVpYIVETMl1T7N/XyqXZQiDLDNRDeZhnapz8w9fsW1KVujAZLNQR +pVnw2qa2UK1dSf2FHX+lQU5mFSYM4vtwaMlX/LgfdLZ9I796hFh619WwTVz+LO2N +vHnwBMabld3XSPuZRqlbBulDQ07Vbqdjv8DYSLA2aKI4ZkMMKuFLG/oS28V2ZYmv +/KpGEs5UgKY+P9NulYpTDwCU/6SomuQpP795wbG6sm7Hzq82r2RmB61GupNRGeqi +pXKsy69T388zBxYu6zQrosXiDl5YzaViH7tm0J7opye8dCWjjpnahki0vq2znti7 +6cWla2j8Xz1glvLz+JI/NCOMfxUInb82T7ijo80N0VJ2hzf7p2GxRZXAxAV9knLI +nM4F5TLjSd7ZhOOZ7ni/eZFueTMisWfypt2nc41whGjHMX/Zp1kPfhB4H2bLKIX/ +lSrwNr3qbGTEJX8JqpDBNVAd96XkMvDNyA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDIzCCAgugAwIBAgIJAMfISuBQ5m+5MA0GCSqGSIb3DQEBBQUAMBUxEzARBgNV +BAMTCnVuaXQtdGVzdHMwHhcNMTExMjA2MTYyNjAyWhcNMjExMjAzMTYyNjAyWjAV +MRMwEQYDVQQDEwp1bml0LXRlc3RzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZgkdmM +7oVK2OfgrSj/FCTkInKPqaCR0gD7K80q+mLBrN3PUkDrJQZpvRZIff3/xmVU1Wer +uQLFJjnFb2dqu0s/FY/2kWiJtBCakXvXEOb7zfbINuayL+MSsCGSdVYsSliS5qQp +gyDap+8b5fpXZVJkq92hrcNtbkg7hCYUJczt8n9hcCTJCfUpApvaFQ18pe+zpyl4 ++WzkP66I28hniMQyUlA1hBiskT7qiouq0m8IOodhv2fagSZKjOTTU2xkSBc//fy3 +ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQABo3YwdDAdBgNVHQ4EFgQU2RQ8yO+O +gN8oVW2SW7RLrfYd9jEwRQYDVR0jBD4wPIAU2RQ8yO+OgN8oVW2SW7RLrfYd9jGh +GaQXMBUxEzARBgNVBAMTCnVuaXQtdGVzdHOCCQDHyErgUOZvuTAMBgNVHRMEBTAD +AQH/MA0GCSqGSIb3DQEBBQUAA4IBAQBRv+M/6+FiVu7KXNjFI5pSN17OcW5QUtPr +odJMlWrJBtynn/TA1oJlYu3yV5clc/71Vr/AxuX5xGP+IXL32YDF9lTUJXG/uUGk ++JETpKmQviPbRsvzYhz4pf6ZIOZMc3/GIcNq92ECbseGO+yAgyWUVKMmZM0HqXC9 +ovNslqe0M8C1sLm1zAR5z/h/litE7/8O2ietija3Q/qtl2TOXJdCA6sgjJX2WUql +ybrC55ct18NKf3qhpcEkGQvFU40rVYApJpi98DiZPYFdx1oBDp/f4uZ3ojpxRVFT +cDwcJLfNRCPUhormsY7fDS9xSyThiHsW9mjJYdcaKQkwYZ0F11yB +-----END CERTIFICATE-----
\ No newline at end of file diff --git a/contrib/python/google-auth/py3/tests/oauth2/test_id_token.py b/contrib/python/google-auth/py3/tests/oauth2/test_id_token.py index 65189df128c..5dc125fb566 100644 --- a/contrib/python/google-auth/py3/tests/oauth2/test_id_token.py +++ b/contrib/python/google-auth/py3/tests/oauth2/test_id_token.py @@ -20,6 +20,7 @@ import pytest # type: ignore from google.auth import environment_vars from google.auth import exceptions +from google.auth import impersonated_credentials from google.auth import transport from google.oauth2 import id_token from google.oauth2 import service_account @@ -28,6 +29,12 @@ import yatest.common as yc SERVICE_ACCOUNT_FILE = os.path.join( os.path.dirname(yc.source_path(__file__)), "../data/service_account.json" ) + +IMPERSONATED_SERVICE_ACCOUNT_FILE = os.path.join( + os.path.dirname(yc.source_path(__file__)), + "../data/impersonated_service_account_authorized_user_source.json", +) + ID_TOKEN_AUDIENCE = "https://pubsub.googleapis.com" @@ -263,6 +270,14 @@ def test_fetch_id_token_credentials_from_explicit_cred_json_file(monkeypatch): assert cred._target_audience == ID_TOKEN_AUDIENCE +def test_fetch_id_token_credentials_from_impersonated_cred_json_file(monkeypatch): + monkeypatch.setenv(environment_vars.CREDENTIALS, IMPERSONATED_SERVICE_ACCOUNT_FILE) + + cred = id_token.fetch_id_token_credentials(ID_TOKEN_AUDIENCE) + assert isinstance(cred, impersonated_credentials.IDTokenCredentials) + assert cred._target_audience == ID_TOKEN_AUDIENCE + + def test_fetch_id_token_credentials_no_cred_exists(monkeypatch): monkeypatch.delenv(environment_vars.CREDENTIALS, raising=False) diff --git a/contrib/python/google-auth/py3/tests/test__oauth2client.py b/contrib/python/google-auth/py3/tests/test__oauth2client.py index 1db595fd9ac..61eaf17c2de 100644 --- a/contrib/python/google-auth/py3/tests/test__oauth2client.py +++ b/contrib/python/google-auth/py3/tests/test__oauth2client.py @@ -117,6 +117,14 @@ def _test__convert_appengine_app_assertion_credentials( app_identity, mock_oauth2client_gae_imports ): + # `oauth2client` requires `cgi` which was removed in Python 3.13 + # See https://github.com/googleapis/oauth2client/blob/50d20532a748f18e53f7d24ccbe6647132c979a9/oauth2client/contrib/appengine.py#L20 + # oauth2client is no longer being updated so this test must be skipped on newer Python Runtimes + if sys.version_info >= (3, 13): # pragma: NO COVER + pytest.skip( + "Skipping test for Python 3.13+ due to oauth2client incompatibility." + ) + import oauth2client.contrib.appengine # type: ignore service_account_id = "service_account_id" @@ -166,6 +174,14 @@ def reset__oauth2client_module(): def _test_import_has_app_engine( mock_oauth2client_gae_imports, reset__oauth2client_module ): + # `oauth2client` requires `cgi` which was removed in Python 3.13 + # See https://github.com/googleapis/oauth2client/blob/50d20532a748f18e53f7d24ccbe6647132c979a9/oauth2client/contrib/appengine.py#L20 + # oauth2client is no longer being updated so this test must be skipped on newer Python Runtimes + if sys.version_info >= (3, 13): # pragma: NO COVER + pytest.skip( + "Skipping test for Python 3.13+ due to oauth2client incompatibility." + ) + importlib.reload(_oauth2client) assert _oauth2client._HAS_APPENGINE diff --git a/contrib/python/google-auth/py3/tests/test_identity_pool.py b/contrib/python/google-auth/py3/tests/test_identity_pool.py index cc6cbf08827..4d78a5c22ea 100644 --- a/contrib/python/google-auth/py3/tests/test_identity_pool.py +++ b/contrib/python/google-auth/py3/tests/test_identity_pool.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import datetime import http.client as http_client import json @@ -19,6 +20,7 @@ import os import urllib import mock +from OpenSSL import crypto import pytest # type: ignore from google.auth import _helpers, external_account @@ -49,6 +51,13 @@ import yatest.common as yc DATA_DIR = os.path.join(os.path.dirname(yc.source_path(__file__)), "data") SUBJECT_TOKEN_TEXT_FILE = os.path.join(DATA_DIR, "external_subject_token.txt") SUBJECT_TOKEN_JSON_FILE = os.path.join(DATA_DIR, "external_subject_token.json") +TRUST_CHAIN_WITH_LEAF_FILE = os.path.join(DATA_DIR, "trust_chain_with_leaf.pem") +TRUST_CHAIN_WITHOUT_LEAF_FILE = os.path.join(DATA_DIR, "trust_chain_without_leaf.pem") +TRUST_CHAIN_WRONG_ORDER_FILE = os.path.join(DATA_DIR, "trust_chain_wrong_order.pem") +CERT_FILE = os.path.join(DATA_DIR, "public_cert.pem") +KEY_FILE = os.path.join(DATA_DIR, "privatekey.pem") +OTHER_CERT_FILE = os.path.join(DATA_DIR, "other_cert.pem") + SUBJECT_TOKEN_FIELD_NAME = "access_token" with open(SUBJECT_TOKEN_TEXT_FILE) as fh: @@ -58,6 +67,20 @@ with open(SUBJECT_TOKEN_JSON_FILE) as fh: JSON_FILE_CONTENT = json.load(fh) JSON_FILE_SUBJECT_TOKEN = JSON_FILE_CONTENT.get(SUBJECT_TOKEN_FIELD_NAME) +with open(CERT_FILE, "rb") as f: + CERT_FILE_CONTENT = base64.b64encode( + crypto.dump_certificate( + crypto.FILETYPE_ASN1, crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) + ) + ).decode("utf-8") + +with open(OTHER_CERT_FILE, "rb") as f: + OTHER_CERT_FILE_CONTENT = base64.b64encode( + crypto.dump_certificate( + crypto.FILETYPE_ASN1, crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) + ) + ).decode("utf-8") + TOKEN_URL = "https://sts.googleapis.com/v1/token" TOKEN_INFO_URL = "https://sts.googleapis.com/v1/introspect" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" @@ -186,6 +209,24 @@ class TestCredentials(object): CREDENTIAL_SOURCE_CERTIFICATE_NOT_DEFAULT = { "certificate": {"certificate_config_location": "path/to/config"} } + CREDENTIAL_SOURCE_CERTIFICATE_TRUST_CHAIN_WITH_LEAF = { + "certificate": { + "use_default_certificate_config": "true", + "trust_chain_path": TRUST_CHAIN_WITH_LEAF_FILE, + } + } + CREDENTIAL_SOURCE_CERTIFICATE_TRUST_CHAIN_WITHOUT_LEAF = { + "certificate": { + "use_default_certificate_config": "true", + "trust_chain_path": TRUST_CHAIN_WITHOUT_LEAF_FILE, + } + } + CREDENTIAL_SOURCE_CERTIFICATE_TRUST_CHAIN_WRONG_ORDER = { + "certificate": { + "use_default_certificate_config": "true", + "trust_chain_path": TRUST_CHAIN_WRONG_ORDER_FILE, + } + } SUCCESS_RESPONSE = { "access_token": "ACCESS_TOKEN", "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", @@ -937,14 +978,126 @@ class TestCredentials(object): assert subject_token == JSON_FILE_SUBJECT_TOKEN - def test_retrieve_subject_token_certificate(self): + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_default( + self, mock_get_workload_cert_and_key_paths + ): credentials = self.make_credentials( credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE ) subject_token = credentials.retrieve_subject_token(None) - assert subject_token == "" + assert subject_token == json.dumps([CERT_FILE_CONTENT]) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_non_default_path( + self, mock_get_workload_cert_and_key_paths + ): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE_NOT_DEFAULT + ) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == json.dumps([CERT_FILE_CONTENT]) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_trust_chain_with_leaf( + self, mock_get_workload_cert_and_key_paths + ): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE_TRUST_CHAIN_WITH_LEAF + ) + + subject_token = credentials.retrieve_subject_token(None) + assert subject_token == json.dumps([CERT_FILE_CONTENT, OTHER_CERT_FILE_CONTENT]) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_trust_chain_without_leaf( + self, mock_get_workload_cert_and_key_paths + ): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE_TRUST_CHAIN_WITHOUT_LEAF + ) + + subject_token = credentials.retrieve_subject_token(None) + assert subject_token == json.dumps([CERT_FILE_CONTENT, OTHER_CERT_FILE_CONTENT]) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_trust_chain_invalid_order( + self, mock_get_workload_cert_and_key_paths + ): + + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_CERTIFICATE_TRUST_CHAIN_WRONG_ORDER + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match( + "The leaf certificate must be at the top of the trust chain file" + ) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_trust_chain_file_does_not_exist( + self, mock_get_workload_cert_and_key_paths + ): + + credentials = self.make_credentials( + credential_source={ + "certificate": { + "use_default_certificate_config": "true", + "trust_chain_path": "fake.pem", + } + } + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match("Trust chain file 'fake.pem' was not found.") + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key_paths", + return_value=(CERT_FILE, KEY_FILE), + ) + def test_retrieve_subject_token_certificate_invalid_trust_chain_file( + self, mock_get_workload_cert_and_key_paths + ): + + credentials = self.make_credentials( + credential_source={ + "certificate": { + "use_default_certificate_config": "true", + "trust_chain_path": SUBJECT_TOKEN_TEXT_FILE, + } + } + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match("Error loading PEM certificates from the trust chain file") def test_retrieve_subject_token_json_file_invalid_field_name(self): credential_source = { diff --git a/contrib/python/google-auth/py3/tests/test_impersonated_credentials.py b/contrib/python/google-auth/py3/tests/test_impersonated_credentials.py index 0321a1a1d7b..9aeb505fdd9 100644 --- a/contrib/python/google-auth/py3/tests/test_impersonated_credentials.py +++ b/contrib/python/google-auth/py3/tests/test_impersonated_credentials.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import datetime import http.client as http_client import json @@ -36,6 +37,9 @@ with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: PRIVATE_KEY_BYTES = fh.read() SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") +IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_FILE = os.path.join( + DATA_DIR, "impersonated_service_account_authorized_user_source.json" +) ID_TOKEN_DATA = ( "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyOTNhZDk3N2Ew" @@ -50,6 +54,9 @@ ID_TOKEN_EXPIRY = 1564475051 with open(SERVICE_ACCOUNT_JSON_FILE, "rb") as fh: SERVICE_ACCOUNT_INFO = json.load(fh) +with open(IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_FILE, "rb") as fh: + IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO = json.load(fh) + SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") TOKEN_URI = "https://example.com/oauth2/token" @@ -149,6 +156,38 @@ class TestImpersonatedCredentials(object): iam_endpoint_override=iam_endpoint_override, ) + def test_from_impersonated_service_account_info(self): + credentials = impersonated_credentials.Credentials.from_impersonated_service_account_info( + IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO + ) + assert isinstance(credentials, impersonated_credentials.Credentials) + + def test_from_impersonated_service_account_info_with_invalid_source_credentials_type( + self + ): + info = copy.deepcopy(IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO) + assert "source_credentials" in info + # Set the source_credentials to an invalid type + info["source_credentials"]["type"] = "invalid_type" + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + assert excinfo.match( + "source credential of type {} is not supported".format("invalid_type") + ) + + def test_from_impersonated_service_account_info_with_invalid_impersonation_url( + self + ): + info = copy.deepcopy(IMPERSONATED_SERVICE_ACCOUNT_AUTHORIZED_USER_SOURCE_INFO) + info["service_account_impersonation_url"] = "invalid_url" + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + impersonated_credentials.Credentials.from_impersonated_service_account_info( + info + ) + assert excinfo.match(r"Cannot extract target principal from") + def test_get_cred_info(self): credentials = self.make_credentials() assert not credentials.get_cred_info() diff --git a/contrib/python/google-auth/py3/ya.make b/contrib/python/google-auth/py3/ya.make index bb762919dc6..eddecb079e5 100644 --- a/contrib/python/google-auth/py3/ya.make +++ b/contrib/python/google-auth/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(2.38.0) +VERSION(2.39.0) LICENSE(Apache-2.0) diff --git a/contrib/python/prompt-toolkit/py3/.dist-info/METADATA b/contrib/python/prompt-toolkit/py3/.dist-info/METADATA index 8d4f5d343d2..265fb8183bd 100644 --- a/contrib/python/prompt-toolkit/py3/.dist-info/METADATA +++ b/contrib/python/prompt-toolkit/py3/.dist-info/METADATA @@ -1,8 +1,7 @@ -Metadata-Version: 2.2 +Metadata-Version: 2.4 Name: prompt_toolkit -Version: 3.0.50 +Version: 3.0.51 Summary: Library for building powerful interactive command lines in Python -Home-page: https://github.com/prompt-toolkit/python-prompt-toolkit Author: Jonathan Slenders Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers @@ -18,19 +17,12 @@ Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: 3 :: Only Classifier: Programming Language :: Python Classifier: Topic :: Software Development -Requires-Python: >=3.8.0 +Requires-Python: >=3.8 Description-Content-Type: text/x-rst License-File: LICENSE License-File: AUTHORS.rst Requires-Dist: wcwidth -Dynamic: author -Dynamic: classifier -Dynamic: description -Dynamic: description-content-type -Dynamic: home-page -Dynamic: requires-dist -Dynamic: requires-python -Dynamic: summary +Dynamic: license-file Python Prompt Toolkit ===================== diff --git a/contrib/python/prompt-toolkit/py3/prompt_toolkit/__init__.py b/contrib/python/prompt-toolkit/py3/prompt_toolkit/__init__.py index 94727e7cb22..ebaa57dc81b 100644 --- a/contrib/python/prompt-toolkit/py3/prompt_toolkit/__init__.py +++ b/contrib/python/prompt-toolkit/py3/prompt_toolkit/__init__.py @@ -17,6 +17,7 @@ Probably, to get started, you might also want to have a look at from __future__ import annotations import re +from importlib import metadata # note: this is a bit more lax than the actual pep 440 to allow for a/b/rc/dev without a number pep440 = re.compile( @@ -28,7 +29,7 @@ from .formatted_text import ANSI, HTML from .shortcuts import PromptSession, print_formatted_text, prompt # Don't forget to update in `docs/conf.py`! -__version__ = "3.0.50" +__version__ = metadata.version("prompt_toolkit") assert pep440.match(__version__) diff --git a/contrib/python/prompt-toolkit/py3/prompt_toolkit/completion/nested.py b/contrib/python/prompt-toolkit/py3/prompt_toolkit/completion/nested.py index 8569bd2cff7..b72b69ee212 100644 --- a/contrib/python/prompt-toolkit/py3/prompt_toolkit/completion/nested.py +++ b/contrib/python/prompt-toolkit/py3/prompt_toolkit/completion/nested.py @@ -69,7 +69,7 @@ class NestedCompleter(Completer): elif isinstance(value, dict): options[key] = cls.from_nested_dict(value) elif isinstance(value, set): - options[key] = cls.from_nested_dict({item: None for item in value}) + options[key] = cls.from_nested_dict(dict.fromkeys(value)) else: assert value is None options[key] = None diff --git a/contrib/python/prompt-toolkit/py3/prompt_toolkit/formatted_text/utils.py b/contrib/python/prompt-toolkit/py3/prompt_toolkit/formatted_text/utils.py index 43228c3cda1..a6f78cb4e06 100644 --- a/contrib/python/prompt-toolkit/py3/prompt_toolkit/formatted_text/utils.py +++ b/contrib/python/prompt-toolkit/py3/prompt_toolkit/formatted_text/utils.py @@ -89,8 +89,7 @@ def split_lines( parts = string.split("\n") for part in parts[:-1]: - if part: - line.append(cast(OneStyleAndTextTuple, (style, part, *mouse_handler))) + line.append(cast(OneStyleAndTextTuple, (style, part, *mouse_handler))) yield line line = [] diff --git a/contrib/python/prompt-toolkit/py3/tests/test_filter.py b/contrib/python/prompt-toolkit/py3/tests/test_filter.py index f7184c286f2..70ddd5cdf59 100644 --- a/contrib/python/prompt-toolkit/py3/tests/test_filter.py +++ b/contrib/python/prompt-toolkit/py3/tests/test_filter.py @@ -16,7 +16,7 @@ def test_always(): def test_invert(): assert not (~Always())() - assert ~Never()() + assert (~Never())() c = ~Condition(lambda: False) assert c() diff --git a/contrib/python/prompt-toolkit/py3/tests/test_formatted_text.py b/contrib/python/prompt-toolkit/py3/tests/test_formatted_text.py index 843aac16191..60f9cdf459b 100644 --- a/contrib/python/prompt-toolkit/py3/tests/test_formatted_text.py +++ b/contrib/python/prompt-toolkit/py3/tests/test_formatted_text.py @@ -274,7 +274,7 @@ def test_split_lines_3(): lines = list(split_lines([("class:a", "\n")])) assert lines == [ - [], + [("class:a", "")], [("class:a", "")], ] @@ -284,3 +284,15 @@ def test_split_lines_3(): assert lines == [ [("class:a", "")], ] + + +def test_split_lines_4(): + "Edge cases: inputs starting and ending with newlines." + # -1- + lines = list(split_lines([("class:a", "\nline1\n")])) + + assert lines == [ + [("class:a", "")], + [("class:a", "line1")], + [("class:a", "")], + ] diff --git a/contrib/python/prompt-toolkit/py3/ya.make b/contrib/python/prompt-toolkit/py3/ya.make index 5eed9c2519b..fd9a72b9e6e 100644 --- a/contrib/python/prompt-toolkit/py3/ya.make +++ b/contrib/python/prompt-toolkit/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(3.0.50) +VERSION(3.0.51) LICENSE(BSD-3-Clause) diff --git a/contrib/python/pythran/bin/pythran/ya.make b/contrib/python/pythran/bin/pythran/ya.make index 7a4a0b9fd6f..370f7347bd6 100644 --- a/contrib/python/pythran/bin/pythran/ya.make +++ b/contrib/python/pythran/bin/pythran/ya.make @@ -10,4 +10,119 @@ PEERDIR( PY_MAIN(pythran.run:run) +INDUCED_DEPS(h+cpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/builtins/assert.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/builtins/getattr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/builtins/int_.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/builtins/max.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/builtins/range.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/builtins/tuple.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/core.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/assert.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/getattr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/int_.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/max.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/min.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/range.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/builtins/tuple.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/numpy/empty_like.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/numpy/float64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/numpy/square.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/add.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/div.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/eq.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/floordiv.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/gt.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/iadd.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/le.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/lt.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/mod.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/mul.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/neg.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/operator_/sub.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/types/float.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/types/float64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/types/int.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/types/ndarray.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/types/numpy_texpr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/include/types/str.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/numpy/empty_like.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/numpy/float64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/numpy/square.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/add.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/div.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/eq.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/floordiv.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/gt.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/iadd.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/le.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/lt.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/mod.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/mul.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/neg.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/operator_/sub.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/python/exception_handler.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/NoneType.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/array.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/assignable.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/attr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/bool.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/cfun.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/clongdouble.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/combined.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/complex.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/complex128.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/complex256.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/complex64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/dict.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/dynamic_tuple.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/empty_iterator.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/exceptions.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/file.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/finfo.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/float.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/float128.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/float32.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/float64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/generator.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/int.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/int16.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/int32.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/int64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/int8.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/intc.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/intp.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/list.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/longdouble.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/ndarray.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/nditerator.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_binary_op.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_broadcast.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_expr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_gexpr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_iexpr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_nary_expr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_op_helper.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_operators.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_texpr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_unary_op.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/numpy_vexpr.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/pointer.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/raw_array.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/set.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/slice.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/static_if.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/str.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/traits.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/tuple.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/uint16.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/uint32.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/uint64.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/uint8.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/uintc.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/uintp.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/variant_functor.hpp + ${ARCADIA_ROOT}/contrib/python/pythran/pythran/pythonic/types/vectorizable_type.hpp +) + END() diff --git a/contrib/python/ydb/py3/.dist-info/METADATA b/contrib/python/ydb/py3/.dist-info/METADATA index 8a15b64dc4b..aa485a3f4c0 100644 --- a/contrib/python/ydb/py3/.dist-info/METADATA +++ b/contrib/python/ydb/py3/.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: ydb -Version: 3.21.0 +Version: 3.21.1 Summary: YDB Python SDK Home-page: http://github.com/ydb-platform/ydb-python-sdk Author: Yandex LLC diff --git a/contrib/python/ydb/py3/ya.make b/contrib/python/ydb/py3/ya.make index b938db3247b..5f37e78726e 100644 --- a/contrib/python/ydb/py3/ya.make +++ b/contrib/python/ydb/py3/ya.make @@ -2,7 +2,7 @@ PY3_LIBRARY() -VERSION(3.21.0) +VERSION(3.21.1) LICENSE(Apache-2.0) diff --git a/contrib/python/ydb/py3/ydb/types.py b/contrib/python/ydb/py3/ydb/types.py index a48548640c0..47c9c48c2e2 100644 --- a/contrib/python/ydb/py3/ydb/types.py +++ b/contrib/python/ydb/py3/ydb/types.py @@ -39,6 +39,19 @@ def _to_date(pb: ydb_value_pb2.Value, value: typing.Union[date, int]) -> None: pb.uint32_value = value +def _from_date32(x: ydb_value_pb2.Value, table_client_settings: table.TableClientSettings) -> typing.Union[date, int]: + if table_client_settings is not None and table_client_settings._native_date_in_result_sets: + return _EPOCH.date() + timedelta(days=x.int32_value) + return x.int32_value + + +def _to_date32(pb: ydb_value_pb2.Value, value: typing.Union[date, int]) -> None: + if isinstance(value, date): + pb.int32_value = (value - _EPOCH.date()).days + else: + pb.int32_value = value + + def _from_datetime_number( x: typing.Union[float, datetime], table_client_settings: table.TableClientSettings ) -> datetime: @@ -62,6 +75,10 @@ def _from_uuid(pb: ydb_value_pb2.Value, value: uuid.UUID): pb.high_128 = struct.unpack("Q", value.bytes_le[8:16])[0] +def _timedelta_to_microseconds(value: timedelta) -> int: + return (value.days * _SECONDS_IN_DAY + value.seconds) * 1000000 + value.microseconds + + def _from_interval( value_pb: ydb_value_pb2.Value, table_client_settings: table.TableClientSettings ) -> typing.Union[timedelta, int]: @@ -70,10 +87,6 @@ def _from_interval( return value_pb.int64_value -def _timedelta_to_microseconds(value: timedelta) -> int: - return (value.days * _SECONDS_IN_DAY + value.seconds) * 1000000 + value.microseconds - - def _to_interval(pb: ydb_value_pb2.Value, value: typing.Union[timedelta, int]): if isinstance(value, timedelta): pb.int64_value = _timedelta_to_microseconds(value) @@ -100,6 +113,25 @@ def _to_timestamp(pb: ydb_value_pb2.Value, value: typing.Union[datetime, int]): pb.uint64_value = value +def _from_timestamp64( + value_pb: ydb_value_pb2.Value, table_client_settings: table.TableClientSettings +) -> typing.Union[datetime, int]: + if table_client_settings is not None and table_client_settings._native_timestamp_in_result_sets: + return _EPOCH + timedelta(microseconds=value_pb.int64_value) + return value_pb.int64_value + + +def _to_timestamp64(pb: ydb_value_pb2.Value, value: typing.Union[datetime, int]): + if isinstance(value, datetime): + if value.tzinfo: + epoch = _EPOCH_UTC + else: + epoch = _EPOCH + pb.int64_value = _timedelta_to_microseconds(value - epoch) + else: + pb.int64_value = value + + @enum.unique class PrimitiveType(enum.Enum): """ @@ -132,23 +164,46 @@ class PrimitiveType(enum.Enum): _from_date, _to_date, ) + Date32 = ( + _apis.primitive_types.DATE32, + None, + _from_date32, + _to_date32, + ) Datetime = ( _apis.primitive_types.DATETIME, "uint32_value", _from_datetime_number, ) + Datetime64 = ( + _apis.primitive_types.DATETIME64, + "int64_value", + _from_datetime_number, + ) Timestamp = ( _apis.primitive_types.TIMESTAMP, None, _from_timestamp, _to_timestamp, ) + Timestamp64 = ( + _apis.primitive_types.TIMESTAMP64, + None, + _from_timestamp64, + _to_timestamp64, + ) Interval = ( _apis.primitive_types.INTERVAL, None, _from_interval, _to_interval, ) + Interval64 = ( + _apis.primitive_types.INTERVAL64, + None, + _from_interval, + _to_interval, + ) DyNumber = _apis.primitive_types.DYNUMBER, "text_value" @@ -365,6 +420,32 @@ class DictType(AbstractTypeBuilder): return self._repr +class SetType(AbstractTypeBuilder): + __slots__ = ("__repr", "__proto") + + def __init__( + self, + key_type: typing.Union[AbstractTypeBuilder, PrimitiveType], + ): + """ + :param key_type: Key type builder + """ + self._repr = "Set<%s>" % (str(key_type)) + self._proto = _apis.ydb_value.Type( + dict_type=_apis.ydb_value.DictType( + key=key_type.proto, + payload=_apis.ydb_value.Type(void_type=struct_pb2.NULL_VALUE), + ) + ) + + @property + def proto(self): + return self._proto + + def __str__(self): + return self._repr + + class TupleType(AbstractTypeBuilder): __slots__ = ("__elements_repr", "__proto") diff --git a/contrib/python/ydb/py3/ydb/ydb_version.py b/contrib/python/ydb/py3/ydb/ydb_version.py index 6b71007009d..3c62627bc85 100644 --- a/contrib/python/ydb/py3/ydb/ydb_version.py +++ b/contrib/python/ydb/py3/ydb/ydb_version.py @@ -1 +1 @@ -VERSION = "3.21.0" +VERSION = "3.21.1" diff --git a/contrib/tools/python3/Python/_warnings.c b/contrib/tools/python3/Python/_warnings.c index 1f91edbf5cb..f0ab47efb64 100644 --- a/contrib/tools/python3/Python/_warnings.c +++ b/contrib/tools/python3/Python/_warnings.c @@ -791,6 +791,10 @@ is_internal_filename(PyObject *filename) } } + if (_PyUnicode_EqualToASCIIString(filename, "library/python/runtime_py3/__res.py")) { + return true; + } + return false; } diff --git a/contrib/tools/python3/Python/import.c b/contrib/tools/python3/Python/import.c index daac00593d7..1b550d0451c 100644 --- a/contrib/tools/python3/Python/import.c +++ b/contrib/tools/python3/Python/import.c @@ -2523,7 +2523,7 @@ remove_importlib_frames(PyThreadState *tstate) { const char *importlib_filename = "<frozen importlib._bootstrap>"; const char *external_filename = "<frozen importlib._bootstrap_external>"; - const char *importer_filename = "library/python/runtime_py3/importer.pxi"; + const char *importer_filename = "library/python/runtime_py3/__res.py"; const char *remove_frames = "_call_with_frames_removed"; int always_trim = 0; int in_importlib = 0; diff --git a/contrib/tools/python3/patches/cut-backtrace.patch b/contrib/tools/python3/patches/cut-backtrace.patch index 538ce57d9cb..3d734f05cdc 100644 --- a/contrib/tools/python3/patches/cut-backtrace.patch +++ b/contrib/tools/python3/patches/cut-backtrace.patch @@ -16,7 +16,7 @@ revision: 5216645 { const char *importlib_filename = "<frozen importlib._bootstrap>"; const char *external_filename = "<frozen importlib._bootstrap_external>"; -+ const char *importer_filename = "library/python/runtime_py3/importer.pxi"; ++ const char *importer_filename = "library/python/runtime_py3/__res.py"; const char *remove_frames = "_call_with_frames_removed"; int always_trim = 0; int in_importlib = 0; diff --git a/contrib/tools/python3/patches/cut-warnings-trace.patch b/contrib/tools/python3/patches/cut-warnings-trace.patch new file mode 100644 index 00000000000..3366ae97ee7 --- /dev/null +++ b/contrib/tools/python3/patches/cut-warnings-trace.patch @@ -0,0 +1,13 @@ +--- contrib/tools/python3/Python/_warnings.c (index) ++++ contrib/tools/python3/Python/_warnings.c (working tree) +@@ -791,6 +791,10 @@ is_internal_filename(PyObject *filename) + } + } + ++ if (_PyUnicode_EqualToASCIIString(filename, "library/python/runtime_py3/__res.py")) { ++ return true; ++ } ++ + return false; + } + diff --git a/library/python/runtime_py3/__res.cpp b/library/python/runtime_py3/__res.cpp new file mode 100644 index 00000000000..e4b01ce227f --- /dev/null +++ b/library/python/runtime_py3/__res.cpp @@ -0,0 +1,224 @@ +#include <library/cpp/resource/resource.h> + +#include <util/generic/scope.h> +#include <util/generic/strbuf.h> + +#include <Python.h> +#include <marshal.h> + +#include <type_traits> +#include <concepts> + +namespace { + +namespace NWrap { + +template<typename F> + requires std::convertible_to<std::invoke_result_t<F>, PyObject*> +PyObject* CallWithErrorTranslation(F&& f) noexcept { + try { + return std::forward<F>(f)(); + } catch (const std::bad_alloc& err) { + PyErr_SetString(PyExc_MemoryError, err.what()); + } catch(const std::out_of_range& err) { + PyErr_SetString(PyExc_IndexError, err.what()); + } catch (const std::exception& err) { + PyErr_SetString(PyExc_RuntimeError, err.what()); + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "Unhandled C++ exception of unknown type"); + } + return nullptr; +} + +PyObject* Count(PyObject* self[[maybe_unused]], PyObject *args[[maybe_unused]]) noexcept { + static_assert( + noexcept(NResource::Count()), + "Python3 Arcadia runtime binding assumes that NResource::Count do not throw exception. If this function start " + "to throw someone must add code translating C++ exceptions into Python exceptions here." + ); + return PyLong_FromLong(NResource::Count()); +} + +PyObject* KeyByIndex(PyObject* self[[maybe_unused]], PyObject *const *args, Py_ssize_t nargs) noexcept { + if (nargs != 1) { + PyErr_Format(PyExc_TypeError, "__res.key_by_index takes 1 positional arguments but %z were given", nargs); + return nullptr; + } + if (PyFloat_Check(args[0])) { + PyErr_SetString(PyExc_TypeError, "integer argument expected, got float"); + return nullptr; + } + PyObject* asNum = PyNumber_Index(args[0]); + if (!asNum) { + return nullptr; + } + const auto idx = PyLong_AsSize_t(asNum); + Py_DECREF(asNum); + if (idx == static_cast<size_t>(-1)) { + return nullptr; + } + return CallWithErrorTranslation([&]{ + const auto res = NResource::KeyByIndex(idx); + return PyBytes_FromStringAndSize(res.data(), res.size()); + }); +} + +PyObject* Find(PyObject* self[[maybe_unused]], PyObject *const* args, Py_ssize_t nargs) noexcept { + if (nargs != 1) { + PyErr_Format(PyExc_TypeError, "__res.find takes 1 positional arguments but %z were given", nargs); + return nullptr; + } + + TStringBuf key; + if (PyUnicode_Check(args[0])) { + Py_ssize_t sz; + const char* data = PyUnicode_AsUTF8AndSize(args[0], &sz); + if (sz < 0) { + return nullptr; + } + key = {data, static_cast<size_t>(sz)}; + } else { + char* data = nullptr; + Py_ssize_t sz; + if (PyBytes_AsStringAndSize(args[0], &data, &sz) != 0) { + return nullptr; + } + key = {data, static_cast<size_t>(sz)}; + } + + return CallWithErrorTranslation([&]{ + TString res; + if (!NResource::FindExact(key, &res)) { + Py_RETURN_NONE; + } + return PyBytes_FromStringAndSize(res.data(), res.size()); + }); +} + +PyObject* Has(PyObject* self[[maybe_unused]], PyObject *const* args, Py_ssize_t nargs) noexcept { + if (nargs != 1) { + PyErr_Format(PyExc_TypeError, "__res.has takes 1 positional arguments but %z were given", nargs); + return nullptr; + } + + TStringBuf key; + if (PyUnicode_Check(args[0])) { + Py_ssize_t sz; + const char* data = PyUnicode_AsUTF8AndSize(args[0], &sz); + if (sz < 0) { + return nullptr; + } + key = {data, static_cast<size_t>(sz)}; + } else { + char* data = nullptr; + Py_ssize_t sz; + if (PyBytes_AsStringAndSize(args[0], &data, &sz) != 0) { + return nullptr; + } + key = {data, static_cast<size_t>(sz)}; + } + + return CallWithErrorTranslation([&]{ + int res = NResource::Has(key); + return PyBool_FromLong(res); + }); +} + +} + +const unsigned char res_importer_pyc[] = { + #include "__res.pyc.inc" +}; + +int mod__res_exec(PyObject *mod) noexcept { + PyObject* modules = PySys_GetObject("modules"); + Y_ASSERT(modules); + Y_ASSERT(PyMapping_Check(modules)); + if (PyMapping_SetItemString(modules, "run_import_hook", mod) == -1) { + return -1; + } + + PyObject *bytecode = PyMarshal_ReadObjectFromString( + reinterpret_cast<const char*>(res_importer_pyc), + std::size(res_importer_pyc) + ); + if (bytecode == NULL) { + return -1; + } + + // The code below which sets "__builtins__" is a workarownd for issue + // reported here https://github.com/python/cpython/issues/130272 . + // The problem can be seen for Y_PYTHON_SOURCE_ROOT mode when trying + // compiling the code wich contains non-ascii identifiers. In this case + // call to `compile` in get_code function raises the exception + // KeyError: '__builtins__' inside `PyImport_Import` function. + PyObject* builtinsKey = NULL; + Y_DEFER { + Py_DECREF(bytecode); + Py_DECREF(builtinsKey); + }; + PyObject* modns = PyModule_GetDict(mod); + if (!modns) { + return -1; + } + builtinsKey = PyUnicode_FromString("__builtins__"); + if (builtinsKey == NULL) { + return -1; + } + int r = PyDict_Contains(modns, builtinsKey); + if (r < 0) { + return -1; + } if (r == 0) { + PyObject* builtins = PyEval_GetBuiltins(); + if (builtins == NULL) { + return -1; + } + if (PyDict_SetItem(modns, builtinsKey, builtins) < 0) { + return -1; + } + } + + if (PyObject* evalRes = PyEval_EvalCode(bytecode, modns, modns)) { + Py_DECREF(evalRes); + } + if (PyErr_Occurred()) { + return -1; + } + return 0; +} + +PyDoc_STRVAR(mod__res_doc, +"resfs python bindings module with importer hook supporting hermetic python programs."); + +PyMethodDef mod__res_methods[] = { + {"count", _PyCFunction_CAST(NWrap::Count), METH_NOARGS, PyDoc_STR("Returns number of embedded resources.")}, + {"key_by_index", _PyCFunction_CAST(NWrap::KeyByIndex), METH_FASTCALL, PyDoc_STR("Returns resource key by resource index.")}, + {"find", _PyCFunction_CAST(NWrap::Find), METH_FASTCALL, PyDoc_STR("Finds resource content by key.")}, + {"has", _PyCFunction_CAST(NWrap::Has), METH_FASTCALL, PyDoc_STR("Checks if the resource with the given key exists.")}, + {nullptr, nullptr, 0, nullptr} +}; + +PyModuleDef_Slot mod__res_slots[] = { + {Py_mod_exec, reinterpret_cast<void*>(&mod__res_exec)}, + {Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED}, + {0, nullptr}, +}; + +PyModuleDef mod__res = { + .m_base = PyModuleDef_HEAD_INIT, + .m_name = "__res", + .m_doc = mod__res_doc, + .m_size = 0, + .m_methods = mod__res_methods, + .m_slots = mod__res_slots, + .m_traverse = nullptr, + .m_clear = nullptr, + .m_free = nullptr +}; + +} + +PyMODINIT_FUNC +PyInit___res() noexcept { + return PyModuleDef_Init(&mod__res); +} diff --git a/library/python/runtime_py3/importer.pxi b/library/python/runtime_py3/__res.py index 51bc3020a3f..9fb2d09481a 100644 --- a/library/python/runtime_py3/importer.pxi +++ b/library/python/runtime_py3/__res.py @@ -1,3 +1,15 @@ +# def count() -> int: +# # implemented in C++ part of this module +# +# def key_by_index(idx: key) -> bytes: +# # implemented in C++ part of this module +# +# def find(key: str | bytes) -> bytes: +# # implemented in C++ part of this module +# +# def has(key: str | bytes) -> bool: +# # implemented in C++ part of this module + import marshal import sys from _codecs import utf_8_decode, utf_8_encode @@ -16,8 +28,6 @@ from _frozen_importlib_external import ( from _io import FileIO -import __res as __resource - _b = lambda x: x if isinstance(x, bytes) else utf_8_encode(x)[0] _s = lambda x: x if isinstance(x, str) else utf_8_decode(x)[0] env_source_root = b'Y_PYTHON_SOURCE_ROOT' @@ -25,7 +35,6 @@ cfg_source_root = b'arcadia-source-root' env_extended_source_search = b'Y_PYTHON_EXTENDED_SOURCE_SEARCH' res_ya_ide_venv = b'YA_IDE_VENV' executable = sys.executable or 'Y_PYTHON' -sys.modules['run_import_hook'] = __resource def _probe(environ_dict, key, default_value=None): """ Probe bytes and str variants for environ. @@ -47,9 +56,10 @@ def _probe(environ_dict, key, default_value=None): py_prefix = b'py/' py_prefix_len = len(py_prefix) -EXTERNAL_PY_FILES_MODE = __resource.find(b'py/conf/ENABLE_EXTERNAL_PY_FILES') in (b'1', b'yes') +EXTERNAL_PY_FILES_MODE = find(b'py/conf/ENABLE_EXTERNAL_PY_FILES') in (b'1', b'yes') + +YA_IDE_VENV = find(res_ya_ide_venv) -YA_IDE_VENV = __resource.find(res_ya_ide_venv) Y_PYTHON_EXTENDED_SOURCE_SEARCH = _probe(_os.environ, env_extended_source_search) or YA_IDE_VENV @@ -182,8 +192,8 @@ def _print(*xs): def iter_keys(prefix): l = len(prefix) - for idx in range(__resource.count()): - key = __resource.key_by_index(idx) + for idx in range(count()): + key = key_by_index(idx) if key.startswith(prefix): yield key, key[l:] @@ -229,14 +239,14 @@ def resfs_src(key, resfs_file=False): """ if resfs_file: key = b'resfs/file/' + _b(key) - return __resource.find(b'resfs/src/' + _b(key)) + return find(b'resfs/src/' + _b(key)) def resfs_has(path): """ Return true if the requested file is embedded in the program """ - return __resource.has(b'resfs/file/' + _b(path)) + return has(b'resfs/file/' + _b(path)) def resfs_read(path, builtin=None): @@ -253,7 +263,7 @@ def resfs_read(path, builtin=None): return file_bytes(fspath) if builtin is not False: - return __resource.find(b'resfs/file/' + _b(path)) + return find(b'resfs/file/' + _b(path)) def resfs_files(prefix=b''): @@ -605,7 +615,7 @@ class ArcadiaSourceFinder: for key, dirty_path in iter_keys(self.NAMESPACE_PREFIX): # dirty_path contains unique prefix to prevent repeatable keys in the resource storage path = dirty_path.split(b'/', 1)[1] - namespaces = __resource.find(key).split(b':') + namespaces = find(key).split(b':') for n in namespaces: package_name = _s(n.rstrip(b'.')) self.module_path_cache.setdefault(package_name, set()).add(_s(path)) diff --git a/library/python/runtime_py3/__res.pyx b/library/python/runtime_py3/__res.pyx deleted file mode 100644 index 2c1d0c3ab4d..00000000000 --- a/library/python/runtime_py3/__res.pyx +++ /dev/null @@ -1,44 +0,0 @@ -from _codecs import utf_8_decode, utf_8_encode - -from libcpp cimport bool - -from util.generic.string cimport TString, TStringBuf - - -cdef extern from "library/cpp/resource/resource.h" namespace "NResource": - cdef bool Has(const TStringBuf key) except + - cdef size_t Count() except + - cdef TStringBuf KeyByIndex(size_t idx) except + - cdef bool FindExact(const TStringBuf key, TString* result) nogil except + - - -def count(): - return Count() - - -def key_by_index(idx): - cdef TStringBuf ret = KeyByIndex(idx) - - return ret.Data()[:ret.Size()] - - -def find(s): - cdef TString res - - if isinstance(s, str): - s = utf_8_encode(s)[0] - - if FindExact(TStringBuf(s, len(s)), &res): - return res.c_str()[:res.length()] - - return None - - -def has(s): - if isinstance(s, str): - s = utf_8_encode(s)[0] - - return Has(s) - - -include "importer.pxi" diff --git a/library/python/runtime_py3/runtime_reg_py3.cpp b/library/python/runtime_py3/runtime_reg_py3.cpp new file mode 100644 index 00000000000..283fa70254d --- /dev/null +++ b/library/python/runtime_py3/runtime_reg_py3.cpp @@ -0,0 +1,17 @@ +#include <Python.h> + +extern "C" PyObject* PyInit___res(); +extern "C" PyObject* PyInit_sitecustomize(); + +namespace { + struct TRegistrar { + inline TRegistrar() { + _inittab mods[] = { + {"__res", PyInit___res}, + {"sitecustomize", PyInit_sitecustomize}, + {nullptr, nullptr} + }; + PyImport_ExtendInittab(mods); + } + } REG; +} diff --git a/library/python/runtime_py3/sitecustomize.cpp b/library/python/runtime_py3/sitecustomize.cpp new file mode 100644 index 00000000000..be68ad39780 --- /dev/null +++ b/library/python/runtime_py3/sitecustomize.cpp @@ -0,0 +1,63 @@ +#include <Python.h> +#include <marshal.h> + +#include <iterator> + +namespace { + +const unsigned char sitecustomize_pyc[] = { + #include "sitecustomize.pyc.inc" +}; + +int modsitecustomize_exec(PyObject *mod) noexcept { + PyObject *bytecode = PyMarshal_ReadObjectFromString( + reinterpret_cast<const char*>(sitecustomize_pyc), + std::size(sitecustomize_pyc) + ); + if (!bytecode) { + return -1; + } + PyObject* modns = PyModule_GetDict(mod); + if (!modns) { + return -1; + } + if (PyObject* evalRes = PyEval_EvalCode(bytecode, modns, modns)) { + Py_DECREF(evalRes); + } + if (PyErr_Occurred()) { + return -1; + } + return 0; +} + +PyDoc_STRVAR(modsitecustomize_doc, +"bridge between Arcadia resource system and python importlib resources interface."); + +PyMethodDef modsitecustomize_methods[] = { + {nullptr, nullptr, 0, nullptr} +}; + +PyModuleDef_Slot modsitecustomize_slots[] = { + {Py_mod_exec, reinterpret_cast<void*>(&modsitecustomize_exec)}, + {Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED}, + {0, nullptr}, +}; + +PyModuleDef modsitecustomize = { + .m_base = PyModuleDef_HEAD_INIT, + .m_name = "sitecustomize", + .m_doc = modsitecustomize_doc, + .m_size = 0, + .m_methods = modsitecustomize_methods, + .m_slots = modsitecustomize_slots, + .m_traverse = nullptr, + .m_clear = nullptr, + .m_free = nullptr +}; + +} + +PyMODINIT_FUNC +PyInit_sitecustomize() noexcept { + return PyModuleDef_Init(&modsitecustomize); +} diff --git a/library/python/runtime_py3/sitecustomize.pyx b/library/python/runtime_py3/sitecustomize.py index 8d30073d7d1..3b30b8807e8 100644 --- a/library/python/runtime_py3/sitecustomize.pyx +++ b/library/python/runtime_py3/sitecustomize.py @@ -11,7 +11,7 @@ from importlib.metadata import ( ) from importlib.resources.abc import Traversable -import __res +from __res import _ResfsResourceReader, find, iter_keys, resfs_read, resfs_files METADATA_NAME = re.compile("^Name: (.*)$", re.MULTILINE) @@ -55,7 +55,7 @@ class ArcadiaResource(ArcadiaTraversable): return False def open(self, mode="r", *args, **kwargs): - data = __res.find(self._resfs.encode("utf-8")) + data = find(self._resfs.encode("utf-8")) if data is None: raise FileNotFoundError(self._resfs) @@ -85,7 +85,7 @@ class ArcadiaResourceContainer(ArcadiaTraversable): def iterdir(self): seen = set() - for key, path_without_prefix in __res.iter_keys(self._resfs.encode("utf-8")): + for key, path_without_prefix in iter_keys(self._resfs.encode("utf-8")): if b"/" in path_without_prefix: subdir = path_without_prefix.split(b"/", maxsplit=1)[0].decode("utf-8") if subdir not in seen: @@ -127,7 +127,7 @@ class ArcadiaDistribution(Distribution): self._path = pathlib.Path(prefix) def read_text(self, filename): - data = __res.resfs_read(f"{self._prefix}{filename}") + data = resfs_read(f"{self._prefix}{filename}") if data is not None: return data.decode("utf-8") @@ -149,11 +149,11 @@ class MetadataArcadiaFinder(DistributionFinder): def _init_prefixes(cls): cls.prefixes.clear() - for resource in __res.resfs_files(): + for resource in resfs_files(): resource = resource.decode("utf-8") if not resource.endswith("METADATA"): continue - data = __res.resfs_read(resource).decode("utf-8") + data = resfs_read(resource).decode("utf-8") metadata_name = METADATA_NAME.search(data) if metadata_name: cls.prefixes[Prepared(metadata_name.group(1)).normalized] = resource.removesuffix("METADATA") diff --git a/library/python/runtime_py3/stage0pycc/main.cpp b/library/python/runtime_py3/stage0pycc/main.cpp new file mode 100644 index 00000000000..f66ea33b6e1 --- /dev/null +++ b/library/python/runtime_py3/stage0pycc/main.cpp @@ -0,0 +1,70 @@ +#include <util/folder/path.h> +#include <util/generic/scope.h> +#include <util/generic/string.h> +#include <util/stream/file.h> +#include <util/stream/output.h> + +#include <Python.h> +#include <marshal.h> + +#include <cstdio> +#include <system_error> + +struct TPyObjDeleter { + static void Destroy(PyObject* o) noexcept { + Py_XDECREF(o); + } +}; +using TPyObject = THolder<PyObject, TPyObjDeleter>; + +constexpr TStringBuf modPrefix = "mod="; + +int main(int argc, char** argv) { + if ((argc - 1) % 3 != 0) { + Cerr << "Usage:\n\t" << argv[0] << " (mod=SRC_PATH_X SRC OUT)+" << Endl; + return 1; + } + + PyConfig cfg{}; + PyConfig_InitIsolatedConfig(&cfg); + cfg._install_importlib = 0; + Y_SCOPE_EXIT(&cfg) {PyConfig_Clear(&cfg);}; + + for (int i = 0; i < (argc - 1)/3; ++i) { + const TString srcpath{TStringBuf{argv[3*i + 1]}.substr(modPrefix.size())}; + const TFsPath inPath{argv[3*i + 2]}; + const char* outPath = argv[3*i + 3]; + + const auto status = Py_InitializeFromConfig(&cfg); + if (PyStatus_Exception(status)) { + Py_ExitStatusException(status); + } + Y_SCOPE_EXIT() {Py_Finalize();}; + + TPyObject bytecode{Py_CompileString( + TFileInput{inPath}.ReadAll().c_str(), + srcpath.c_str(), + Py_file_input + )}; + if (!bytecode) { + Cerr << "Failed to compile " << outPath << Endl; + PyErr_Print(); + return 1; + } + + if (FILE* out = fopen(outPath, "wb")) { + PyMarshal_WriteObjectToFile(bytecode.Get(), out, Py_MARSHAL_VERSION); + fclose(out); + if (PyErr_Occurred()) { + Cerr << "Failed to marshal " << outPath << Endl; + PyErr_Print(); + return 1; + } + } else { + Cerr << "Failed to write " << outPath << ": " << std::error_code{errno, std::system_category()}.message() << Endl; + return 1; + } + } + + return 0; +} diff --git a/library/python/runtime_py3/stage0pycc/ya.make b/library/python/runtime_py3/stage0pycc/ya.make new file mode 100644 index 00000000000..b182365da87 --- /dev/null +++ b/library/python/runtime_py3/stage0pycc/ya.make @@ -0,0 +1,10 @@ +PROGRAM() + +PYTHON3_ADDINCL() +PEERDIR( + contrib/tools/python3 +) + +SRCS(main.cpp) + +END() diff --git a/library/python/runtime_py3/test/subinterpreter/py3_subinterpreters.cpp b/library/python/runtime_py3/test/subinterpreter/py3_subinterpreters.cpp new file mode 100644 index 00000000000..0a934d4db50 --- /dev/null +++ b/library/python/runtime_py3/test/subinterpreter/py3_subinterpreters.cpp @@ -0,0 +1,82 @@ +#include "stdout_interceptor.h" + +#include <util/stream/str.h> + +#include <library/cpp/testing/gtest/gtest.h> + +#include <Python.h> + +#include <thread> +#include <algorithm> + +struct TSubinterpreters: ::testing::Test { + static void SetUpTestSuite() { + Py_InitializeEx(0); + EXPECT_TRUE(TPyStdoutInterceptor::SetupInterceptionSupport()); + } + static void TearDownTestSuite() { + Py_Finalize(); + } + + static void ThreadPyRun(PyInterpreterState* interp, IOutputStream& pyout, const char* pycode) { + PyThreadState* state = PyThreadState_New(interp); + PyEval_RestoreThread(state); + + { + TPyStdoutInterceptor interceptor{pyout}; + PyRun_SimpleString(pycode); + } + + PyThreadState_Clear(state); + PyThreadState_DeleteCurrent(); + } +}; + +TEST_F(TSubinterpreters, NonSubinterpreterFlowStillWorks) { + TStringStream pyout; + TPyStdoutInterceptor interceptor{pyout}; + + PyRun_SimpleString("print('Hello World')"); + EXPECT_EQ(pyout.Str(), "Hello World\n"); +} + +TEST_F(TSubinterpreters, ThreadedSubinterpretersFlowWorks) { + TStringStream pyout[2]; + + PyInterpreterConfig cfg = { + .use_main_obmalloc = 0, + .allow_fork = 0, + .allow_exec = 0, + .allow_threads = 1, + .allow_daemon_threads = 0, + .check_multi_interp_extensions = 1, + .gil = PyInterpreterConfig_OWN_GIL, + }; + + PyThreadState* mainState = PyThreadState_Get(); + PyThreadState *sub[2] = {nullptr, nullptr}; + Py_NewInterpreterFromConfig(&sub[0], &cfg); + ASSERT_NE(sub[0], nullptr); + Py_NewInterpreterFromConfig(&sub[1], &cfg); + ASSERT_NE(sub[1], nullptr); + PyThreadState_Swap(mainState); + + PyThreadState* savedState = PyEval_SaveThread(); + std::array<std::thread, 2> threads{ + std::thread{ThreadPyRun, sub[0]->interp, std::ref(pyout[0]), "print('Hello Thread 0')"}, + std::thread{ThreadPyRun, sub[1]->interp, std::ref(pyout[1]), "print('Hello Thread 1')"} + }; + std::ranges::for_each(threads, &std::thread::join); + PyEval_RestoreThread(savedState); + + PyThreadState_Swap(sub[0]); + Py_EndInterpreter(sub[0]); + + PyThreadState_Swap(sub[1]); + Py_EndInterpreter(sub[1]); + + PyThreadState_Swap(mainState); + + EXPECT_EQ(pyout[0].Str(), "Hello Thread 0\n"); + EXPECT_EQ(pyout[1].Str(), "Hello Thread 1\n"); +} diff --git a/library/python/runtime_py3/test/subinterpreter/stdout_interceptor.cpp b/library/python/runtime_py3/test/subinterpreter/stdout_interceptor.cpp new file mode 100644 index 00000000000..3cd4b69d012 --- /dev/null +++ b/library/python/runtime_py3/test/subinterpreter/stdout_interceptor.cpp @@ -0,0 +1,77 @@ +#include "stdout_interceptor.h" + +#include <util/stream/output.h> + +namespace { + +struct TOStreamWrapper { + PyObject_HEAD + IOutputStream* Stm = nullptr; +}; + +PyObject* Write(TOStreamWrapper *self, PyObject *const *args, Py_ssize_t nargs) noexcept { + try { + Py_buffer view; + for (Py_ssize_t i = 0; i < nargs; ++i) { + PyObject* buf = args[i]; + if (PyUnicode_Check(args[i])) { + buf = PyUnicode_AsUTF8String(buf); + if (!buf) { + return nullptr; + } + } + + if (PyObject_GetBuffer(buf, &view, PyBUF_SIMPLE | PyBUF_C_CONTIGUOUS) == -1) { + return nullptr; + } + self->Stm->Write(reinterpret_cast<const char*>(view.buf), view.len); + PyBuffer_Release(&view); + } + + return Py_None; + } catch(const std::exception& err) { + PyErr_SetString(PyExc_IOError, err.what()); + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Unhandled C++ exception of unknown type"); + } + return nullptr; +} + +PyMethodDef TOStreamWrapperMethods[] = { + {"write", reinterpret_cast<PyCFunction>(Write), METH_FASTCALL, PyDoc_STR("write buffer to wrapped C++ stream")}, + {} +}; + +PyTypeObject TOStreamWrapperType { + .ob_base = PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "testwrap.OStream", + .tp_basicsize = sizeof(TOStreamWrapper), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = PyDoc_STR("C++ IOStream wrapper"), + .tp_methods = TOStreamWrapperMethods, + .tp_new = PyType_GenericNew, +}; + +} + +TPyStdoutInterceptor::TPyStdoutInterceptor(IOutputStream& redirectionStream) noexcept + : RealStdout_{PySys_GetObject("stdout")} +{ + Py_INCREF(RealStdout_); + + PyObject* redirect = TOStreamWrapperType.tp_alloc(&TOStreamWrapperType, 0); + reinterpret_cast<TOStreamWrapper*>(redirect)->Stm = &redirectionStream; + + PySys_SetObject("stdout", redirect); + Py_DECREF(redirect); +} + +TPyStdoutInterceptor::~TPyStdoutInterceptor() noexcept { + PySys_SetObject("stdout", RealStdout_); + Py_DECREF(RealStdout_); +} + +bool TPyStdoutInterceptor::SetupInterceptionSupport() noexcept { + return PyType_Ready(&TOStreamWrapperType) == 0; +} diff --git a/library/python/runtime_py3/test/subinterpreter/stdout_interceptor.h b/library/python/runtime_py3/test/subinterpreter/stdout_interceptor.h new file mode 100644 index 00000000000..a1e219953f0 --- /dev/null +++ b/library/python/runtime_py3/test/subinterpreter/stdout_interceptor.h @@ -0,0 +1,16 @@ +#pragma once + +#include <Python.h> + +class IOutputStream; + +class TPyStdoutInterceptor { +public: + TPyStdoutInterceptor(IOutputStream& redirectionStream) noexcept; + ~TPyStdoutInterceptor() noexcept; + + static bool SetupInterceptionSupport() noexcept; + +private: + PyObject* RealStdout_; +}; diff --git a/library/python/runtime_py3/test/subinterpreter/ya.make b/library/python/runtime_py3/test/subinterpreter/ya.make new file mode 100644 index 00000000000..78cc82304c8 --- /dev/null +++ b/library/python/runtime_py3/test/subinterpreter/ya.make @@ -0,0 +1,10 @@ +GTEST() + +USE_PYTHON3() + +SRCS( + py3_subinterpreters.cpp + stdout_interceptor.cpp +) + +END() diff --git a/library/python/runtime_py3/test/test_arcadia_source_finder.py b/library/python/runtime_py3/test/test_arcadia_source_finder.py index 9f794f03591..835e60c6710 100644 --- a/library/python/runtime_py3/test/test_arcadia_source_finder.py +++ b/library/python/runtime_py3/test/test_arcadia_source_finder.py @@ -18,7 +18,7 @@ class ImporterMocks: self._mock_resources = mock_resources self._patchers = [ patch("__res.iter_keys", wraps=self._iter_keys), - patch("__res.__resource.find", wraps=self._resource_find), + patch("__res.find", wraps=self._resource_find), patch("__res._path_isfile", wraps=self._path_isfile), patch("__res._os.listdir", wraps=self._os_listdir), patch("__res._os.lstat", wraps=self._os_lstat), diff --git a/library/python/runtime_py3/test/ya.make b/library/python/runtime_py3/test/ya.make index e0c4061ad2c..fde64236dca 100644 --- a/library/python/runtime_py3/test/ya.make +++ b/library/python/runtime_py3/test/ya.make @@ -34,4 +34,7 @@ RESOURCE_FILES( END() -RECURSE_FOR_TESTS(traceback) +RECURSE_FOR_TESTS( + subinterpreter + traceback +) diff --git a/library/python/runtime_py3/ya.make b/library/python/runtime_py3/ya.make index dc97c8e2e08..b2d0dbf51ee 100644 --- a/library/python/runtime_py3/ya.make +++ b/library/python/runtime_py3/ya.make @@ -8,21 +8,18 @@ PEERDIR( library/cpp/resource ) -CFLAGS(-DCYTHON_REGISTER_ABCS=0) - NO_PYTHON_INCLUDES() ENABLE(PYBUILD_NO_PYC) +SRCS( + __res.cpp + sitecustomize.cpp + GLOBAL runtime_reg_py3.cpp +) + PY_SRCS( entry_points.py - TOP_LEVEL - - CYTHON_DIRECTIVE - language_level=3 - - __res.pyx - sitecustomize.pyx ) IF (EXTERNAL_PY_FILES) @@ -31,17 +28,16 @@ IF (EXTERNAL_PY_FILES) ) ENDIF() -IF (CYTHON_COVERAGE) - # Let covarage support add all needed files to resources -ELSE() - RESOURCE_FILES( - DONT_COMPRESS - PREFIX ${MODDIR}/ - __res.pyx - importer.pxi - sitecustomize.pyx - ) -ENDIF() +RUN_PROGRAM( + library/python/runtime_py3/stage0pycc + mod=${MODDIR}/__res.py __res.py __res.pyc + mod=${MODDIR}/sitecustomize.py sitecustomize.py sitecustomize.pyc + IN __res.py sitecustomize.py + OUT_NOAUTO __res.pyc sitecustomize.pyc + ENV PYTHONHASHSEED=0 +) +ARCHIVE(NAME __res.pyc.inc DONTCOMPRESS __res.pyc) +ARCHIVE(NAME sitecustomize.pyc.inc DONTCOMPRESS sitecustomize.pyc) END() diff --git a/ydb/ci/rightlib.txt b/ydb/ci/rightlib.txt index 9ccf833a891..e0188e8eab3 100644 --- a/ydb/ci/rightlib.txt +++ b/ydb/ci/rightlib.txt @@ -1 +1 @@ -407f7c0bc156862b8263bccf3eaaf0687ba75f8d +8a6954f35eee99eef660e76e775774d720e111a9 diff --git a/yql/essentials/core/dq_expr_nodes/dq_expr_nodes.json b/yql/essentials/core/dq_expr_nodes/dq_expr_nodes.json index e9efb8c2042..9c2b4046521 100644 --- a/yql/essentials/core/dq_expr_nodes/dq_expr_nodes.json +++ b/yql/essentials/core/dq_expr_nodes/dq_expr_nodes.json @@ -44,7 +44,10 @@ "Match": {"Type": "Callable", "Name": "DqJoin"}, "Children": [ {"Index": 8, "Name": "JoinAlgo", "Type": "TCoAtom"}, - {"Index": 9, "Name": "Flags", "Type": "TCoAtomList", "Optional": true} + {"Index": 9, "Name": "ShuffleLeftSideBy", "Type": "TExprList", "Optional": true}, + {"Index": 10, "Name": "ShuffleRightSideBy", "Type": "TExprList", "Optional": true}, + {"Index": 11, "Name": "JoinAlgoOptions", "Type": "TCoNameValueTupleList", "Optional": true}, + {"Index": 12, "Name": "Flags", "Type": "TCoAtomList", "Optional": true} ] }, { diff --git a/yql/essentials/core/dqs_expr_nodes/dqs_expr_nodes.h b/yql/essentials/core/dqs_expr_nodes/dqs_expr_nodes.h index 8851127f831..b40dae973ba 100644 --- a/yql/essentials/core/dqs_expr_nodes/dqs_expr_nodes.h +++ b/yql/essentials/core/dqs_expr_nodes/dqs_expr_nodes.h @@ -6,5 +6,6 @@ namespace NYql::NNodes { #include <yql/essentials/core/dqs_expr_nodes/dqs_expr_nodes.decl.inl.h> + #include <yql/essentials/core/dqs_expr_nodes/dqs_expr_nodes.defs.inl.h> } diff --git a/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp b/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp index 2b64e4e0bd1..666f1c6183d 100644 --- a/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp +++ b/yql/essentials/core/peephole_opt/yql_opt_peephole_physical.cpp @@ -2143,17 +2143,6 @@ TExprNode::TPtr ExpandSqlIn(const TExprNode::TPtr& input, TExprContext& ctx) { dictKeyType = collectionType->Cast<TListExprType>()->GetItemType(); } } else if (collectionType->GetKind() == ETypeAnnotationKind::Tuple) { - if (ansiIn && collectionType->Cast<TTupleExprType>()->GetSize()) { - return ctx.Builder(input->Pos()) - .Callable("SqlIn") - .Callable(0, "AsListStrict") - .Add(collection->ChildrenList()) - .Seal() - .Add(1, std::move(lookup)) - .Add(2, std::move(options)) - .Seal() - .Build(); - } YQL_CLOG(DEBUG, CorePeepHole) << "IN Tuple"; dict = BuildDictOverTuple(std::move(collection), dictKeyType, ctx); } else if (collectionType->GetKind() == ETypeAnnotationKind::EmptyDict) { diff --git a/yql/essentials/docs/.yfm b/yql/essentials/docs/.yfm index 4b468a87b1e..a19a96fc41f 100644 --- a/yql/essentials/docs/.yfm +++ b/yql/essentials/docs/.yfm @@ -12,8 +12,8 @@ docs-viewer: title: YQL core description: Язык запросов YQL auto-release-to: - testing: false - prod: false + testing: true + prod: true single-page: supported: true startrek: diff --git a/yql/essentials/docs/.yfmlint b/yql/essentials/docs/.yfmlint index fa6ba04a899..0c54658d91b 100644 --- a/yql/essentials/docs/.yfmlint +++ b/yql/essentials/docs/.yfmlint @@ -1,5 +1,9 @@ log-levels: YFM001: 'error' # Inline code length + MD009: 'error' # Trailing spaces + MD023: 'error' # Headings must start at the beginning of the line + MD024: 'error' # Multiple headings with the same content + MD026: 'error' # Trailing punctuation in heading # Inline code length YFM001: diff --git a/yql/essentials/docs/presets.yaml b/yql/essentials/docs/presets.yaml deleted file mode 100644 index e9427569bbe..00000000000 --- a/yql/essentials/docs/presets.yaml +++ /dev/null @@ -1,49 +0,0 @@ -default: - oss: true - ya_make: true - -rtmr: - backend_name: RTMR - backend_name_lower: rtmr - rtmr: true - example_cluster: rtmr_yql_alpha - feature_not_null: true - feature_column_container_type: true - feature_mapreduce: true - process_command: PROCESS STREAM - select_command: SELECT STREAM - feature_temp_table: true - -ydb: - backend_name: YDB - backend_name_lower: ydb - ydb: true - example_cluster: ydbtest - feature_secondary_index: true - feature_changefeed: true - feature_replace: true - feature_upsert: true - feature_join: true - feature_map_tables: true - feature_group_by_rollup_cube: true - feature_window_functions: true - feature_bulk_tables: false - -yt: - backend_name: YT - backend_name_lower: yt - yt: true - example_cluster: hahn - feature_mapreduce: true - feature_column_container_type: true - feature_subquery: true - feature_upsert: true - feature_join: true - feature_insert_with_truncate: true - feature_bulk_tables: true # CONCAT, RANGE, TablePath(), INSERT INTO details... - feature_group_by_rollup_cube: true - feature_window_functions: true - feature_codegen: true - feature_functional_tables: true - feature_udf_noncpp: true - feature_temp_table: true diff --git a/yql/essentials/docs/ru/syntax/window.md b/yql/essentials/docs/ru/syntax/window.md index 1c081d1d104..9806d8261b3 100644 --- a/yql/essentials/docs/ru/syntax/window.md +++ b/yql/essentials/docs/ru/syntax/window.md @@ -4,7 +4,7 @@ В отличие от [агрегатных функций](../builtins/aggregation.md) при этом не происходит группировка нескольких строк в одну – после применения оконных функций число строк в результирующей таблице всегда совпадает с числом строк в исходной. -При наличии в запросе агрегатных и оконных функций сначала производится группировка и вычисляются значения агрегатных функций. Вычисленные значения агрегатных функций могут использоваться в качестве аргументов оконных (но не наоборот). Порядок, в котором вычисляются оконные функции относительно других элементов запроса, описан в разеделе [SELECT](select/index.md). +При наличии в запросе агрегатных и оконных функций сначала производится группировка и вычисляются значения агрегатных функций. Вычисленные значения агрегатных функций могут использоваться в качестве аргументов оконных (но не наоборот). Порядок, в котором вычисляются оконные функции относительно других элементов запроса, описан в разделе [SELECT](select/index.md). ## Синтаксис {#syntax} @@ -75,7 +75,7 @@ WINDOW ### Рамка {#frame} -Определение рамки `frame_definition` задает множество строк раздела, попадающих в *рамку окна* связанную с текущей строкой. +Определение рамки `frame_definition` задает множество строк раздела, попадающих в *рамку окна*, связанную с текущей строкой. В режиме `ROWS` (в YQL пока поддерживается только он) в рамку окна попадают строки с указанными смещениями относительно текущей строки раздела. Например, для `ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING` в рамку окна попадут три строки перед текущей, текущая строка и пять строк после текущей строки. diff --git a/yql/essentials/minikql/mkql_program_builder.cpp b/yql/essentials/minikql/mkql_program_builder.cpp index 8bed557c701..b472f648575 100644 --- a/yql/essentials/minikql/mkql_program_builder.cpp +++ b/yql/essentials/minikql/mkql_program_builder.cpp @@ -1217,11 +1217,11 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list) { std::vector<std::conditional_t<OnStruct, std::string_view, ui32>> members; const bool multiOptional = CollectOptionalElements<IsFilter>(itemType, members, filteredItems); - const auto predicate = [=](TRuntimeNode item) { + const auto predicate = [=, this](TRuntimeNode item) { std::vector<TRuntimeNode> checkMembers; checkMembers.reserve(members.size()); std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers), - [=](const auto& i){ return Exists(Element(item, i)); }); + [=, this](const auto& i){ return Exists(Element(item, i)); }); return And(checkMembers); }; @@ -1263,11 +1263,11 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRe THROW yexception() << "Expected flow or list or stream or optional of struct."; } - const auto predicate = [=](TRuntimeNode item) { + const auto predicate = [=, this](TRuntimeNode item) { TRuntimeNode::TList checkMembers; checkMembers.reserve(members.size()); std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers), - [=](const auto& i){ return Exists(Element(item, i)); }); + [=, this](const auto& i){ return Exists(Element(item, i)); }); return And(checkMembers); }; @@ -1297,7 +1297,7 @@ TRuntimeNode TProgramBuilder::BuildFilterNulls(TRuntimeNode list, const TArrayRe TRuntimeNode::TList checkMembers; checkMembers.reserve(members.size()); std::transform(members.cbegin(), members.cend(), std::back_inserter(checkMembers), - [=](const auto& i){ return Element(item, i); }); + [=, this](const auto& i){ return this->Element(item, i); }); return IfPresent(checkMembers, [&](TRuntimeNode::TList items) { std::conditional_t<OnStruct, std::vector<std::pair<std::string_view, TRuntimeNode>>, TRuntimeNode::TList> row; diff --git a/yql/essentials/sql/v1/complete/name/object/schema_gateway.h b/yql/essentials/sql/v1/complete/name/object/schema_gateway.h index 37fff8571e0..f9307bf495d 100644 --- a/yql/essentials/sql/v1/complete/name/object/schema_gateway.h +++ b/yql/essentials/sql/v1/complete/name/object/schema_gateway.h @@ -10,6 +10,9 @@ namespace NSQLComplete { struct TFolderEntry { + static constexpr const char* Folder = "Folder"; + static constexpr const char* Table = "Table"; + TString Type; TString Name; diff --git a/yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.cpp b/yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.cpp new file mode 100644 index 00000000000..e8e7bf3ccd9 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.cpp @@ -0,0 +1,74 @@ +#include "schema_gateway.h" + +#include <util/charset/utf8.h> + +namespace NSQLComplete { + + namespace { + + class TSimpleSchemaGateway: public ISchemaGateway { + private: + static auto FilterByName(TString name) { + return [name = std::move(name)](auto f) { + TVector<TFolderEntry> entries = f.ExtractValue(); + EraseIf(entries, [prefix = ToLowerUTF8(name)](const TFolderEntry& entry) { + return !entry.Name.StartsWith(prefix); + }); + return entries; + }; + } + + static auto FilterByTypes(TMaybe<THashSet<TString>> types) { + return [types = std::move(types)](auto f) { + TVector<TFolderEntry> entries = f.ExtractValue(); + EraseIf(entries, [types = std::move(types)](const TFolderEntry& entry) { + return types && !types->contains(entry.Type); + }); + return entries; + }; + } + + static auto Crop(size_t limit) { + return [limit](auto f) { + TVector<TFolderEntry> entries = f.ExtractValue(); + entries.crop(limit); + return entries; + }; + } + + static auto ToResponse(TStringBuf name) { + const auto length = name.length(); + return [length](auto f) { + return TListResponse{ + .NameHintLength = length, + .Entries = f.ExtractValue(), + }; + }; + } + + public: + explicit TSimpleSchemaGateway(ISimpleSchemaGateway::TPtr simple) + : Simple_(std::move(simple)) + { + } + + NThreading::TFuture<TListResponse> List(const TListRequest& request) const override { + auto [path, name] = Simple_->Split(request.Path); + return Simple_->List(TString(path)) + .Apply(FilterByName(TString(name))) + .Apply(FilterByTypes(std::move(request.Filter.Types))) + .Apply(Crop(request.Limit)) + .Apply(ToResponse(name)); + } + + private: + ISimpleSchemaGateway::TPtr Simple_; + }; + + } // namespace + + ISchemaGateway::TPtr MakeSimpleSchemaGateway(ISimpleSchemaGateway::TPtr simple) { + return ISchemaGateway::TPtr(new TSimpleSchemaGateway(std::move(simple))); + } + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.h b/yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.h new file mode 100644 index 00000000000..4b4671f1cca --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.h @@ -0,0 +1,23 @@ +#pragma once + +#include <yql/essentials/sql/v1/complete/name/object/schema_gateway.h> + +namespace NSQLComplete { + + struct TSplittedPath { + TStringBuf Path; + TStringBuf NameHint; + }; + + class ISimpleSchemaGateway: public TThrRefBase { + public: + using TPtr = TIntrusivePtr<ISimpleSchemaGateway>; + + virtual ~ISimpleSchemaGateway() = default; + virtual TSplittedPath Split(TStringBuf path) const = 0; + virtual NThreading::TFuture<TVector<TFolderEntry>> List(TString folder) const = 0; + }; + + ISchemaGateway::TPtr MakeSimpleSchemaGateway(ISimpleSchemaGateway::TPtr simple); + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/object/simple/ya.make b/yql/essentials/sql/v1/complete/name/object/simple/ya.make new file mode 100644 index 00000000000..d3668fdb1fc --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/object/simple/ya.make @@ -0,0 +1,11 @@ +LIBRARY() + +SRCS( + schema_gateway.cpp +) + +PEERDIR( + yql/essentials/sql/v1/complete/name/object +) + +END() diff --git a/yql/essentials/sql/v1/complete/name/object/static/schema_gateway.cpp b/yql/essentials/sql/v1/complete/name/object/static/schema_gateway.cpp index e325b908f7a..f43af57c752 100644 --- a/yql/essentials/sql/v1/complete/name/object/static/schema_gateway.cpp +++ b/yql/essentials/sql/v1/complete/name/object/static/schema_gateway.cpp @@ -1,6 +1,6 @@ #include "schema_gateway.h" -#include <yql/essentials/sql/v1/complete/text/case.h> +#include <yql/essentials/sql/v1/complete/name/object/simple/schema_gateway.h> #include <util/charset/utf8.h> @@ -8,11 +8,9 @@ namespace NSQLComplete { namespace { - class TSchemaGateway: public ISchemaGateway { - static constexpr size_t MaxLimit = 4 * 1024; - + class TSimpleSchemaGateway: public ISimpleSchemaGateway { public: - explicit TSchemaGateway(THashMap<TString, TVector<TFolderEntry>> data) + explicit TSimpleSchemaGateway(THashMap<TString, TVector<TFolderEntry>> data) : Data_(std::move(data)) { for (const auto& [k, _] : Data_) { @@ -21,35 +19,7 @@ namespace NSQLComplete { } } - NThreading::TFuture<TListResponse> List(const TListRequest& request) const override { - auto [path, prefix] = ParsePath(request.Path); - - TVector<TFolderEntry> entries; - if (const auto* data = Data_.FindPtr(path)) { - entries = *data; - } - - EraseIf(entries, [prefix = ToLowerUTF8(prefix)](const TFolderEntry& entry) { - return !entry.Name.StartsWith(prefix); - }); - - EraseIf(entries, [types = std::move(request.Filter.Types)](const TFolderEntry& entry) { - return types && !types->contains(entry.Type); - }); - - Y_ENSURE(request.Limit <= MaxLimit); - entries.crop(request.Limit); - - TListResponse response = { - .NameHintLength = prefix.size(), - .Entries = std::move(entries), - }; - - return NThreading::MakeFuture(std::move(response)); - } - - private: - static std::tuple<TStringBuf, TStringBuf> ParsePath(TString path Y_LIFETIME_BOUND) { + TSplittedPath Split(TStringBuf path) const override { size_t pos = path.find_last_of('/'); if (pos == TString::npos) { return {"", path}; @@ -60,13 +30,24 @@ namespace NSQLComplete { return {head, tail}; } + NThreading::TFuture<TVector<TFolderEntry>> List(TString folder) const override { + TVector<TFolderEntry> entries; + if (const auto* data = Data_.FindPtr(folder)) { + entries = *data; + } + return NThreading::MakeFuture(std::move(entries)); + } + + private: THashMap<TString, TVector<TFolderEntry>> Data_; }; } // namespace ISchemaGateway::TPtr MakeStaticSchemaGateway(THashMap<TString, TVector<TFolderEntry>> fs) { - return MakeIntrusive<TSchemaGateway>(std::move(fs)); + return MakeSimpleSchemaGateway( + ISimpleSchemaGateway::TPtr( + new TSimpleSchemaGateway(std::move(fs)))); } } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/object/static/ya.make b/yql/essentials/sql/v1/complete/name/object/static/ya.make index d3dd658468c..d37495f0700 100644 --- a/yql/essentials/sql/v1/complete/name/object/static/ya.make +++ b/yql/essentials/sql/v1/complete/name/object/static/ya.make @@ -6,6 +6,7 @@ SRCS( PEERDIR( yql/essentials/sql/v1/complete/name/object + yql/essentials/sql/v1/complete/name/object/simple ) END() diff --git a/yql/essentials/sql/v1/complete/name/object/ya.make b/yql/essentials/sql/v1/complete/name/object/ya.make index 1254e60dc6c..483f11c9a59 100644 --- a/yql/essentials/sql/v1/complete/name/object/ya.make +++ b/yql/essentials/sql/v1/complete/name/object/ya.make @@ -11,5 +11,6 @@ PEERDIR( END() RECURSE( + simple static ) diff --git a/yql/essentials/sql/v1/complete/name/service/name_service.cpp b/yql/essentials/sql/v1/complete/name/service/name_service.cpp new file mode 100644 index 00000000000..88473a60bff --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/service/name_service.cpp @@ -0,0 +1,81 @@ +#include "name_service.h" + +#include <yql/essentials/core/sql_types/normalize_name.h> + +#include <util/charset/utf8.h> + +namespace NSQLComplete { + + namespace { + + void SetPrefix(TString& name, const TStringBuf delimeter, const TNamespaced& namespaced) { + if (namespaced.Namespace.empty()) { + return; + } + + name.prepend(delimeter); + name.prepend(namespaced.Namespace); + } + + void FixPrefix(TString& name, const TStringBuf delimeter, const TNamespaced& namespaced) { + if (namespaced.Namespace.empty()) { + return; + } + + name.remove(0, namespaced.Namespace.size() + delimeter.size()); + } + + } // namespace + + TGenericName TNameConstraints::Qualified(TGenericName unqualified) const { + return std::visit([&](auto&& name) -> TGenericName { + using T = std::decay_t<decltype(name)>; + if constexpr (std::is_same_v<T, TPragmaName>) { + SetPrefix(name.Indentifier, ".", *Pragma); + } else if constexpr (std::is_same_v<T, TFunctionName>) { + SetPrefix(name.Indentifier, "::", *Function); + } + return name; + }, std::move(unqualified)); + } + + TGenericName TNameConstraints::Unqualified(TGenericName qualified) const { + return std::visit([&](auto&& name) -> TGenericName { + using T = std::decay_t<decltype(name)>; + if constexpr (std::is_same_v<T, TPragmaName>) { + FixPrefix(name.Indentifier, ".", *Pragma); + } else if constexpr (std::is_same_v<T, TFunctionName>) { + FixPrefix(name.Indentifier, "::", *Function); + } + return name; + }, std::move(qualified)); + } + + TVector<TGenericName> TNameConstraints::Qualified(TVector<TGenericName> unqualified) const { + for (auto& name : unqualified) { + name = Qualified(std::move(name)); + } + return unqualified; + } + + TVector<TGenericName> TNameConstraints::Unqualified(TVector<TGenericName> qualified) const { + for (auto& name : qualified) { + name = Unqualified(std::move(name)); + } + return qualified; + } + + TString LowerizeName(TStringBuf name) { + return ToLowerUTF8(name); + } + + TString NormalizeName(TStringBuf name) { + TString normalized(name); + TMaybe<NYql::TIssue> error = NYql::NormalizeName(NYql::TPosition(), normalized); + if (!error.Empty()) { + return LowerizeName(name); + } + return normalized; + } + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/name_service.h b/yql/essentials/sql/v1/complete/name/service/name_service.h index d71b2518bd0..7d773582b61 100644 --- a/yql/essentials/sql/v1/complete/name/service/name_service.h +++ b/yql/essentials/sql/v1/complete/name/service/name_service.h @@ -10,7 +10,7 @@ namespace NSQLComplete { - using NThreading::TFuture; + using NThreading::TFuture; // TODO(YQL-19747): remove struct TIndentifier { TString Indentifier; @@ -49,14 +49,21 @@ namespace NSQLComplete { TFunctionName, THintName>; + struct TNameConstraints { + TMaybe<TPragmaName::TConstraints> Pragma; + TMaybe<TTypeName::TConstraints> Type; + TMaybe<TFunctionName::TConstraints> Function; + TMaybe<THintName::TConstraints> Hint; + + TGenericName Qualified(TGenericName unqualified) const; + TGenericName Unqualified(TGenericName qualified) const; + TVector<TGenericName> Qualified(TVector<TGenericName> unqualified) const; + TVector<TGenericName> Unqualified(TVector<TGenericName> qualified) const; + }; + struct TNameRequest { TVector<TString> Keywords; - struct { - TMaybe<TPragmaName::TConstraints> Pragma; - TMaybe<TTypeName::TConstraints> Type; - TMaybe<TFunctionName::TConstraints> Function; - TMaybe<THintName::TConstraints> Hint; - } Constraints; + TNameConstraints Constraints; TString Prefix = ""; size_t Limit = 128; @@ -81,4 +88,6 @@ namespace NSQLComplete { virtual ~INameService() = default; }; + TString NormalizeName(TStringBuf name); + } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/ranking/frequency.cpp b/yql/essentials/sql/v1/complete/name/service/ranking/frequency.cpp index a9a99e27097..5b4d6088bf5 100644 --- a/yql/essentials/sql/v1/complete/name/service/ranking/frequency.cpp +++ b/yql/essentials/sql/v1/complete/name/service/ranking/frequency.cpp @@ -1,6 +1,6 @@ #include "frequency.h" -#include <yql/essentials/core/sql_types/normalize_name.h> +#include <yql/essentials/sql/v1/complete/name/service/name_service.h> #include <library/cpp/json/json_reader.h> #include <library/cpp/resource/resource.h> @@ -100,7 +100,7 @@ namespace NSQLComplete { TFrequencyData Pruned(const TFrequencyData& data) { return PrunedBy(data, [](TStringBuf s) { - return NYql::NormalizeName(s); + return NormalizeName(s); }); } diff --git a/yql/essentials/sql/v1/complete/name/service/ranking/ranking.cpp b/yql/essentials/sql/v1/complete/name/service/ranking/ranking.cpp index 901b841d1b5..3e2dd322522 100644 --- a/yql/essentials/sql/v1/complete/name/service/ranking/ranking.cpp +++ b/yql/essentials/sql/v1/complete/name/service/ranking/ranking.cpp @@ -2,8 +2,6 @@ #include <yql/essentials/sql/v1/complete/name/service/name_service.h> -#include <yql/essentials/core/sql_types/normalize_name.h> - #include <util/charset/utf8.h> namespace NSQLComplete { @@ -21,12 +19,16 @@ namespace NSQLComplete { { } - void CropToSortedPrefix(TVector<TGenericName>& names, size_t limit) const override { + void CropToSortedPrefix( + TVector<TGenericName>& names, + const TNameConstraints& constraints, + size_t limit) const override { limit = std::min(limit, names.size()); TVector<TRow> rows; rows.reserve(names.size()); for (TGenericName& name : names) { + name = constraints.Qualified(std::move(name)); size_t weight = Weight(name); rows.emplace_back(std::move(name), weight); } @@ -48,7 +50,7 @@ namespace NSQLComplete { rows.crop(limit); for (size_t i = 0; i < limit; ++i) { - names[i] = std::move(rows[i].Name); + names[i] = constraints.Unqualified(std::move(rows[i].Name)); } } @@ -57,7 +59,7 @@ namespace NSQLComplete { return std::visit([this](const auto& name) -> size_t { using T = std::decay_t<decltype(name)>; - auto content = NYql::NormalizeName(ContentView(name)); + auto content = NormalizeName(ContentView(name)); if constexpr (std::is_same_v<T, TKeyword>) { if (auto weight = Frequency_.Keywords.FindPtr(content)) { diff --git a/yql/essentials/sql/v1/complete/name/service/ranking/ranking.h b/yql/essentials/sql/v1/complete/name/service/ranking/ranking.h index 269f46d2028..ac6329bbaf9 100644 --- a/yql/essentials/sql/v1/complete/name/service/ranking/ranking.h +++ b/yql/essentials/sql/v1/complete/name/service/ranking/ranking.h @@ -11,7 +11,10 @@ namespace NSQLComplete { public: using TPtr = TIntrusivePtr<IRanking>; - virtual void CropToSortedPrefix(TVector<TGenericName>& names, size_t limit) const = 0; + virtual void CropToSortedPrefix( + TVector<TGenericName>& names, + const TNameConstraints& constraints, + size_t limit) const = 0; virtual ~IRanking() = default; }; diff --git a/yql/essentials/sql/v1/complete/name/service/ranking/ya.make b/yql/essentials/sql/v1/complete/name/service/ranking/ya.make index 4b139666fea..56e8e782128 100644 --- a/yql/essentials/sql/v1/complete/name/service/ranking/ya.make +++ b/yql/essentials/sql/v1/complete/name/service/ranking/ya.make @@ -6,7 +6,6 @@ SRCS( ) PEERDIR( - yql/essentials/core/sql_types yql/essentials/sql/v1/complete/name/service ) diff --git a/yql/essentials/sql/v1/complete/name/service/static/name_index.cpp b/yql/essentials/sql/v1/complete/name/service/static/name_index.cpp deleted file mode 100644 index bfbf6af7fb4..00000000000 --- a/yql/essentials/sql/v1/complete/name/service/static/name_index.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "name_index.h" - -#include <yql/essentials/core/sql_types/normalize_name.h> - -#include <util/charset/utf8.h> - -namespace NSQLComplete { - - TString NormalizeName(const TString& name) { - return NYql::NormalizeName(name); - } - - TString LowerizeName(const TString& name) { - return ToLowerUTF8(name); - } - - TString UnchangedName(const TString& name) { - return name; - } - -} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/static/name_index.h b/yql/essentials/sql/v1/complete/name/service/static/name_index.h index 77b50238846..840a4a5fe44 100644 --- a/yql/essentials/sql/v1/complete/name/service/static/name_index.h +++ b/yql/essentials/sql/v1/complete/name/service/static/name_index.h @@ -39,10 +39,4 @@ namespace NSQLComplete { return index; } - TString NormalizeName(const TString& name); - - TString LowerizeName(const TString& name); - - TString UnchangedName(const TString& name); - } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/static/name_service.cpp b/yql/essentials/sql/v1/complete/name/service/static/name_service.cpp index bb8d0840a93..976646dd7f4 100644 --- a/yql/essentials/sql/v1/complete/name/service/static/name_service.cpp +++ b/yql/essentials/sql/v1/complete/name/service/static/name_service.cpp @@ -3,6 +3,7 @@ #include "name_index.h" #include <yql/essentials/sql/v1/complete/name/service/ranking/ranking.h> +#include <yql/essentials/sql/v1/complete/name/service/union/name_service.h> #include <yql/essentials/sql/v1/complete/text/case.h> namespace NSQLComplete { @@ -40,113 +41,188 @@ namespace NSQLComplete { } } - TString Prefixed(const TStringBuf requestPrefix, const TStringBuf delimeter, const TNamespaced& namespaced) { - TString prefix; - if (!namespaced.Namespace.empty()) { - prefix += namespaced.Namespace; - prefix += delimeter; - } - prefix += requestPrefix; - return prefix; + template <class T> + void NameIndexScan( + const TNameIndex& index, + const TString& prefix, + const TNameConstraints& constraints, + TVector<TGenericName>& out) { + T name; + name.Indentifier = prefix; + name = std::get<T>(constraints.Qualified(std::move(name))); + + AppendAs<T>(out, FilteredByPrefix(name.Indentifier, index)); + out = constraints.Unqualified(std::move(out)); } - void FixPrefix(TString& name, const TStringBuf delimeter, const TNamespaced& namespaced) { - if (namespaced.Namespace.empty()) { - return; + class IRankingNameService: public INameService { + private: + auto Ranking(TNameRequest request) const { + return [request = std::move(request), this](auto f) { + TNameResponse response = f.ExtractValue(); + Ranking_->CropToSortedPrefix( + response.RankedNames, + request.Constraints, + request.Limit); + return response; + }; } - name.remove(0, namespaced.Namespace.size() + delimeter.size()); - } - void FixPrefix(TGenericName& name, const TNameRequest& request) { - std::visit([&](auto& name) -> size_t { - using T = std::decay_t<decltype(name)>; - if constexpr (std::is_same_v<T, TPragmaName>) { - FixPrefix(name.Indentifier, ".", *request.Constraints.Pragma); - } - if constexpr (std::is_same_v<T, TFunctionName>) { - FixPrefix(name.Indentifier, "::", *request.Constraints.Function); - } - return 0; - }, name); - } + public: + explicit IRankingNameService(IRanking::TPtr ranking) + : Ranking_(std::move(ranking)) + { + } + + NThreading::TFuture<TNameResponse> Lookup(TNameRequest request) const override { + return LookupAllUnranked(request).Apply(Ranking(request)); + } + + virtual NThreading::TFuture<TNameResponse> LookupAllUnranked(TNameRequest request) const = 0; + + private: + IRanking::TPtr Ranking_; + }; - class TStaticNameService: public INameService { + class TKeywordNameService: public IRankingNameService { public: - explicit TStaticNameService(TNameSet names, IRanking::TPtr ranking) - : Pragmas_(BuildNameIndex(std::move(names.Pragmas), NormalizeName)) - , Types_(BuildNameIndex(std::move(names.Types), NormalizeName)) - , Functions_(BuildNameIndex(std::move(names.Functions), NormalizeName)) - , Hints_([hints = std::move(names.Hints)] { - THashMap<EStatementKind, TNameIndex> index; - for (auto& [k, hints] : hints) { - index.emplace(k, BuildNameIndex(std::move(hints), NormalizeName)); - } - return index; - }()) - , Ranking_(std::move(ranking)) + explicit TKeywordNameService(IRanking::TPtr ranking) + : IRankingNameService(std::move(ranking)) { } - TFuture<TNameResponse> Lookup(TNameRequest request) const override { + NThreading::TFuture<TNameResponse> LookupAllUnranked(TNameRequest request) const override { TNameResponse response; - Sort(request.Keywords, NoCaseCompare); AppendAs<TKeyword>( response.RankedNames, FilteredByPrefix(request.Prefix, request.Keywords)); + return NThreading::MakeFuture<TNameResponse>(std::move(response)); + } + }; + class TPragmaNameService: public IRankingNameService { + public: + explicit TPragmaNameService(IRanking::TPtr ranking, TVector<TString> pragmas) + : IRankingNameService(std::move(ranking)) + , Pragmas_(BuildNameIndex(std::move(pragmas), NormalizeName)) + { + } + + NThreading::TFuture<TNameResponse> LookupAllUnranked(TNameRequest request) const override { + TNameResponse response; if (request.Constraints.Pragma) { - auto prefix = Prefixed(request.Prefix, ".", *request.Constraints.Pragma); - auto names = FilteredByPrefix(prefix, Pragmas_); - AppendAs<TPragmaName>(response.RankedNames, names); + NameIndexScan<TPragmaName>( + Pragmas_, + request.Prefix, + request.Constraints, + response.RankedNames); } + return NThreading::MakeFuture<TNameResponse>(std::move(response)); + } + + private: + TNameIndex Pragmas_; + }; + + class TTypeNameService: public IRankingNameService { + public: + explicit TTypeNameService(IRanking::TPtr ranking, TVector<TString> types) + : IRankingNameService(std::move(ranking)) + , Types_(BuildNameIndex(std::move(types), NormalizeName)) + { + } + NThreading::TFuture<TNameResponse> LookupAllUnranked(TNameRequest request) const override { + TNameResponse response; if (request.Constraints.Type) { - AppendAs<TTypeName>( - response.RankedNames, - FilteredByPrefix(request.Prefix, Types_)); + NameIndexScan<TTypeName>( + Types_, + request.Prefix, + request.Constraints, + response.RankedNames); } + return NThreading::MakeFuture<TNameResponse>(std::move(response)); + } + + private: + TNameIndex Types_; + }; + + class TFunctionNameService: public IRankingNameService { + public: + explicit TFunctionNameService(IRanking::TPtr ranking, TVector<TString> functions) + : IRankingNameService(std::move(ranking)) + , Functions_(BuildNameIndex(std::move(functions), NormalizeName)) + { + } + NThreading::TFuture<TNameResponse> LookupAllUnranked(TNameRequest request) const override { + TNameResponse response; if (request.Constraints.Function) { - auto prefix = Prefixed(request.Prefix, "::", *request.Constraints.Function); - auto names = FilteredByPrefix(prefix, Functions_); - AppendAs<TFunctionName>(response.RankedNames, names); + NameIndexScan<TFunctionName>( + Functions_, + request.Prefix, + request.Constraints, + response.RankedNames); } + return NThreading::MakeFuture<TNameResponse>(std::move(response)); + } + private: + TNameIndex Functions_; + }; + + class THintNameService: public IRankingNameService { + public: + explicit THintNameService( + IRanking::TPtr ranking, + THashMap<EStatementKind, TVector<TString>> hints) + : IRankingNameService(std::move(ranking)) + , Hints_([hints = std::move(hints)] { + THashMap<EStatementKind, TNameIndex> index; + for (auto& [k, hints] : hints) { + index.emplace(k, BuildNameIndex(std::move(hints), NormalizeName)); + } + return index; + }()) + { + } + + NThreading::TFuture<TNameResponse> LookupAllUnranked(TNameRequest request) const override { + TNameResponse response; if (request.Constraints.Hint) { const auto stmt = request.Constraints.Hint->Statement; if (const auto* hints = Hints_.FindPtr(stmt)) { - AppendAs<THintName>( - response.RankedNames, - FilteredByPrefix(request.Prefix, *hints)); + NameIndexScan<THintName>( + *hints, + request.Prefix, + request.Constraints, + response.RankedNames); } } - - Ranking_->CropToSortedPrefix(response.RankedNames, request.Limit); - - for (auto& name : response.RankedNames) { - FixPrefix(name, request); - } - - return NThreading::MakeFuture(std::move(response)); + return NThreading::MakeFuture<TNameResponse>(std::move(response)); } private: - TNameIndex Pragmas_; - TNameIndex Types_; - TNameIndex Functions_; THashMap<EStatementKind, TNameIndex> Hints_; - IRanking::TPtr Ranking_; }; INameService::TPtr MakeStaticNameService(TNameSet names, TFrequencyData frequency) { - return INameService::TPtr(new TStaticNameService( + return MakeStaticNameService( Pruned(std::move(names), frequency), - MakeDefaultRanking(std::move(frequency)))); + MakeDefaultRanking(std::move(frequency))); } INameService::TPtr MakeStaticNameService(TNameSet names, IRanking::TPtr ranking) { - return MakeIntrusive<TStaticNameService>(std::move(names), std::move(ranking)); + TVector<INameService::TPtr> children = { + new TKeywordNameService(ranking), + new TPragmaNameService(ranking, std::move(names.Pragmas)), + new TTypeNameService(ranking, std::move(names.Types)), + new TFunctionNameService(ranking, std::move(names.Functions)), + new THintNameService(ranking, std::move(names.Hints)), + }; + return MakeUnionNameService(std::move(children), ranking); } } // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/static/name_set.cpp b/yql/essentials/sql/v1/complete/name/service/static/name_set.cpp index eb01c78420c..41b2feef7e8 100644 --- a/yql/essentials/sql/v1/complete/name/service/static/name_set.cpp +++ b/yql/essentials/sql/v1/complete/name/service/static/name_set.cpp @@ -2,6 +2,8 @@ #include "name_index.h" +#include <yql/essentials/sql/v1/complete/name/service/name_service.h> + namespace NSQLComplete { TVector<TString> Pruned(TVector<TString> names, const THashMap<TString, size_t>& frequency) { diff --git a/yql/essentials/sql/v1/complete/name/service/static/name_set_json.cpp b/yql/essentials/sql/v1/complete/name/service/static/name_set_json.cpp index fc802b1d886..ba80d5d4c36 100644 --- a/yql/essentials/sql/v1/complete/name/service/static/name_set_json.cpp +++ b/yql/essentials/sql/v1/complete/name/service/static/name_set_json.cpp @@ -1,4 +1,5 @@ #include "name_set.h" +#include "name_set_json.h" #include <yql/essentials/sql/v1/complete/name/service/name_service.h> @@ -15,12 +16,6 @@ namespace NSQLComplete { return NJson::ReadJsonFastTree(text); } - template <class T, class U> - T Merge(T lhs, U rhs) { - std::copy(std::begin(rhs), std::end(rhs), std::back_inserter(lhs)); - return lhs; - } - TVector<TString> ParseNames(NJson::TJsonValue::TArray& json) { TVector<TString> keys; keys.reserve(json.size()); diff --git a/yql/essentials/sql/v1/complete/name/service/static/name_set_json.h b/yql/essentials/sql/v1/complete/name/service/static/name_set_json.h new file mode 100644 index 00000000000..03d5083c432 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/service/static/name_set_json.h @@ -0,0 +1,31 @@ +#pragma once + +#include <yql/essentials/sql/v1/complete/core/statement.h> + +#include <library/cpp/json/json_value.h> + +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/generic/hash.h> + +namespace NSQLComplete { + + NJson::TJsonValue LoadJsonResource(const TStringBuf filename); + + template <class T, class U> + T Merge(T lhs, U rhs) { + std::copy(std::begin(rhs), std::end(rhs), std::back_inserter(lhs)); + return lhs; + } + + TVector<TString> ParsePragmas(NJson::TJsonValue json); + + TVector<TString> ParseTypes(NJson::TJsonValue json); + + TVector<TString> ParseFunctions(NJson::TJsonValue json); + + TVector<TString> ParseUdfs(NJson::TJsonValue json); + + THashMap<EStatementKind, TVector<TString>> ParseHints(NJson::TJsonValue json); + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/static/ya.make b/yql/essentials/sql/v1/complete/name/service/static/ya.make index 303916474ba..317ef4ed84a 100644 --- a/yql/essentials/sql/v1/complete/name/service/static/ya.make +++ b/yql/essentials/sql/v1/complete/name/service/static/ya.make @@ -1,16 +1,15 @@ LIBRARY() SRCS( + name_service.cpp name_set_json.cpp name_set.cpp - name_index.cpp - name_service.cpp ) PEERDIR( - yql/essentials/core/sql_types yql/essentials/sql/v1/complete/name/service yql/essentials/sql/v1/complete/name/service/ranking + yql/essentials/sql/v1/complete/name/service/union yql/essentials/sql/v1/complete/text ) diff --git a/yql/essentials/sql/v1/complete/name/service/union/name_service.cpp b/yql/essentials/sql/v1/complete/name/service/union/name_service.cpp new file mode 100644 index 00000000000..c2373822f6f --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/service/union/name_service.cpp @@ -0,0 +1,60 @@ +#include "name_service.h" + +#include <library/cpp/threading/future/wait/wait.h> + +namespace NSQLComplete { + + namespace { + + class TNameService: public INameService { + public: + TNameService( + TVector<INameService::TPtr> children, + IRanking::TPtr ranking) + : Children_(std::move(children)) + , Ranking_(std::move(ranking)) + { + } + + NThreading::TFuture<TNameResponse> Lookup(TNameRequest request) const override { + TVector<NThreading::TFuture<TNameResponse>> fs; + for (const auto& c : Children_) { + fs.emplace_back(c->Lookup(request)); + } + return NThreading::WaitAll(fs) + .Apply([fs, this, request = std::move(request)](auto) { + return Union(fs, request.Constraints, request.Limit); + }); + } + + private: + TNameResponse Union( + TVector<NThreading::TFuture<TNameResponse>> fs, + const TNameConstraints& constraints, + size_t limit) const { + TNameResponse united; + for (auto f : fs) { + TNameResponse response = f.ExtractValue(); + std::ranges::move( + response.RankedNames, + std::back_inserter(united.RankedNames)); + } + Ranking_->CropToSortedPrefix(united.RankedNames, constraints, limit); + return united; + } + + TVector<INameService::TPtr> Children_; + IRanking::TPtr Ranking_; + }; + + } // namespace + + INameService::TPtr MakeUnionNameService( + TVector<INameService::TPtr> children, + IRanking::TPtr ranking) { + return INameService::TPtr(new TNameService( + std::move(children), + std::move(ranking))); + } + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/union/name_service.h b/yql/essentials/sql/v1/complete/name/service/union/name_service.h new file mode 100644 index 00000000000..a9cd292b39b --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/service/union/name_service.h @@ -0,0 +1,12 @@ +#pragma once + +#include <yql/essentials/sql/v1/complete/name/service/ranking/ranking.h> +#include <yql/essentials/sql/v1/complete/name/service/name_service.h> + +namespace NSQLComplete { + + INameService::TPtr MakeUnionNameService( + TVector<INameService::TPtr> children, + IRanking::TPtr ranking); + +} // namespace NSQLComplete diff --git a/yql/essentials/sql/v1/complete/name/service/union/ya.make b/yql/essentials/sql/v1/complete/name/service/union/ya.make new file mode 100644 index 00000000000..6716ee6ab20 --- /dev/null +++ b/yql/essentials/sql/v1/complete/name/service/union/ya.make @@ -0,0 +1,12 @@ +LIBRARY() + +SRCS( + name_service.cpp +) + +PEERDIR( + yql/essentials/sql/v1/complete/name/service + yql/essentials/sql/v1/complete/name/service/ranking +) + +END() diff --git a/yql/essentials/sql/v1/complete/name/service/ya.make b/yql/essentials/sql/v1/complete/name/service/ya.make index 473ee05f7d8..1f1af9055ae 100644 --- a/yql/essentials/sql/v1/complete/name/service/ya.make +++ b/yql/essentials/sql/v1/complete/name/service/ya.make @@ -1,6 +1,11 @@ LIBRARY() +SRCS( + name_service.cpp +) + PEERDIR( + yql/essentials/core/sql_types yql/essentials/sql/v1/complete/core ) @@ -9,4 +14,5 @@ END() RECURSE( ranking static + union ) diff --git a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp index 54bcc35233f..060dfd42add 100644 --- a/yql/essentials/sql/v1/complete/sql_complete_ut.cpp +++ b/yql/essentials/sql/v1/complete/sql_complete_ut.cpp @@ -52,7 +52,11 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { TNameSet names = { .Pragmas = {"yson.CastToString"}, .Types = {"Uint64"}, - .Functions = {"StartsWith", "DateTime::Split"}, + .Functions = { + "StartsWith", + "DateTime::Split", + "Python::__private", + }, .Hints = { {EStatementKind::Select, {"XLOCK"}}, {EStatementKind::Insert, {"EXPIRATION"}}, @@ -303,6 +307,13 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { { TVector<TCandidate> expected = { {PragmaName, "yson.CastToString"}}; + auto completion = engine->CompleteAsync({"PRAGMA ys"}).GetValueSync(); + UNIT_ASSERT_VALUES_EQUAL(completion.Candidates, expected); + UNIT_ASSERT_VALUES_EQUAL(completion.CompletedToken.Content, "ys"); + } + { + TVector<TCandidate> expected = { + {PragmaName, "yson.CastToString"}}; auto completion = engine->CompleteAsync({"PRAGMA yson"}).GetValueSync(); UNIT_ASSERT_VALUES_EQUAL(completion.Candidates, expected); UNIT_ASSERT_VALUES_EQUAL(completion.CompletedToken.Content, "yson"); @@ -348,6 +359,7 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { {Keyword, "NOT"}, {Keyword, "NULL"}, {Keyword, "OPTIONAL<"}, + {FunctionName, "Python::__private("}, {Keyword, "RESOURCE<"}, {Keyword, "SET<"}, {Keyword, "STREAM"}, @@ -407,6 +419,7 @@ Y_UNIT_TEST_SUITE(SqlCompleteTests) { {Keyword, "NOT"}, {Keyword, "NULL"}, {Keyword, "OPTIONAL<"}, + {FunctionName, "Python::__private("}, {Keyword, "RESOURCE<"}, {Keyword, "SET<"}, {Keyword, "STREAM<"}, diff --git a/yql/essentials/sql/v1/complete/ut/ya.make b/yql/essentials/sql/v1/complete/ut/ya.make index 4c50124cf7c..fbb84f56f25 100644 --- a/yql/essentials/sql/v1/complete/ut/ya.make +++ b/yql/essentials/sql/v1/complete/ut/ya.make @@ -7,6 +7,7 @@ SRCS( PEERDIR( yql/essentials/sql/v1/lexer/antlr4_pure yql/essentials/sql/v1/lexer/antlr4_pure_ansi + yql/essentials/sql/v1/complete/name/service/static ) END() diff --git a/yql/essentials/sql/v1/complete/ya.make b/yql/essentials/sql/v1/complete/ya.make index 4e2d02b7a97..010c8ac5243 100644 --- a/yql/essentials/sql/v1/complete/ya.make +++ b/yql/essentials/sql/v1/complete/ya.make @@ -8,9 +8,11 @@ PEERDIR( yql/essentials/sql/v1/lexer yql/essentials/sql/v1/complete/antlr4 yql/essentials/sql/v1/complete/name/service - yql/essentials/sql/v1/complete/name/service/static yql/essentials/sql/v1/complete/syntax yql/essentials/sql/v1/complete/text + + # TODO(YQL-19747): add it to YDB CLI PEERDIR + yql/essentials/sql/v1/complete/name/service/static ) END() diff --git a/yql/essentials/tests/sql/minirun/part8/canondata/result.json b/yql/essentials/tests/sql/minirun/part8/canondata/result.json index 594491392db..da62996a222 100644 --- a/yql/essentials/tests/sql/minirun/part8/canondata/result.json +++ b/yql/essentials/tests/sql/minirun/part8/canondata/result.json @@ -929,6 +929,20 @@ "uri": "https://{canondata_backend}/1600758/e19ffc8677d8f7ce11076554c3082ee5be112fdb/resource.tar.gz#test.test_flexible_types-group_by-default.txt-Results_/results.txt" } ], + "test.test[in-YQL-18950-default.txt-Debug]": [ + { + "checksum": "409bea545c24e6f2a706f68f61946b36", + "size": 392, + "uri": "https://{canondata_backend}/1871182/8488be1009a783ec149801679b1d381d33cbeb2f/resource.tar.gz#test.test_in-YQL-18950-default.txt-Debug_/opt.yql" + } + ], + "test.test[in-YQL-18950-default.txt-Results]": [ + { + "checksum": "cd3fe11c994d6fd2fd90ed01fed17699", + "size": 1198, + "uri": "https://{canondata_backend}/1871182/8488be1009a783ec149801679b1d381d33cbeb2f/resource.tar.gz#test.test_in-YQL-18950-default.txt-Results_/results.txt" + } + ], "test.test[join-yql-19731-default.txt-Debug]": [ { "checksum": "176315e3d36000d21b5d5b939996e7f4", diff --git a/yql/essentials/tests/sql/sql2yql/canondata/result.json b/yql/essentials/tests/sql/sql2yql/canondata/result.json index 7bb69811dcd..85753fb74b3 100644 --- a/yql/essentials/tests/sql/sql2yql/canondata/result.json +++ b/yql/essentials/tests/sql/sql2yql/canondata/result.json @@ -3688,6 +3688,13 @@ "uri": "https://{canondata_backend}/1942173/99e88108149e222741552e7e6cddef041d6a2846/resource.tar.gz#test_sql2yql.test_flexible_types-with_typeof_/sql.yql" } ], + "test_sql2yql.test[in-YQL-18950]": [ + { + "checksum": "d8bccb706313d9a7c54a46d66d93352b", + "size": 1248, + "uri": "https://{canondata_backend}/1899731/9a58f2b769ebea90a87ae7182a05e62331015e1f/resource.tar.gz#test_sql2yql.test_in-YQL-18950_/sql.yql" + } + ], "test_sql2yql.test[in-in_ansi]": [ { "checksum": "9107459fab676d3b103f131638426169", @@ -10103,6 +10110,11 @@ "uri": "file://test_sql_format.test_flexible_types-with_typeof_/formatted.sql" } ], + "test_sql_format.test[in-YQL-18950]": [ + { + "uri": "file://test_sql_format.test_in-YQL-18950_/formatted.sql" + } + ], "test_sql_format.test[in-in_ansi]": [ { "uri": "file://test_sql_format.test_in-in_ansi_/formatted.sql" diff --git a/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_in-YQL-18950_/formatted.sql b/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_in-YQL-18950_/formatted.sql new file mode 100644 index 00000000000..7ea9768d1c8 --- /dev/null +++ b/yql/essentials/tests/sql/sql2yql/canondata/test_sql_format.test_in-YQL-18950_/formatted.sql @@ -0,0 +1,15 @@ +/* postgres can not */ +PRAGMA AnsiInForEmptyOrNullableItemsCollections; + +$list = AsList( + AsTuple(Just('aa'), Just('aaa')), + AsTuple(Just('bb'), Just('bbb')), +); + +SELECT + ListMap( + $list, ($item) -> { + RETURN 'bb' IN $item; + } + ) +; diff --git a/yql/essentials/tests/sql/suites/in/YQL-18950.sql b/yql/essentials/tests/sql/suites/in/YQL-18950.sql new file mode 100644 index 00000000000..7bf333c1a22 --- /dev/null +++ b/yql/essentials/tests/sql/suites/in/YQL-18950.sql @@ -0,0 +1,10 @@ +/* postgres can not */ + +PRAGMA AnsiInForEmptyOrNullableItemsCollections; + +$list = AsList( + AsTuple(Just("aa"), Just("aaa")), + AsTuple(Just("bb"), Just("bbb")), +); + +SELECT ListMap($list, ($item) -> { RETURN 'bb' IN $item; }); diff --git a/yql/essentials/tools/yql_complete/ya.make b/yql/essentials/tools/yql_complete/ya.make index 21a98628b1e..36d845dadab 100644 --- a/yql/essentials/tools/yql_complete/ya.make +++ b/yql/essentials/tools/yql_complete/ya.make @@ -5,6 +5,8 @@ PROGRAM() PEERDIR( library/cpp/getopt yql/essentials/sql/v1/complete + yql/essentials/sql/v1/complete/name/service/ranking + yql/essentials/sql/v1/complete/name/service/static yql/essentials/sql/v1/lexer/antlr4_pure yql/essentials/sql/v1/lexer/antlr4_pure_ansi yql/essentials/utils diff --git a/yql/tools/yqlrun/http/www/js/mode-sql.js b/yql/tools/yqlrun/http/www/js/mode-sql.js index 8d0a0d12460..81448df47e0 100644 --- a/yql/tools/yqlrun/http/www/js/mode-sql.js +++ b/yql/tools/yqlrun/http/www/js/mode-sql.js @@ -25,7 +25,7 @@ var SqlHighlightRules = function() { "avg|cast|coalesce|likely|random|randomnumber|filecontent|filepath|length|max|median|count|count_if|" + "grouping|min|percentile|sum|min_by|max_by|min_of|max_of|stddev|variance|" + "stddev_sample|stddev_population|variance_sample|variance_population|" + - "bool_and|bool_or|bit_and|bit_or|bit_xor|some|list|unique|sakura|betula|banach|smith|hegel|aristotle|plato|quine|marx|freud|hahn|cedar" + "bool_and|bool_or|bit_and|bit_or|bit_xor|some|list|unique" ); var dataTypes = ( diff --git a/yt/cpp/mapreduce/interface/error_codes.h b/yt/cpp/mapreduce/interface/error_codes.h index a782cb80859..93c60b6da2f 100644 --- a/yt/cpp/mapreduce/interface/error_codes.h +++ b/yt/cpp/mapreduce/interface/error_codes.h @@ -392,9 +392,10 @@ namespace NChunkPools { //////////////////////////////////////////////////////////////////////////////// - constexpr int DataSliceLimitExceeded = 2000; - constexpr int MaxDataWeightPerJobExceeded = 2001; - constexpr int MaxPrimaryDataWeightPerJobExceeded = 2002; + constexpr int DataSliceLimitExceeded = 2000; + constexpr int MaxDataWeightPerJobExceeded = 2001; + constexpr int MaxPrimaryDataWeightPerJobExceeded = 2002; + constexpr int MaxCompressedDataSizePerJobExceeded = 2003; //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/client/table_client/columnar-inl.h b/yt/yt/client/table_client/columnar-inl.h index 64e7a91426e..ddb4f2c7058 100644 --- a/yt/yt/client/table_client/columnar-inl.h +++ b/yt/yt/client/table_client/columnar-inl.h @@ -4,6 +4,8 @@ #include "columnar.h" #endif +#include "helpers.h" + #include <library/cpp/yt/coding/zig_zag.h> namespace NYT::NTableClient { @@ -67,6 +69,7 @@ void DecodeVectorRleImpl( ui64 baseValue, TRange<ui32> dictionaryIndexes, TRange<ui64> rleIndexes, + TRef bitmap, TGetter getter, TConsumer consumer) { @@ -81,17 +84,30 @@ void DecodeVectorRleImpl( break; } decltype(getter(0)) currentValue; + bool isNull = false; if constexpr(WithDictionary) { auto dictionaryIndex = dictionaryIndexes[currentRleIndex]; if (dictionaryIndex == 0) { currentValue = {}; } else { - currentValue = getter(dictionaryIndex - 1); + if (bitmap && GetBit(bitmap, dictionaryIndex - 1)) { + isNull = true; + } else { + currentValue = getter(dictionaryIndex - 1); + } + } } else { - currentValue = getter(currentRleIndex); + if (bitmap && GetBit(bitmap, currentRleIndex)) { + isNull = true; + } else { + currentValue = getter(currentRleIndex); + } + } + currentDecodedValue = {}; + if (!isNull) { + currentDecodedValue = TValueDecoder<WithBaseValue, WithZigZag, T>::Run(currentValue, baseValue); } - currentDecodedValue = TValueDecoder<WithBaseValue, WithZigZag, T>::Run(currentValue, baseValue); ++currentRleIndex; thresholdIndex = currentRleIndex < std::ssize(rleIndexes) ? std::min(static_cast<i64>(rleIndexes[currentRleIndex]), endIndex) @@ -131,22 +147,36 @@ void DecodeVectorDirectImpl( i64 endIndex, ui64 baseValue, TRange<ui32> dictionaryIndexes, + TRef bitmap, TGetter getter, TConsumer consumer) { for (i64 index = startIndex; index < endIndex; ++index) { + bool isNull = false; decltype(getter(0)) value; if constexpr(WithDictionary) { auto dictionaryIndex = dictionaryIndexes[index]; if (dictionaryIndex == 0) { value = {}; } else { - value = getter(dictionaryIndex - 1); + if (bitmap && GetBit(bitmap, dictionaryIndex - 1)) { + isNull = true; + } else { + value = getter(dictionaryIndex - 1); + } } } else { - value = getter(index); + if (bitmap && GetBit(bitmap, index)) { + isNull = true; + } else { + value = getter(index); + } } - auto decodedValue = TValueDecoder<WithBaseValue, WithZigZag, T>::Run(value, baseValue); + decltype(getter(0)) decodedValue = {}; + if (!isNull) { + decodedValue = TValueDecoder<WithBaseValue, WithZigZag, T>::Run(value, baseValue); + } + consumer(decodedValue); } } @@ -172,6 +202,7 @@ void DecodeVectorImpl( ui64 baseValue, TRange<ui32> dictionaryIndexes, TRange<ui64> rleIndexes, + TRef bitmap, TGetter getter, TConsumer consumer) { @@ -182,6 +213,7 @@ void DecodeVectorImpl( baseValue, dictionaryIndexes, rleIndexes, + bitmap, std::forward<TGetter>(getter), std::forward<TConsumer>(consumer)); } else { @@ -190,6 +222,7 @@ void DecodeVectorImpl( endIndex, baseValue, dictionaryIndexes, + bitmap, std::forward<TGetter>(getter), std::forward<TConsumer>(consumer)); } @@ -258,6 +291,7 @@ void DecodeVector( bool zigZagEncoded, TRange<ui32> dictionaryIndexes, TRange<ui64> rleIndexes, + TRef bitmap, TGetter getter, TConsumer consumer) { @@ -271,6 +305,7 @@ void DecodeVector( baseValue, \ dictionaryIndexes, \ rleIndexes, \ + bitmap, \ std::forward<TGetter>(getter), \ std::forward<TConsumer>(consumer)); @@ -324,6 +359,7 @@ void DecodeIntegerVector( bool zigZagEncoded, TRange<ui32> dictionaryIndexes, TRange<ui64> rleIndexes, + TRef bitmap, TFetcher fetcher, TConsumer consumer) { @@ -334,6 +370,7 @@ void DecodeIntegerVector( zigZagEncoded, dictionaryIndexes, rleIndexes, + bitmap, std::forward<TFetcher>(fetcher), std::forward<TConsumer>(consumer)); } @@ -358,6 +395,7 @@ void DecodeRawVector( false, dictionaryIndexes, rleIndexes, + /*bitmap*/ {}, std::forward<TFetcher>(fetcher), std::forward<TConsumer>(consumer)); } diff --git a/yt/yt/client/table_client/columnar.cpp b/yt/yt/client/table_client/columnar.cpp index 5cfc89f80f5..dfa27d18a3d 100644 --- a/yt/yt/client/table_client/columnar.cpp +++ b/yt/yt/client/table_client/columnar.cpp @@ -1,4 +1,5 @@ #include "columnar.h" +#include "helpers.h" #include <yt/yt/library/numeric/algorithm_helpers.h> @@ -128,22 +129,6 @@ void CopyBitmapRangeToBitmapImpl( } } -bool GetBit(TRef bitmap, i64 index) -{ - return (bitmap[index >> 3] & (1U << (index & 7))) != 0; -} - -void SetBit(TMutableRef bitmap, i64 index, bool value) -{ - auto& byte = bitmap[index >> 3]; - auto mask = (1U << (index & 7)); - if (value) { - byte |= mask; - } else { - byte &= ~mask; - } -} - template <class F> void BuildBitmapFromRleImpl( TRange<ui64> rleIndexes, diff --git a/yt/yt/client/table_client/columnar.h b/yt/yt/client/table_client/columnar.h index 94958b5d49f..442284efc28 100644 --- a/yt/yt/client/table_client/columnar.h +++ b/yt/yt/client/table_client/columnar.h @@ -211,6 +211,7 @@ void DecodeIntegerVector( bool zigZagEncoded, TRange<ui32> dictionaryIndexes, TRange<ui64> rleIndexes, + TRef bitmap, TFetcher fetcher, TConsumer consumer); diff --git a/yt/yt/client/table_client/helpers.cpp b/yt/yt/client/table_client/helpers.cpp index b2eabd6d0aa..10755ecbd09 100644 --- a/yt/yt/client/table_client/helpers.cpp +++ b/yt/yt/client/table_client/helpers.cpp @@ -1669,4 +1669,22 @@ TUnversionedValueRangeTruncationResult TruncateUnversionedValues( //////////////////////////////////////////////////////////////////////////////// +bool GetBit(TRef bitmap, i64 index) +{ + return (bitmap[index >> 3] & (1U << (index & 7))) != 0; +} + +void SetBit(TMutableRef bitmap, i64 index, bool value) +{ + auto& byte = bitmap[index >> 3]; + auto mask = (1U << (index & 7)); + if (value) { + byte |= mask; + } else { + byte &= ~mask; + } +} + +//////////////////////////////////////////////////////////////////////////////// + } // namespace NYT::NTableClient diff --git a/yt/yt/client/table_client/helpers.h b/yt/yt/client/table_client/helpers.h index 3b85e1edfb0..6ea310ac9b5 100644 --- a/yt/yt/client/table_client/helpers.h +++ b/yt/yt/client/table_client/helpers.h @@ -392,6 +392,12 @@ TUnversionedValueRangeTruncationResult TruncateUnversionedValues(TUnversionedVal //////////////////////////////////////////////////////////////////////////////// +bool GetBit(TRef bitmap, i64 index); + +void SetBit(TMutableRef bitmap, i64 index, bool value); + +//////////////////////////////////////////////////////////////////////////////// + } // namespace NYT::NTableClient #define HELPERS_INL_H_ diff --git a/yt/yt/client/table_client/logical_type.cpp b/yt/yt/client/table_client/logical_type.cpp index 821b726e59e..d3c9d1850d9 100644 --- a/yt/yt/client/table_client/logical_type.cpp +++ b/yt/yt/client/table_client/logical_type.cpp @@ -201,18 +201,18 @@ TString ToString(const TLogicalType& logicalType) { switch (logicalType.GetMetatype()) { case ELogicalMetatype::Simple: - return CamelCaseToUnderscoreCase(ToString(logicalType.AsSimpleTypeRef().GetElement())); + return ToString(logicalType.AsSimpleTypeRef().GetElement()); case ELogicalMetatype::Decimal: - return Format("decimal(%v,%v)", + return Format("Decimal(%v,%v)", logicalType.AsDecimalTypeRef().GetPrecision(), logicalType.AsDecimalTypeRef().GetScale()); case ELogicalMetatype::Optional: - return Format("optional<%v>", *logicalType.AsOptionalTypeRef().GetElement()); + return Format("Optional<%v>", *logicalType.AsOptionalTypeRef().GetElement()); case ELogicalMetatype::List: - return Format("list<%v>", *logicalType.AsListTypeRef().GetElement()); + return Format("List<%v>", *logicalType.AsListTypeRef().GetElement()); case ELogicalMetatype::Struct: { TStringStream out; - out << "struct<"; + out << "Struct<"; bool first = true; for (const auto& structItem : logicalType.AsStructTypeRef().GetFields()) { if (first) { @@ -227,7 +227,7 @@ TString ToString(const TLogicalType& logicalType) } case ELogicalMetatype::Tuple: { TStringStream out; - out << "tuple<"; + out << "Tuple<"; bool first = true; for (const auto& element : logicalType.AsTupleTypeRef().GetElements()) { if (first) { @@ -242,7 +242,7 @@ TString ToString(const TLogicalType& logicalType) } case ELogicalMetatype::VariantTuple: { TStringStream out; - out << "variant<"; + out << "Variant<"; bool first = true; for (const auto& element : logicalType.AsVariantTupleTypeRef().GetElements()) { if (first) { @@ -257,7 +257,7 @@ TString ToString(const TLogicalType& logicalType) } case ELogicalMetatype::VariantStruct: { TStringStream out; - out << "named_variant<"; + out << "NamedVariant<"; bool first = true; for (const auto& field : logicalType.AsVariantStructTypeRef().GetFields()) { if (first) { @@ -274,13 +274,13 @@ TString ToString(const TLogicalType& logicalType) case ELogicalMetatype::Dict: { const auto& dictType = logicalType.AsDictTypeRef(); TStringStream out; - out << "dict<" << ToString(*dictType.GetKey()) << ';' << ToString(*dictType.GetValue()) << '>'; + out << "Dict<" << ToString(*dictType.GetKey()) << ';' << ToString(*dictType.GetValue()) << '>'; return out.Str(); } case ELogicalMetatype::Tagged: { const auto& taggedType = logicalType.AsTaggedTypeRef(); TStringStream out; - out << "tagged<\"" << ToString(taggedType.GetTag()) << "\";" << ToString(*taggedType.GetElement()) << '>'; + out << "Tagged<\"" << ToString(taggedType.GetTag()) << "\";" << ToString(*taggedType.GetElement()) << '>'; return out.Str(); } } diff --git a/yt/yt/core/actions/invoker_detail.cpp b/yt/yt/core/actions/invoker_detail.cpp index fc0fd7b3089..47d20524bd1 100644 --- a/yt/yt/core/actions/invoker_detail.cpp +++ b/yt/yt/core/actions/invoker_detail.cpp @@ -66,17 +66,26 @@ template struct NDetail::TMaybeVirtualInvokerBase<false>; //////////////////////////////////////////////////////////////////////////////// -TInvokerProfileWrapper::TInvokerProfileWrapper(NProfiling::IRegistryPtr registry, const TString& invokerFamily, const NProfiling::TTagSet& tagSet) +TInvokerProfilingWrapper::TInvokerProfilingWrapper( + NProfiling::IRegistryPtr registry, + const std::string& invokerFamily, + const NProfiling::TTagSet& tagSet) { - auto profiler = NProfiling::TProfiler("/invoker", NProfiling::TProfiler::DefaultNamespace, tagSet, registry).WithHot(); + auto profiler = NProfiling::TProfiler( + "/invoker", + NProfiling::TProfiler::DefaultNamespace, + tagSet, registry) + .WithHot(); WaitTimer_ = profiler.Timer(invokerFamily + "/wait"); } -TClosure TInvokerProfileWrapper::WrapCallback(TClosure callback) +TClosure TInvokerProfilingWrapper::WrapCallback(TClosure callback) { - auto invokedAt = GetCpuInstant(); + if (!WaitTimer_) { + return callback; + } - return BIND([invokedAt, waitTimer = WaitTimer_, callback = std::move(callback)] { + return BIND([invokedAt = GetCpuInstant(), waitTimer = WaitTimer_, callback = std::move(callback)] { // Measure the time from WrapCallback() to callback(). auto waitTime = CpuDurationToDuration(GetCpuInstant() - invokedAt); waitTimer.Record(waitTime); diff --git a/yt/yt/core/actions/invoker_detail.h b/yt/yt/core/actions/invoker_detail.h index 1f662dccd4c..72e11bd5103 100644 --- a/yt/yt/core/actions/invoker_detail.h +++ b/yt/yt/core/actions/invoker_detail.h @@ -50,15 +50,22 @@ protected: //////////////////////////////////////////////////////////////////////////////// //! A helper base which makes callbacks track their invocation time and profile their wait time. -class TInvokerProfileWrapper +class TInvokerProfilingWrapper { public: + //! Constructs a wrapper with profiling disabled. + TInvokerProfilingWrapper() = default; + + //! Constructs a wrapper with profiling enabled. /*! * #registry defines a profile registry where sensors data is stored. * #invokerFamily defines a family of invokers, e.g. "serialized" or "prioritized" and appears in sensor's name. * #tagSet defines a particular instance of the invoker and appears in sensor's tags. */ - TInvokerProfileWrapper(NProfiling::IRegistryPtr registry, const TString& invokerFamily, const NProfiling::TTagSet& tagSet); + TInvokerProfilingWrapper( + NProfiling::IRegistryPtr registry, + const std::string& invokerFamily, + const NProfiling::TTagSet& tagSet); protected: TClosure WrapCallback(TClosure callback); diff --git a/yt/yt/core/concurrency/action_queue.cpp b/yt/yt/core/concurrency/action_queue.cpp index 8b04a291c90..2e8ba17f668 100644 --- a/yt/yt/core/concurrency/action_queue.cpp +++ b/yt/yt/core/concurrency/action_queue.cpp @@ -110,15 +110,20 @@ const IInvokerPtr& TActionQueue::GetInvoker() class TSerializedInvoker : public TInvokerWrapper<false> - , public TInvokerProfileWrapper + , public TInvokerProfilingWrapper { public: + explicit TSerializedInvoker( + IInvokerPtr underlyingInvoker) + : TInvokerWrapper(std::move(underlyingInvoker)) + { } + TSerializedInvoker( IInvokerPtr underlyingInvoker, const NProfiling::TTagSet& tagSet, NProfiling::IRegistryPtr registry) : TInvokerWrapper(std::move(underlyingInvoker)) - , TInvokerProfileWrapper(std::move(registry), "/serialized", tagSet) + , TInvokerProfilingWrapper(std::move(registry), "/serialized", tagSet) { } using TInvokerWrapper::Invoke; @@ -244,33 +249,58 @@ private: } }; -IInvokerPtr CreateSerializedInvoker(IInvokerPtr underlyingInvoker, const NProfiling::TTagSet& tagSet, NProfiling::IRegistryPtr registry) +IInvokerPtr CreateSerializedInvoker( + IInvokerPtr underlyingInvoker) { if (underlyingInvoker->IsSerialized()) { return underlyingInvoker; } - return New<TSerializedInvoker>(std::move(underlyingInvoker), tagSet, registry); + return New<TSerializedInvoker>( + std::move(underlyingInvoker)); } -IInvokerPtr CreateSerializedInvoker(IInvokerPtr underlyingInvoker, const TString& invokerName, NProfiling::IRegistryPtr registry) +IInvokerPtr CreateSerializedInvoker( + IInvokerPtr underlyingInvoker, + const NProfiling::TTagSet& tagSet, + NProfiling::IRegistryPtr registry) { - NProfiling::TTagSet tagSet; - tagSet.AddTag(NProfiling::TTag("invoker", invokerName)); - return CreateSerializedInvoker(std::move(underlyingInvoker), std::move(tagSet), std::move(registry)); + if (underlyingInvoker->IsSerialized()) { + return underlyingInvoker; + } + + return New<TSerializedInvoker>( + std::move(underlyingInvoker), + tagSet, + std::move(registry)); +} + +IInvokerPtr CreateSerializedInvoker( + IInvokerPtr underlyingInvoker, + const std::string& invokerName, + NProfiling::IRegistryPtr registry) +{ + return CreateSerializedInvoker( + std::move(underlyingInvoker), + NProfiling::TTagSet({{"invoker", invokerName}}), + std::move(registry)); } //////////////////////////////////////////////////////////////////////////////// class TPrioritizedInvoker : public TInvokerWrapper<true> - , public TInvokerProfileWrapper + , public TInvokerProfilingWrapper , public virtual IPrioritizedInvoker { public: + explicit TPrioritizedInvoker(IInvokerPtr underlyingInvoker) + : TInvokerWrapper(std::move(underlyingInvoker)) + { } + TPrioritizedInvoker(IInvokerPtr underlyingInvoker, const NProfiling::TTagSet& tagSet, NProfiling::IRegistryPtr registry) : TInvokerWrapper(std::move(underlyingInvoker)) - , TInvokerProfileWrapper(std::move(registry), "/prioritized", tagSet) + , TInvokerProfilingWrapper(std::move(registry), "/prioritized", tagSet) { } using TInvokerWrapper::Invoke; @@ -318,19 +348,35 @@ private: guard.Release(); callback(); } - }; -IPrioritizedInvokerPtr CreatePrioritizedInvoker(IInvokerPtr underlyingInvoker, const NProfiling::TTagSet& tagSet, NProfiling::IRegistryPtr registry) +IPrioritizedInvokerPtr CreatePrioritizedInvoker( + IInvokerPtr underlyingInvoker) +{ + return New<TPrioritizedInvoker>( + std::move(underlyingInvoker)); +} + +IPrioritizedInvokerPtr CreatePrioritizedInvoker( + IInvokerPtr underlyingInvoker, + const NProfiling::TTagSet& tagSet, + NProfiling::IRegistryPtr registry) { - return New<TPrioritizedInvoker>(std::move(underlyingInvoker), std::move(tagSet), std::move(registry)); + return New<TPrioritizedInvoker>( + std::move(underlyingInvoker), + std::move(tagSet), + std::move(registry)); } -IPrioritizedInvokerPtr CreatePrioritizedInvoker(IInvokerPtr underlyingInvoker, const TString& invokerName, NProfiling::IRegistryPtr registry) +IPrioritizedInvokerPtr CreatePrioritizedInvoker( + IInvokerPtr underlyingInvoker, + const std::string& invokerName, + NProfiling::IRegistryPtr registry) { - NProfiling::TTagSet tagSet; - tagSet.AddTag(NProfiling::TTag("invoker", invokerName)); - return CreatePrioritizedInvoker(std::move(underlyingInvoker), std::move(tagSet), std::move(registry)); + return CreatePrioritizedInvoker( + std::move(underlyingInvoker), + NProfiling::TTagSet({{"invoker", invokerName}}), + std::move(registry)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/concurrency/action_queue.h b/yt/yt/core/concurrency/action_queue.h index 2167b46dab7..638167fb578 100644 --- a/yt/yt/core/concurrency/action_queue.h +++ b/yt/yt/core/concurrency/action_queue.h @@ -43,8 +43,11 @@ DEFINE_REFCOUNTED_TYPE(TActionQueue) //! #invokerName is used as a profiling tag. //! #registry is needed for testing purposes only. IInvokerPtr CreateSerializedInvoker( + IInvokerPtr underlyingInvoker); + +IInvokerPtr CreateSerializedInvoker( IInvokerPtr underlyingInvoker, - const TString& invokerName = "default", + const std::string& invokerName, NProfiling::IRegistryPtr registry = nullptr); IInvokerPtr CreateSerializedInvoker( @@ -59,8 +62,11 @@ IInvokerPtr CreateSerializedInvoker( //! #invokerName is used as a profiling tag. //! #registry is needed for testing purposes only. IPrioritizedInvokerPtr CreatePrioritizedInvoker( + IInvokerPtr underlyingInvoker); + +IPrioritizedInvokerPtr CreatePrioritizedInvoker( IInvokerPtr underlyingInvoker, - const TString& invokerName = "default", + const std::string& invokerName, NProfiling::IRegistryPtr registry = nullptr); IPrioritizedInvokerPtr CreatePrioritizedInvoker( @@ -71,7 +77,8 @@ IPrioritizedInvokerPtr CreatePrioritizedInvoker( //! Creates a wrapper around IInvoker that implements IPrioritizedInvoker but //! does not perform any actual reordering. Priorities passed to #IPrioritizedInvoker::Invoke //! are ignored. -IPrioritizedInvokerPtr CreateFakePrioritizedInvoker(IInvokerPtr underlyingInvoker); +IPrioritizedInvokerPtr CreateFakePrioritizedInvoker( + IInvokerPtr underlyingInvoker); //! Creates a wrapper around IPrioritizedInvoker turning it into a regular IInvoker. //! All callbacks are propagated with a given fixed #priority. diff --git a/yt/yt/core/misc/async_slru_cache-inl.h b/yt/yt/core/misc/async_slru_cache-inl.h index 5520cbfb2b4..49247cae60d 100644 --- a/yt/yt/core/misc/async_slru_cache-inl.h +++ b/yt/yt/core/misc/async_slru_cache-inl.h @@ -31,13 +31,20 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::TItem::GetValueFuture() const -> //////////////////////////////////////////////////////////////////////////////// +template <class TKey, class TValue, class THash> +TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostItem::TGhostItem(TKey key) + : Key(std::move(key)) +{ } + +//////////////////////////////////////////////////////////////////////////////// + template <class TItem, class TDerived> void TAsyncSlruCacheListManager<TItem, TDerived>::PushToYounger(TItem* item, i64 weight) { YT_ASSERT(item->Empty()); - YoungerLruList.PushFront(item); + YoungerLruList_.PushFront(item); item->CachedWeight = weight; - YoungerWeightCounter += weight; + YoungerWeightCounter_ += weight; AsDerived()->OnYoungerUpdated(1, weight); item->Younger = true; } @@ -47,12 +54,12 @@ void TAsyncSlruCacheListManager<TItem, TDerived>::MoveToYounger(TItem* item) { YT_ASSERT(!item->Empty()); item->Unlink(); - YoungerLruList.PushFront(item); + YoungerLruList_.PushFront(item); if (!item->Younger) { i64 weight = item->CachedWeight; - OlderWeightCounter -= weight; + OlderWeightCounter_ -= weight; AsDerived()->OnOlderUpdated(-1, -weight); - YoungerWeightCounter += weight; + YoungerWeightCounter_ += weight; AsDerived()->OnYoungerUpdated(1, weight); item->Younger = true; } @@ -63,12 +70,12 @@ void TAsyncSlruCacheListManager<TItem, TDerived>::MoveToOlder(TItem* item) { YT_ASSERT(!item->Empty()); item->Unlink(); - OlderLruList.PushFront(item); + OlderLruList_.PushFront(item); if (item->Younger) { i64 weight = item->CachedWeight; - YoungerWeightCounter -= weight; + YoungerWeightCounter_ -= weight; AsDerived()->OnYoungerUpdated(-1, -weight); - OlderWeightCounter += weight; + OlderWeightCounter_ += weight; AsDerived()->OnOlderUpdated(1, weight); item->Younger = false; } @@ -81,14 +88,14 @@ void TAsyncSlruCacheListManager<TItem, TDerived>::PopFromLists(TItem* item) return; } - YT_VERIFY(TouchBufferPosition.load() == 0); + YT_VERIFY(TouchBufferPosition_.load() == 0); i64 weight = item->CachedWeight; if (item->Younger) { - YoungerWeightCounter -= weight; + YoungerWeightCounter_ -= weight; AsDerived()->OnYoungerUpdated(-1, -weight); } else { - OlderWeightCounter -= weight; + OlderWeightCounter_ -= weight; AsDerived()->OnOlderUpdated(-1, -weight); } item->Unlink(); @@ -99,10 +106,10 @@ void TAsyncSlruCacheListManager<TItem, TDerived>::UpdateWeight(TItem* item, i64 { YT_VERIFY(!item->Empty()); if (item->Younger) { - YoungerWeightCounter += weightDelta; + YoungerWeightCounter_ += weightDelta; AsDerived()->OnYoungerUpdated(0, weightDelta); } else { - OlderWeightCounter += weightDelta; + OlderWeightCounter_ += weightDelta; AsDerived()->OnOlderUpdated(0, weightDelta); } item->CachedWeight += weightDelta; @@ -112,7 +119,7 @@ template <class TItem, class TDerived> void TAsyncSlruCacheListManager<TItem, TDerived>::UpdateCookie(TItem* item, i64 countDelta, i64 weightDelta) { YT_VERIFY(item->Empty()); - CookieWeightCounter += weightDelta; + CookieWeightCounter_ += weightDelta; item->CachedWeight += weightDelta; AsDerived()->OnCookieUpdated(countDelta, weightDelta); } @@ -121,17 +128,17 @@ template <class TItem, class TDerived> TIntrusiveListWithAutoDelete<TItem, TDelete> TAsyncSlruCacheListManager<TItem, TDerived>::TrimNoDelete() { // Move from older to younger. - auto capacity = Capacity.load(); - auto youngerSizeFraction = YoungerSizeFraction.load(); - while (!OlderLruList.Empty() && OlderWeightCounter > capacity * (1 - youngerSizeFraction)) { - auto* item = &*(--OlderLruList.End()); + auto capacity = Capacity_.load(); + auto youngerSizeFraction = YoungerSizeFraction_.load(); + while (!OlderLruList_.Empty() && OlderWeightCounter_ > capacity * (1 - youngerSizeFraction)) { + auto* item = &*(--OlderLruList_.End()); MoveToYounger(item); } // Evict from younger. TIntrusiveListWithAutoDelete<TItem, TDelete> evictedItems; - while (!YoungerLruList.Empty() && static_cast<i64>(YoungerWeightCounter + OlderWeightCounter + CookieWeightCounter) > capacity) { - auto* item = &*(--YoungerLruList.End()); + while (!YoungerLruList_.Empty() && static_cast<i64>(YoungerWeightCounter_ + OlderWeightCounter_ + CookieWeightCounter_) > capacity) { + auto* item = &*(--YoungerLruList_.End()); PopFromLists(item); evictedItems.PushBack(item); } @@ -146,8 +153,8 @@ bool TAsyncSlruCacheListManager<TItem, TDerived>::TouchItem(TItem* item) return false; } - int capacity = std::ssize(TouchBuffer); - int index = TouchBufferPosition++; + int capacity = std::ssize(TouchBuffer_); + int index = TouchBufferPosition_++; if (index >= capacity) { // Drop touch request due to buffer overflow. // NB: We still return false since the other thread is already responsible for @@ -155,31 +162,31 @@ bool TAsyncSlruCacheListManager<TItem, TDerived>::TouchItem(TItem* item) return false; } - TouchBuffer[index] = item; + TouchBuffer_[index] = item; return index == capacity - 1; } template <class TItem, class TDerived> void TAsyncSlruCacheListManager<TItem, TDerived>::DrainTouchBuffer() { - int count = std::min<int>(TouchBufferPosition.load(), std::ssize(TouchBuffer)); + int count = std::min<int>(TouchBufferPosition_.load(), std::ssize(TouchBuffer_)); for (int index = 0; index < count; ++index) { - MoveToOlder(TouchBuffer[index]); + MoveToOlder(TouchBuffer_[index]); } - TouchBufferPosition = 0; + TouchBufferPosition_ = 0; } template <class TItem, class TDerived> void TAsyncSlruCacheListManager<TItem, TDerived>::Reconfigure(i64 capacity, double youngerSizeFraction) { - Capacity = capacity; - YoungerSizeFraction = youngerSizeFraction; + Capacity_.store(capacity); + YoungerSizeFraction_.store(youngerSizeFraction); } template <class TItem, class TDerived> void TAsyncSlruCacheListManager<TItem, TDerived>::SetTouchBufferCapacity(i64 touchBufferCapacity) { - TouchBuffer.resize(touchBufferCapacity); + TouchBuffer_.resize(touchBufferCapacity); } template <class TItem, class TDerived> @@ -211,7 +218,7 @@ TIntrusivePtr<typename TAsyncCacheValueBase<TKey, TValue, THash>::TCache> TAsync template <class TKey, class TValue, class THash> void TAsyncCacheValueBase<TKey, TValue, THash>::SetCache(TWeakPtr<TCache> cache) { - Cache_.Store(cache); + Cache_.Store(std::move(cache)); } template <class TKey, class TValue, class THash> @@ -429,7 +436,7 @@ TAsyncSlruCacheBase<TKey, TValue, THash>::GetAll() auto readerGuard = ReaderGuard(shard.SpinLock); for (const auto& [key, rawValue] : shard.ValueMap) { if (auto value = DangerousGetPtr<TValue>(rawValue)) { - result.push_back(value); + result.push_back(std::move(value)); } } } @@ -553,7 +560,7 @@ TAsyncSlruCacheBase<TKey, TValue, THash>::DoLookup(TShard* shard, const TKey& ke auto valueFuture = item->GetValueFuture(); - YT_VERIFY(itemMap.emplace(key, item).second); + EmplaceOrCrash(itemMap, key, item); ++Size_; i64 weight = GetWeight(item->Value); @@ -601,9 +608,9 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::BeginInsert(const TKey& key, i64 return TInsertCookie( key, - nullptr, + /*cache*/ nullptr, std::move(valueFuture), - false); + /*active*/ false); } while (true) { @@ -651,7 +658,7 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::BeginInsert(const TKey& key, i64 key, nullptr, std::move(valueFuture), - false); + /*active*/ false); } auto valueIt = valueMap.find(key); @@ -659,7 +666,7 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::BeginInsert(const TKey& key, i64 auto* item = new TItem(); auto valueFuture = item->GetValueFuture(); - YT_VERIFY(itemMap.emplace(key, item).second); + EmplaceOrCrash(itemMap, key, item); ++Size_; Counters_.MissedCounter.Increment(); @@ -674,9 +681,9 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::BeginInsert(const TKey& key, i64 auto insertCookie = TInsertCookie( key, - this, + /*cache*/ this, std::move(valueFuture), - true); + /*active*/ true); if (GhostCachesEnabled_.load()) { insertCookie.InsertedIntoSmallGhost_ = shard->SmallGhost.BeginInsert(key, cookieWeight); @@ -690,7 +697,7 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::BeginInsert(const TKey& key, i64 auto* item = new TItem(value); value->Item_ = item; - YT_VERIFY(itemMap.emplace(key, item).second); + EmplaceOrCrash(itemMap, key, item); ++Size_; i64 weight = GetWeight(item->Value); @@ -710,9 +717,9 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::BeginInsert(const TKey& key, i64 return TInsertCookie( key, - nullptr, + /*cache*/ nullptr, MakeFuture(value), - false); + /*active*/ false); } // Back off. @@ -727,7 +734,7 @@ template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::UpdateCookieWeight(const TInsertCookie& insertCookie, i64 newWeight) { YT_VERIFY(newWeight >= 0); - auto key = insertCookie.GetKey(); + const auto& key = insertCookie.GetKey(); auto* shard = GetShardByKey(key); @@ -765,7 +772,7 @@ template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::EndInsert(const TInsertCookie& insertCookie, TValuePtr value) { YT_VERIFY(value); - auto key = value->GetKey(); + const auto& key = value->GetKey(); auto* shard = GetShardByKey(key); @@ -780,7 +787,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::EndInsert(const TInsertCookie& in value->Item_ = item; auto promise = item->ValuePromise; - YT_VERIFY(shard->ValueMap.emplace(key, value.Get()).second); + EmplaceOrCrash(shard->ValueMap, key, value.Get()); auto cookieWeight = item->CachedWeight; shard->UpdateCookie(item, /*countDelta*/ -1, -cookieWeight); @@ -828,11 +835,9 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::CancelInsert(const TInsertCookie& shard->DrainTouchBuffer(); auto& itemMap = shard->ItemMap; - auto itemIt = itemMap.find(key); - YT_VERIFY(itemIt != itemMap.end()); + auto itemIt = GetIteratorOrCrash(itemMap, key); auto* item = itemIt->second; - auto promise = item->ValuePromise; itemMap.erase(itemIt); --Size_; @@ -842,6 +847,8 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::CancelInsert(const TInsertCookie& auto cookieWeight = item->CachedWeight; shard->UpdateCookie(item, /*countDelta*/ -1, -cookieWeight); + auto promise = std::move(item->ValuePromise); + delete item; guard.Release(); @@ -870,7 +877,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::Unregister(const TKey& key) template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TryRemove(const TKey& key, bool forbidResurrection) { - DoTryRemove(key, nullptr, forbidResurrection); + DoTryRemove(key, /*value*/ nullptr, forbidResurrection); } template <class TKey, class TValue, class THash> @@ -919,8 +926,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::DoTryRemove( } auto* item = itemIt->second; - auto actualValue = item->Value; - if (!actualValue) { + if (!item->Value) { return; } @@ -929,16 +935,45 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::DoTryRemove( shard->PopFromLists(item); - YT_VERIFY(actualValue->Item_ == item); - actualValue->Item_ = nullptr; + YT_VERIFY(item->Value->Item_ == item); + item->Value->Item_ = nullptr; - delete item; - - OnRemoved(actualValue); + OnRemoved(item->Value); // It is necessary to remove the guard before the actual value is destroyed. // Otherwise, it will lead to a deadlock in unregister. guard.Release(); + + delete item; +} + +template <class TKey, class TValue, class THash> +std::vector<typename TAsyncSlruCacheBase<TKey, TValue, THash>::TValuePtr> +TAsyncSlruCacheBase<TKey, TValue, THash>::TrimWithNotify( + TShard* shard, + NThreading::TWriterGuard<NThreading::TReaderWriterSpinLock>& guard, + const TValuePtr& insertedValue, + i64 weightDelta) +{ + YT_ASSERT_WRITER_SPINLOCK_AFFINITY(shard->SpinLock); + + auto evictedItems = shard->TrimNoDelete(); + auto evictedValues = shard->Trim(std::move(evictedItems)); + + if (weightDelta != 0) { + OnWeightUpdated(weightDelta); + } + if (insertedValue) { + OnAdded(insertedValue); + } + for (const auto& value : evictedValues) { + OnRemoved(value); + } + + // NB. Evicted items must die outside of critical section. + guard.Release(); + + return evictedValues; } template <class TKey, class TValue, class THash> @@ -1032,7 +1067,7 @@ auto TAsyncSlruCacheBase<TKey, TValue, THash>::GetLargeGhostCounters() const -> template <class TKey, class TValue, class THash> bool TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::DoLookup(const TKey& key, bool allowAsyncHits) { - auto readerGuard = ReaderGuard(SpinLock); + auto readerGuard = ReaderGuard(SpinLock_); auto itemIt = ItemMap_.find(key); if (itemIt == ItemMap_.end()) { @@ -1057,7 +1092,7 @@ bool TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::DoLookup(const TKey& readerGuard.Release(); if (needToDrain) { - auto writerGuard = WriterGuard(SpinLock); + auto writerGuard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); } @@ -1067,7 +1102,7 @@ bool TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::DoLookup(const TKey& template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Find(const TKey& key) { - if (!DoLookup(key, false)) { + if (!DoLookup(key, /*allowAsyncHits*/ false)) { Counters_->MissedCounter.Increment(); } } @@ -1075,7 +1110,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Find(const TKey& key template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Lookup(const TKey& key) { - if (!DoLookup(key, true)) { + if (!DoLookup(key, /*allowAsyncHits*/ true)) { Counters_->MissedCounter.Increment(); } } @@ -1083,7 +1118,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Lookup(const TKey& k template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Touch(const TValuePtr& value) { - auto readerGuard = ReaderGuard(SpinLock); + auto readerGuard = ReaderGuard(SpinLock_); if (!value) { return; @@ -1100,7 +1135,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Touch(const TValuePt readerGuard.Release(); if (needToDrain) { - auto writerGuard = WriterGuard(SpinLock); + auto writerGuard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); } } @@ -1112,7 +1147,7 @@ bool TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::BeginInsert(const TK return false; } - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); @@ -1134,7 +1169,7 @@ bool TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::BeginInsert(const TK auto* item = new TGhostItem(key); Counters_->MissedCounter.Increment(); - YT_VERIFY(ItemMap_.emplace(key, item).second); + EmplaceOrCrash(ItemMap_, key, item); this->UpdateCookie(item, /*countDelta*/ 1, cookieWeight); if (cookieWeight > 0) { @@ -1148,12 +1183,11 @@ bool TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::BeginInsert(const TK template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::CancelInsert(const TKey& key) { - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); - auto itemIt = ItemMap_.find(key); - YT_VERIFY(itemIt != ItemMap_.end()); + auto itemIt = GetIteratorOrCrash(ItemMap_, key); auto* item = itemIt->second; YT_VERIFY(!item->Inserted); @@ -1162,16 +1196,18 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::CancelInsert(const T this->UpdateCookie(item, /*countDelta*/ -1, -item->CachedWeight); + guard.Release(); + delete item; } template <class TKey, class TValue, class THash> -void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::EndInsert(const TValuePtr& value, i64 weight) +void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::EndInsert(TValuePtr value, i64 weight) { YT_VERIFY(value); - auto key = value->GetKey(); + const auto& key = value->GetKey(); - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); @@ -1180,7 +1216,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::EndInsert(const TVal YT_VERIFY(!item->Inserted); this->UpdateCookie(item, /*countDelta*/ -1, -item->CachedWeight); - item->Value = value; + item->Value = std::move(value); item->Inserted = true; this->PushToYounger(item, weight); @@ -1196,9 +1232,9 @@ template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Resurrect(const TValuePtr& value, i64 weight) { YT_VERIFY(value); - auto key = value->GetKey(); + const auto& key = value->GetKey(); - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); @@ -1211,7 +1247,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Resurrect(const TVal item->Value = value; item->Inserted = true; - YT_VERIFY(ItemMap_.emplace(key, item).second); + EmplaceOrCrash(ItemMap_, key, item); this->PushToYounger(item, weight); @@ -1225,7 +1261,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Resurrect(const TVal template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::TryRemove(const TKey& key, const TValuePtr& value) { - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); @@ -1238,28 +1274,32 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::TryRemove(const TKey if (!item->Inserted) { return; } - auto actualValue = item->Value.Lock(); - // If value is null, it means that we don't care about the removed value and remove just by key. - // If actualValue is null, then it refers to the value removed from the main cache, and always - // doesn't match our provided value. Otherwise, just compare the values. Note that the condition - // can be simplified just to (value && value != actualValue), but is retained as-is to make the - // intention more clear. - if (value && (!actualValue || value != actualValue)) { - return; + + { + auto actualValue = item->Value.Lock(); + // If value is null, it means that we don't care about the removed value and remove just by key. + // If actualValue is null, then it refers to the value removed from the main cache, and always + // doesn't match our provided value. Otherwise, just compare the values. Note that the condition + // can be simplified just to (value && value != actualValue), but is retained as-is to make the + // intention more clear. + if (value && (!actualValue || value != actualValue)) { + return; + } } - actualValue.Reset(); ItemMap_.erase(itemIt); this->PopFromLists(item); + guard.Release(); + delete item; } template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::UpdateWeight(const TKey& key, i64 newWeight) { - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); this->DrainTouchBuffer(); @@ -1290,7 +1330,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::UpdateWeight(const T template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::UpdateCookieWeight(const TKey& key, i64 newWeight) { - auto guard = WriterGuard(SpinLock); + auto guard = WriterGuard(SpinLock_); auto itemIt = ItemMap_.find(key); if (itemIt == ItemMap_.end()) { @@ -1314,7 +1354,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::UpdateCookieWeight(c template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Reconfigure(i64 capacity, double youngerSizeFraction) { - auto writerGuard = WriterGuard(SpinLock); + auto writerGuard = WriterGuard(SpinLock_); TAsyncSlruCacheListManager<TGhostItem, TGhostShard>::Reconfigure(capacity, youngerSizeFraction); this->DrainTouchBuffer(); Trim(writerGuard); @@ -1325,7 +1365,7 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Trim(NThreading::TWr { auto evictedItems = this->TrimNoDelete(); for (const auto& item : evictedItems) { - YT_VERIFY(ItemMap_.erase(item.Key) == 1); + EraseOrCrash(ItemMap_, item.Key); } // NB. Evicted items must die outside of critical section. @@ -1336,13 +1376,15 @@ void TAsyncSlruCacheBase<TKey, TValue, THash>::TGhostShard::Trim(NThreading::TWr template <class TKey, class TValue, class THash> std::vector<typename TAsyncSlruCacheBase<TKey, TValue, THash>::TValuePtr> -TAsyncSlruCacheBase<TKey, TValue, THash>::TShard::Trim(const TIntrusiveListWithAutoDelete<TItem, TDelete>& evictedItems) +TAsyncSlruCacheBase<TKey, TValue, THash>::TShard::Trim(TIntrusiveListWithAutoDelete<TItem, TDelete>&& evictedItems) { Parent->Size_ -= static_cast<int>(evictedItems.Size()); std::vector<TValuePtr> evictedValues; - for (const auto& item : evictedItems) { - auto value = item.Value; + evictedValues.reserve(evictedItems.Size()); + + for (auto& item : evictedItems) { + auto& value = item.Value; EraseOrCrash(ItemMap, value->GetKey()); @@ -1361,35 +1403,6 @@ TAsyncSlruCacheBase<TKey, TValue, THash>::TShard::Trim(const TIntrusiveListWithA } template <class TKey, class TValue, class THash> -std::vector<typename TAsyncSlruCacheBase<TKey, TValue, THash>::TValuePtr> -TAsyncSlruCacheBase<TKey, TValue, THash>::TrimWithNotify( - TShard* shard, - NThreading::TWriterGuard<NThreading::TReaderWriterSpinLock>& guard, - const TValuePtr& insertedValue, - i64 weightDelta) -{ - YT_ASSERT_SPINLOCK_AFFINITY(shard->SpinLock); - - auto evictedItems = shard->TrimNoDelete(); - auto evictedValues = shard->Trim(evictedItems); - - if (weightDelta != 0) { - OnWeightUpdated(weightDelta); - } - if (insertedValue) { - OnAdded(insertedValue); - } - for (const auto& value : evictedValues) { - OnRemoved(value); - } - - // NB. Evicted items must die outside of critical section. - guard.Release(); - - return evictedValues; -} - -template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TShard::OnYoungerUpdated(i64 deltaCount, i64 deltaWeight) { Parent->YoungerSizeCounter_ += deltaCount; @@ -1434,7 +1447,8 @@ TAsyncSlruCacheBase<TKey, TValue, THash>::TInsertCookie::~TInsertCookie() } template <class TKey, class TValue, class THash> -typename TAsyncSlruCacheBase<TKey, TValue, THash>::TInsertCookie& TAsyncSlruCacheBase<TKey, TValue, THash>::TInsertCookie::operator =(TInsertCookie&& other) +typename TAsyncSlruCacheBase<TKey, TValue, THash>::TInsertCookie& +TAsyncSlruCacheBase<TKey, TValue, THash>::TInsertCookie::operator = (TInsertCookie&& other) { if (this != &other) { Abort(); @@ -1488,7 +1502,7 @@ template <class TKey, class TValue, class THash> void TAsyncSlruCacheBase<TKey, TValue, THash>::TInsertCookie::EndInsert(TValuePtr value) { if (Active_.exchange(false)) { - Cache_->EndInsert(*this, value); + Cache_->EndInsert(*this, std::move(value)); } } diff --git a/yt/yt/core/misc/async_slru_cache.h b/yt/yt/core/misc/async_slru_cache.h index ac95f4259d3..889d063d201 100644 --- a/yt/yt/core/misc/async_slru_cache.h +++ b/yt/yt/core/misc/async_slru_cache.h @@ -117,18 +117,18 @@ protected: void OnCookieUpdated(i64 deltaCount, i64 deltaWeight); private: - TIntrusiveListWithAutoDelete<TItem, TDelete> YoungerLruList; - TIntrusiveListWithAutoDelete<TItem, TDelete> OlderLruList; + TIntrusiveListWithAutoDelete<TItem, TDelete> YoungerLruList_; + TIntrusiveListWithAutoDelete<TItem, TDelete> OlderLruList_; - std::vector<TItem*> TouchBuffer; - std::atomic<int> TouchBufferPosition = 0; + std::vector<TItem*> TouchBuffer_; + std::atomic<int> TouchBufferPosition_ = 0; - i64 YoungerWeightCounter = 0; - i64 OlderWeightCounter = 0; - i64 CookieWeightCounter = 0; + i64 YoungerWeightCounter_ = 0; + i64 OlderWeightCounter_ = 0; + i64 CookieWeightCounter_ = 0; - std::atomic<i64> Capacity; - std::atomic<double> YoungerSizeFraction; + std::atomic<i64> Capacity_; + std::atomic<double> YoungerSizeFraction_; }; //////////////////////////////////////////////////////////////////////////////// @@ -201,8 +201,8 @@ public: // NB: Shards store reference to the cache, so the cache cannot be simply copied or moved. TAsyncSlruCacheBase(const TAsyncSlruCacheBase&) = delete; TAsyncSlruCacheBase(TAsyncSlruCacheBase&&) = delete; - TAsyncSlruCacheBase& operator=(const TAsyncSlruCacheBase&) = delete; - TAsyncSlruCacheBase& operator=(TAsyncSlruCacheBase&&) = delete; + TAsyncSlruCacheBase& operator = (const TAsyncSlruCacheBase&) = delete; + TAsyncSlruCacheBase& operator = (TAsyncSlruCacheBase&&) = delete; int GetSize() const; i64 GetCapacity() const; @@ -298,9 +298,7 @@ private: struct TGhostItem : public TIntrusiveListItem<TGhostItem> { - explicit TGhostItem(TKey key) - : Key(std::move(key)) - { } + explicit TGhostItem(TKey key); TKey Key; //! The value associated with this item. If Inserted == true and Value is null, then we refer to some @@ -348,7 +346,7 @@ private: //! called with the same key. Do not call CancelInsert() or EndInsert() without matching BeginInsert(). bool BeginInsert(const TKey& key, i64 cookieWeight); void CancelInsert(const TKey& key); - void EndInsert(const TValuePtr& value, i64 weight); + void EndInsert(TValuePtr value, i64 weight); //! Inserts the value back to the cache immediately. Called when the value is resurected in the //! main cache. @@ -371,7 +369,7 @@ private: private: friend class TAsyncSlruCacheListManager<TGhostItem, TGhostShard>; - YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, SpinLock); + YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, SpinLock_); THashMap<TKey, TGhostItem*, THash> ItemMap_; @@ -421,7 +419,7 @@ private: TGhostShard LargeGhost; //! Returns the list of evicted items. - std::vector<TValuePtr> Trim(const TIntrusiveListWithAutoDelete<TItem, TDelete>& evictedItems); + std::vector<TValuePtr> Trim(TIntrusiveListWithAutoDelete<TItem, TDelete>&& evictedItems); protected: void OnYoungerUpdated(i64 deltaCount, i64 deltaWeight); diff --git a/yt/yt/library/formats/arrow_writer.cpp b/yt/yt/library/formats/arrow_writer.cpp index 377ecc8634a..5c5483491a2 100644 --- a/yt/yt/library/formats/arrow_writer.cpp +++ b/yt/yt/library/formats/arrow_writer.cpp @@ -171,6 +171,10 @@ int ExtractTableIndexFromColumn(const TBatchColumn* column) const auto* valueColumn = column->Rle->ValueColumn; auto values = valueColumn->GetTypedValues<ui64>(); + TRef nullBitmap; + if (valueColumn->NullBitmap) { + nullBitmap = valueColumn->NullBitmap->Data; + } // Expecting only one element. YT_VERIFY(values.size() == 1); @@ -187,12 +191,14 @@ int ExtractTableIndexFromColumn(const TBatchColumn* column) valueColumn->Values->ZigZagEncoded, TRange<ui32>(), rleIndexes, + nullBitmap, [&] (auto index) { return values[index]; }, [&] (auto value) { tableIndex = value; }); + return tableIndex; } @@ -497,6 +503,11 @@ void SerializeIntegerColumn( auto startIndex = column->StartIndex; + TRef nullBitmap; + if (valueColumn->NullBitmap) { + nullBitmap = valueColumn->NullBitmap->Data; + } + switch (simpleType) { #define XX(cppType, ytType) \ case ESimpleLogicalValueType::ytType: { \ @@ -509,6 +520,7 @@ void SerializeIntegerColumn( valueColumn->Values->ZigZagEncoded, \ TRange<ui32>(), \ rleIndexes, \ + nullBitmap, \ [&] (auto index) { \ return values[index]; \ }, \ @@ -565,6 +577,11 @@ void SerializeDateColumn( ? column->GetTypedValues<ui64>() : TRange<ui64>(); + TRef nullBitmap; + if (valueColumn->NullBitmap) { + nullBitmap = valueColumn->NullBitmap->Data; + } + auto startIndex = column->StartIndex; auto dstValues = GetTypedValues<i32>(dstRef); @@ -576,6 +593,7 @@ void SerializeDateColumn( valueColumn->Values->ZigZagEncoded, TRange<ui32>(), rleIndexes, + nullBitmap, [&] (auto index) { return values[index]; }, @@ -616,6 +634,11 @@ void SerializeDatetimeColumn( ? column->GetTypedValues<ui64>() : TRange<ui64>(); + TRef nullBitmap; + if (valueColumn->NullBitmap) { + nullBitmap = valueColumn->NullBitmap->Data; + } + auto startIndex = column->StartIndex; auto dstValues = GetTypedValues<i64>(dstRef); @@ -627,6 +650,7 @@ void SerializeDatetimeColumn( valueColumn->Values->ZigZagEncoded, TRange<ui32>(), rleIndexes, + nullBitmap, [&] (auto index) { return values[index]; }, @@ -666,6 +690,11 @@ void SerializeTimestampColumn( ? column->GetTypedValues<ui64>() : TRange<ui64>(); + TRef nullBitmap; + if (valueColumn->NullBitmap) { + nullBitmap = valueColumn->NullBitmap->Data; + } + auto startIndex = column->StartIndex; auto dstValues = GetTypedValues<i64>(dstRef); @@ -677,6 +706,7 @@ void SerializeTimestampColumn( valueColumn->Values->ZigZagEncoded, TRange<ui32>(), rleIndexes, + nullBitmap, [&] (auto index) { return values[index]; }, |