diff options
author | vvvv <vvvv@ydb.tech> | 2023-09-14 14:45:19 +0300 |
---|---|---|
committer | vvvv <vvvv@ydb.tech> | 2023-09-14 15:01:39 +0300 |
commit | 13d1416652bad5fab73a207b2beeb6b7dfde33f0 (patch) | |
tree | 85538520b6979bb0fb098e3744a98e5abacd7844 | |
parent | c3a576c8d16fbd457f297937ca83c88280bbffda (diff) | |
download | ydb-13d1416652bad5fab73a207b2beeb6b7dfde33f0.tar.gz |
avoid hardcode for bit, pass typmod to registered casts
-rw-r--r-- | ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp index a8c02a9292..11d5c6842c 100644 --- a/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp +++ b/ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp @@ -2486,9 +2486,7 @@ TMkqlCommonCallableCompiler::TShared::TShared() { auto typeMod1 = typeMod; if (node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "interval" && - node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "_interval" && - node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "bit" && - node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "_bit" ) { + node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "_interval") { typeMod1 = TRuntimeNode(); } @@ -2570,12 +2568,24 @@ TMkqlCommonCallableCompiler::TShared::TShared() { auto typeMod1 = typeMod; if (node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "interval" && - node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "_interval" && - node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "bit" && - node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "_bit") { + node.GetTypeAnn()->Cast<TPgExprType>()->GetName() != "_interval") { typeMod1 = TRuntimeNode(); } + if (node.Head().GetTypeAnn()->GetKind() != ETypeAnnotationKind::Null) { + auto sourceTypeId = node.Head().GetTypeAnn()->Cast<TPgExprType>()->GetId(); + auto targetTypeId = node.GetTypeAnn()->Cast<TPgExprType>()->GetId(); + const auto& sourceTypeDesc = NPg::LookupType(sourceTypeId); + const auto& targetTypeDesc = NPg::LookupType(targetTypeId); + const bool isSourceArray = sourceTypeDesc.TypeId == sourceTypeDesc.ArrayTypeId; + const bool isTargetArray = targetTypeDesc.TypeId == targetTypeDesc.ArrayTypeId; + if (isSourceArray == isTargetArray && NPg::HasCast( + isSourceArray ? sourceTypeDesc.ElementTypeId : sourceTypeId, + isTargetArray ? targetTypeDesc.ElementTypeId : targetTypeId)) { + typeMod1 = typeMod; + } + } + auto cast = ctx.ProgramBuilder.PgCast(input, returnType, typeMod1); if (node.ChildrenSize() >= 3) { return ctx.ProgramBuilder.PgCast(cast, returnType, typeMod); |