diff options
author | galaxycrab <UgnineSirdis@ydb.tech> | 2023-11-16 14:17:24 +0300 |
---|---|---|
committer | galaxycrab <UgnineSirdis@ydb.tech> | 2023-11-16 15:28:47 +0300 |
commit | 8f96a1df4534e534a309273c897a91eb8f010343 (patch) | |
tree | f46210999a49696f9f1fb93b6b1cf9e347141e12 | |
parent | c4cee9c5360349a737210e2904e11dd8ffe48059 (diff) | |
download | ydb-8f96a1df4534e534a309273c897a91eb8f010343.tar.gz |
Support SQL rendering in Go service
23 files changed, 1409 insertions, 197 deletions
diff --git a/ydb/library/yql/providers/generic/connector/app/server/clickhouse/sql_formatter.go b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/sql_formatter.go new file mode 100644 index 0000000000..4f87ce1c3a --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/sql_formatter.go @@ -0,0 +1,103 @@ +package clickhouse + +import ( + "fmt" + "strings" + + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/app/server/utils" + api_service_protos "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" +) + +type sqlFormatter struct { +} + +type predicateBuilderFeatures struct { +} + +func (f predicateBuilderFeatures) SupportsType(typeID Ydb.Type_PrimitiveTypeId) bool { + switch typeID { + case Ydb.Type_BOOL: + return true + case Ydb.Type_INT8: + return true + case Ydb.Type_UINT8: + return true + case Ydb.Type_INT16: + return true + case Ydb.Type_UINT16: + return true + case Ydb.Type_INT32: + return true + case Ydb.Type_UINT32: + return true + case Ydb.Type_INT64: + return true + case Ydb.Type_UINT64: + return true + case Ydb.Type_FLOAT: + return true + case Ydb.Type_DOUBLE: + return true + default: + return false + } +} + +func (f predicateBuilderFeatures) SupportsConstantValueExpression(t *Ydb.Type) bool { + switch v := t.Type.(type) { + case *Ydb.Type_TypeId: + return f.SupportsType(v.TypeId) + case *Ydb.Type_OptionalType: + return f.SupportsConstantValueExpression(v.OptionalType.Item) + default: + return false + } +} + +func (f predicateBuilderFeatures) SupportsExpression(expression *api_service_protos.TExpression) bool { + switch e := expression.Payload.(type) { + case *api_service_protos.TExpression_Column: + return true + case *api_service_protos.TExpression_TypedValue: + return f.SupportsConstantValueExpression(e.TypedValue.Type) + case *api_service_protos.TExpression_ArithmeticalExpression: + return false + case *api_service_protos.TExpression_Null: + return true + default: + return false + } +} + +func (formatter sqlFormatter) FormatRead(logger log.Logger, selectReq *api_service_protos.TSelect) (string, error) { + var sb strings.Builder + + selectPart, err := utils.FormatSelectColumns(selectReq.What, selectReq.GetFrom().GetTable(), true) + if err != nil { + return "", fmt.Errorf("failed to format select statement: %w", err) + } + + sb.WriteString(selectPart) + + if selectReq.Where != nil { + var features predicateBuilderFeatures + + clause, err := utils.FormatWhereClause(selectReq.Where, features) + if err != nil { + logger.Error("Failed to format WHERE clause", log.Error(err), log.String("where", selectReq.Where.String())) + } else { + sb.WriteString(" ") + sb.WriteString(clause) + } + } + + query := sb.String() + + return query, nil +} + +func NewSQLFormatter() utils.SQLFormatter { + return sqlFormatter{} +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/clickhouse/sql_formatter_test.go b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/sql_formatter_test.go new file mode 100644 index 0000000000..5b6f264d98 --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/sql_formatter_test.go @@ -0,0 +1,343 @@ +package clickhouse + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + ydb "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/app/server/utils" + api "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" +) + +func TestSQLFormatter(t *testing.T) { + type testCase struct { + testName string + selectReq *api.TSelect + output string + err error + } + + logger := utils.NewTestLogger(t) + formatter := NewSQLFormatter() + + tcs := []testCase{ + { + testName: "empty_table_name", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "", + }, + What: &api.TSelect_TWhat{}, + }, + output: "", + err: utils.ErrEmptyTableName, + }, + { + testName: "empty_no columns", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: &api.TSelect_TWhat{}, + }, + output: "SELECT 0 FROM tab", + err: nil, + }, + { + testName: "select_col", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: &api.TSelect_TWhat{ + Items: []*api.TSelect_TWhat_TItem{ + &api.TSelect_TWhat_TItem{ + Payload: &api.TSelect_TWhat_TItem_Column{ + Column: &ydb.Column{ + Name: "col", + Type: utils.NewPrimitiveType(ydb.Type_INT32), + }, + }, + }, + }, + }, + }, + output: "SELECT col FROM tab", + err: nil, + }, + { + testName: "is_null", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_IsNull{ + IsNull: &api.TPredicate_TIsNull{ + Value: utils.NewColumnExpression("col1"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE (col1 IS NULL)", + err: nil, + }, + { + testName: "is_not_null", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_IsNotNull{ + IsNotNull: &api.TPredicate_TIsNotNull{ + Value: utils.NewColumnExpression("col2"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE (col2 IS NOT NULL)", + err: nil, + }, + { + testName: "bool_column", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_BoolExpression{ + BoolExpression: &api.TPredicate_TBoolExpression{ + Value: utils.NewColumnExpression("col2"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE col2", + err: nil, + }, + { + testName: "complex_filter", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Disjunction{ + Disjunction: &api.TPredicate_TDisjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Negation{ + Negation: &api.TPredicate_TNegation{ + Operand: &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_LE, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewInt32ValueExpression(42), + }, + }, + }, + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_Conjunction{ + Conjunction: &api.TPredicate_TConjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_NE, + LeftValue: utils.NewColumnExpression("col1"), + RightValue: utils.NewUint64ValueExpression(0), + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_IsNull{ + IsNull: &api.TPredicate_TIsNull{ + Value: utils.NewColumnExpression("col3"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE ((NOT (col2 <= 42)) OR ((col1 <> 0) AND (col3 IS NULL)))", + err: nil, + }, + { + testName: "unsupported_predicate", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Between{ + Between: &api.TPredicate_TBetween{ + Value: utils.NewColumnExpression("col2"), + Least: utils.NewColumnExpression("col1"), + Greatest: utils.NewColumnExpression("col3"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab", + err: nil, + }, + { + testName: "unsupported_type", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewTextValueExpression("text"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab", + err: nil, + }, + { + testName: "partial_filter_removes_and", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Conjunction{ + Conjunction: &api.TPredicate_TConjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col1"), + RightValue: utils.NewInt32ValueExpression(32), + }, + }, + }, + &api.TPredicate{ + // Not supported + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewTextValueExpression("text"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE (col1 = 32)", + err: nil, + }, + { + testName: "partial_filter", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Conjunction{ + Conjunction: &api.TPredicate_TConjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col1"), + RightValue: utils.NewInt32ValueExpression(32), + }, + }, + }, + &api.TPredicate{ + // Not supported + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewTextValueExpression("text"), + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_IsNull{ + IsNull: &api.TPredicate_TIsNull{ + Value: utils.NewColumnExpression("col3"), + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_IsNotNull{ + IsNotNull: &api.TPredicate_TIsNotNull{ + Value: utils.NewColumnExpression("col4"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE ((col1 = 32) AND (col3 IS NULL) AND (col4 IS NOT NULL))", + err: nil, + }, + } + + for _, tc := range tcs { + tc := tc + + t.Run(tc.testName, func(t *testing.T) { + output, err := formatter.FormatRead(logger, tc.selectReq) + require.Equal(t, tc.output, output) + + if tc.err != nil { + require.True(t, errors.Is(err, tc.err)) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ut/ya.make b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ut/ya.make new file mode 100644 index 0000000000..154f25de98 --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ut/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(ydb/library/yql/providers/generic/connector/app/server/clickhouse) + +SIZE(SMALL) + +END() diff --git a/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ya.make b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ya.make index 6925db2cf7..a08bd57be7 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ya.make +++ b/ydb/library/yql/providers/generic/connector/app/server/clickhouse/ya.make @@ -4,7 +4,16 @@ SRCS( connection_manager.go doc.go query_executor.go + sql_formatter.go type_mapper.go ) +GO_TEST_SRCS( + sql_formatter_test.go +) + END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/ydb/library/yql/providers/generic/connector/app/server/postgresql/sql_formatter.go b/ydb/library/yql/providers/generic/connector/app/server/postgresql/sql_formatter.go new file mode 100644 index 0000000000..fcaad0520b --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/postgresql/sql_formatter.go @@ -0,0 +1,103 @@ +package postgresql + +import ( + "fmt" + "strings" + + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/app/server/utils" + api_service_protos "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" +) + +type sqlFormatter struct { +} + +type predicateBuilderFeatures struct { +} + +func (f predicateBuilderFeatures) SupportsType(typeID Ydb.Type_PrimitiveTypeId) bool { + switch typeID { + case Ydb.Type_BOOL: + return true + case Ydb.Type_INT8: + return true + case Ydb.Type_UINT8: + return true + case Ydb.Type_INT16: + return true + case Ydb.Type_UINT16: + return true + case Ydb.Type_INT32: + return true + case Ydb.Type_UINT32: + return true + case Ydb.Type_INT64: + return true + case Ydb.Type_UINT64: + return true + case Ydb.Type_FLOAT: + return true + case Ydb.Type_DOUBLE: + return true + default: + return false + } +} + +func (f predicateBuilderFeatures) SupportsConstantValueExpression(t *Ydb.Type) bool { + switch v := t.Type.(type) { + case *Ydb.Type_TypeId: + return f.SupportsType(v.TypeId) + case *Ydb.Type_OptionalType: + return f.SupportsConstantValueExpression(v.OptionalType.Item) + default: + return false + } +} + +func (f predicateBuilderFeatures) SupportsExpression(expression *api_service_protos.TExpression) bool { + switch e := expression.Payload.(type) { + case *api_service_protos.TExpression_Column: + return true + case *api_service_protos.TExpression_TypedValue: + return f.SupportsConstantValueExpression(e.TypedValue.Type) + case *api_service_protos.TExpression_ArithmeticalExpression: + return false + case *api_service_protos.TExpression_Null: + return true + default: + return false + } +} + +func (formatter sqlFormatter) FormatRead(logger log.Logger, selectReq *api_service_protos.TSelect) (string, error) { + var sb strings.Builder + + selectPart, err := utils.FormatSelectColumns(selectReq.What, selectReq.GetFrom().GetTable(), true) + if err != nil { + return "", fmt.Errorf("failed to format select statement: %w", err) + } + + sb.WriteString(selectPart) + + if selectReq.Where != nil { + var features predicateBuilderFeatures + + clause, err := utils.FormatWhereClause(selectReq.Where, features) + if err != nil { + logger.Error("Failed to format WHERE clause", log.Error(err), log.String("where", selectReq.Where.String())) + } else { + sb.WriteString(" ") + sb.WriteString(clause) + } + } + + query := sb.String() + + return query, nil +} + +func NewSQLFormatter() utils.SQLFormatter { + return sqlFormatter{} +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/postgresql/sql_formatter_test.go b/ydb/library/yql/providers/generic/connector/app/server/postgresql/sql_formatter_test.go new file mode 100644 index 0000000000..16cd042fc4 --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/postgresql/sql_formatter_test.go @@ -0,0 +1,343 @@ +package postgresql + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + ydb "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/app/server/utils" + api "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" +) + +func TestSQLFormatter(t *testing.T) { + type testCase struct { + testName string + selectReq *api.TSelect + output string + err error + } + + logger := utils.NewTestLogger(t) + formatter := NewSQLFormatter() + + tcs := []testCase{ + { + testName: "empty_table_name", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "", + }, + What: &api.TSelect_TWhat{}, + }, + output: "", + err: utils.ErrEmptyTableName, + }, + { + testName: "empty_no columns", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: &api.TSelect_TWhat{}, + }, + output: "SELECT 0 FROM tab", + err: nil, + }, + { + testName: "select_col", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: &api.TSelect_TWhat{ + Items: []*api.TSelect_TWhat_TItem{ + &api.TSelect_TWhat_TItem{ + Payload: &api.TSelect_TWhat_TItem_Column{ + Column: &ydb.Column{ + Name: "col", + Type: utils.NewPrimitiveType(ydb.Type_INT32), + }, + }, + }, + }, + }, + }, + output: "SELECT col FROM tab", + err: nil, + }, + { + testName: "is_null", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_IsNull{ + IsNull: &api.TPredicate_TIsNull{ + Value: utils.NewColumnExpression("col1"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE (col1 IS NULL)", + err: nil, + }, + { + testName: "is_not_null", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_IsNotNull{ + IsNotNull: &api.TPredicate_TIsNotNull{ + Value: utils.NewColumnExpression("col2"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE (col2 IS NOT NULL)", + err: nil, + }, + { + testName: "bool_column", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_BoolExpression{ + BoolExpression: &api.TPredicate_TBoolExpression{ + Value: utils.NewColumnExpression("col2"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE col2", + err: nil, + }, + { + testName: "complex_filter", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Disjunction{ + Disjunction: &api.TPredicate_TDisjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Negation{ + Negation: &api.TPredicate_TNegation{ + Operand: &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_LE, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewInt32ValueExpression(42), + }, + }, + }, + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_Conjunction{ + Conjunction: &api.TPredicate_TConjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_NE, + LeftValue: utils.NewColumnExpression("col1"), + RightValue: utils.NewUint64ValueExpression(0), + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_IsNull{ + IsNull: &api.TPredicate_TIsNull{ + Value: utils.NewColumnExpression("col3"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE ((NOT (col2 <= 42)) OR ((col1 <> 0) AND (col3 IS NULL)))", + err: nil, + }, + { + testName: "unsupported_predicate", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Between{ + Between: &api.TPredicate_TBetween{ + Value: utils.NewColumnExpression("col2"), + Least: utils.NewColumnExpression("col1"), + Greatest: utils.NewColumnExpression("col3"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab", + err: nil, + }, + { + testName: "unsupported_type", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewTextValueExpression("text"), + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab", + err: nil, + }, + { + testName: "partial_filter_removes_and", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Conjunction{ + Conjunction: &api.TPredicate_TConjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col1"), + RightValue: utils.NewInt32ValueExpression(32), + }, + }, + }, + &api.TPredicate{ + // Not supported + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewTextValueExpression("text"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE (col1 = 32)", + err: nil, + }, + { + testName: "partial_filter", + selectReq: &api.TSelect{ + From: &api.TSelect_TFrom{ + Table: "tab", + }, + What: utils.NewDefaultWhat(), + Where: &api.TSelect_TWhere{ + FilterTyped: &api.TPredicate{ + Payload: &api.TPredicate_Conjunction{ + Conjunction: &api.TPredicate_TConjunction{ + Operands: []*api.TPredicate{ + &api.TPredicate{ + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col1"), + RightValue: utils.NewInt32ValueExpression(32), + }, + }, + }, + &api.TPredicate{ + // Not supported + Payload: &api.TPredicate_Comparison{ + Comparison: &api.TPredicate_TComparison{ + Operation: api.TPredicate_TComparison_EQ, + LeftValue: utils.NewColumnExpression("col2"), + RightValue: utils.NewTextValueExpression("text"), + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_IsNull{ + IsNull: &api.TPredicate_TIsNull{ + Value: utils.NewColumnExpression("col3"), + }, + }, + }, + &api.TPredicate{ + Payload: &api.TPredicate_IsNotNull{ + IsNotNull: &api.TPredicate_TIsNotNull{ + Value: utils.NewColumnExpression("col4"), + }, + }, + }, + }, + }, + }, + }, + }, + }, + output: "SELECT col0, col1 FROM tab WHERE ((col1 = 32) AND (col3 IS NULL) AND (col4 IS NOT NULL))", + err: nil, + }, + } + + for _, tc := range tcs { + tc := tc + + t.Run(tc.testName, func(t *testing.T) { + output, err := formatter.FormatRead(logger, tc.selectReq) + require.Equal(t, tc.output, output) + + if tc.err != nil { + require.True(t, errors.Is(err, tc.err)) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/postgresql/ut/ya.make b/ydb/library/yql/providers/generic/connector/app/server/postgresql/ut/ya.make new file mode 100644 index 0000000000..93069164e0 --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/postgresql/ut/ya.make @@ -0,0 +1,5 @@ +GO_TEST_FOR(ydb/library/yql/providers/generic/connector/app/server/postgresql) + +SIZE(SMALL) + +END() diff --git a/ydb/library/yql/providers/generic/connector/app/server/postgresql/ya.make b/ydb/library/yql/providers/generic/connector/app/server/postgresql/ya.make index 6925db2cf7..a08bd57be7 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/postgresql/ya.make +++ b/ydb/library/yql/providers/generic/connector/app/server/postgresql/ya.make @@ -4,7 +4,16 @@ SRCS( connection_manager.go doc.go query_executor.go + sql_formatter.go type_mapper.go ) +GO_TEST_SRCS( + sql_formatter_test.go +) + END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler.go b/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler.go index 66c8268591..914a418133 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler.go +++ b/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler.go @@ -3,7 +3,6 @@ package rdbms import ( "context" "fmt" - "strings" "github.com/ydb-platform/ydb/library/go/core/log" "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/app/server/paging" @@ -13,6 +12,7 @@ import ( type handlerImpl struct { typeMapper utils.TypeMapper + sqlFormatter utils.SQLFormatter queryBuilder utils.QueryExecutor connectionManager utils.ConnectionManager logger log.Logger @@ -66,65 +66,13 @@ func (h *handlerImpl) DescribeTable( return &api_service_protos.TDescribeTableResponse{Schema: schema}, nil } -func (h *handlerImpl) makeReadSplitQuery( - logger log.Logger, - split *api_service_protos.TSplit, -) (string, error) { - // SELECT $columns - // interpolate request - var sb strings.Builder - - sb.WriteString("SELECT ") - - columns, err := utils.SelectWhatToYDBColumns(split.Select.What) - if err != nil { - return "", fmt.Errorf("convert Select.What.Items to Ydb.Columns: %w", err) - } - - // for the case of empty column set select some constant for constructing a valid sql statement - if len(columns) == 0 { - sb.WriteString("0") - } else { - for i, column := range columns { - sb.WriteString(column.GetName()) - - if i != len(columns)-1 { - sb.WriteString(", ") - } - } - } - - // SELECT $columns FROM $from - tableName := split.GetSelect().GetFrom().GetTable() - if tableName == "" { - return "", fmt.Errorf("empty table name") - } - - sb.WriteString(" FROM ") - sb.WriteString(tableName) - - if split.Select.Where != nil { - clause, err := FormatWhereClause(split.Select.Where) - if err != nil { - logger.Error("Failed to format WHERE clause", log.Error(err), log.String("where", split.Select.Where.String())) - } else { - sb.WriteString(" ") - sb.WriteString(clause) - } - } - - // execute query - - return sb.String(), nil -} - func (h *handlerImpl) doReadSplit( ctx context.Context, logger log.Logger, split *api_service_protos.TSplit, sink paging.Sink, ) error { - query, err := h.makeReadSplitQuery(logger, split) + query, err := h.sqlFormatter.FormatRead(logger, split.Select) if err != nil { return fmt.Errorf("make read split query: %w", err) } @@ -187,6 +135,7 @@ func newHandler( ) Handler { return &handlerImpl{ logger: logger, + sqlFormatter: preset.sqlFormatter, queryBuilder: preset.queryExecutor, connectionManager: preset.connectionManager, typeMapper: preset.typeMapper, diff --git a/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler_factory.go b/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler_factory.go index 2b4aa105ba..33560cebea 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler_factory.go +++ b/ydb/library/yql/providers/generic/connector/app/server/rdbms/handler_factory.go @@ -11,6 +11,7 @@ import ( ) type handlerPreset struct { + sqlFormatter utils.SQLFormatter queryExecutor utils.QueryExecutor connectionManager utils.ConnectionManager typeMapper utils.TypeMapper @@ -42,11 +43,13 @@ func NewHandlerFactory(qlf utils.QueryLoggerFactory) HandlerFactory { return &handlerFactoryImpl{ clickhouse: handlerPreset{ + sqlFormatter: clickhouse.NewSQLFormatter(), queryExecutor: clickhouse.NewQueryExecutor(), connectionManager: clickhouse.NewConnectionManager(connManagerCfg), typeMapper: clickhouse.NewTypeMapper(), }, postgresql: handlerPreset{ + sqlFormatter: postgresql.NewSQLFormatter(), queryExecutor: postgresql.NewQueryExecutor(), connectionManager: postgresql.NewConnectionManager(connManagerCfg), typeMapper: postgresql.NewTypeMapper(), diff --git a/ydb/library/yql/providers/generic/connector/app/server/rdbms/mock.go b/ydb/library/yql/providers/generic/connector/app/server/rdbms/mock.go index 28dbe58fe6..dbe05948dd 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/rdbms/mock.go +++ b/ydb/library/yql/providers/generic/connector/app/server/rdbms/mock.go @@ -41,6 +41,7 @@ func (m *HandlerMock) TypeMapper() utils.TypeMapper { var _ HandlerFactory = (*HandlerFactoryMock)(nil) type HandlerFactoryMock struct { + SQLFormatter utils.SQLFormatter QueryExecutor utils.QueryExecutor ConnectionManager utils.ConnectionManager TypeMapper utils.TypeMapper @@ -50,6 +51,7 @@ func (m *HandlerFactoryMock) Make(logger log.Logger, dataSourceType api_common.E handler := newHandler( logger, &handlerPreset{ + sqlFormatter: m.SQLFormatter, queryExecutor: m.QueryExecutor, connectionManager: m.ConnectionManager, typeMapper: m.TypeMapper, diff --git a/ydb/library/yql/providers/generic/connector/app/server/rdbms/predicate_builder.go b/ydb/library/yql/providers/generic/connector/app/server/rdbms/predicate_builder.go deleted file mode 100644 index 383873ba85..0000000000 --- a/ydb/library/yql/providers/generic/connector/app/server/rdbms/predicate_builder.go +++ /dev/null @@ -1,86 +0,0 @@ -package rdbms - -import ( - "fmt" - - "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" - "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/app/server/utils" - api_service_protos "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" -) - -func FormatValue(value *Ydb.TypedValue) (string, error) { - switch v := value.Value.Value.(type) { - case *Ydb.Value_BoolValue: - return fmt.Sprintf("%t", v.BoolValue), nil - case *Ydb.Value_Int32Value: - return fmt.Sprintf("%d", v.Int32Value), nil - case *Ydb.Value_Uint32Value: - return fmt.Sprintf("%d", v.Uint32Value), nil - default: - return "", fmt.Errorf("%w, type: %T", utils.ErrUnimplementedTypedValue, v) - } -} - -func FormatExpression(expression *api_service_protos.TExpression) (string, error) { - switch e := expression.Payload.(type) { - case *api_service_protos.TExpression_Column: - return e.Column, nil - case *api_service_protos.TExpression_TypedValue: - return FormatValue(e.TypedValue) - default: - return "", fmt.Errorf("%w, type: %T", utils.ErrUnimplementedExpression, e) - } -} - -func FormatComparison(comparison *api_service_protos.TPredicate_TComparison) (string, error) { - var operation string - - switch op := comparison.Operation; op { - case api_service_protos.TPredicate_TComparison_EQ: - operation = " = " - default: - return "", fmt.Errorf("%w, op: %d", utils.ErrUnimplementedOperation, op) - } - - var ( - left string - right string - err error - ) - - left, err = FormatExpression(comparison.LeftValue) - if err != nil { - return "", fmt.Errorf("failed to format left argument: %w", err) - } - - right, err = FormatExpression(comparison.RightValue) - if err != nil { - return "", fmt.Errorf("failed to format right argument: %w", err) - } - - return fmt.Sprintf("(%s%s%s)", left, operation, right), nil -} - -func FormatPredicate(predicate *api_service_protos.TPredicate) (string, error) { - switch p := predicate.Payload.(type) { - case *api_service_protos.TPredicate_Comparison: - return FormatComparison(p.Comparison) - default: - return "", fmt.Errorf("%w, type: %T", utils.ErrUnimplementedPredicateType, p) - } -} - -func FormatWhereClause(where *api_service_protos.TSelect_TWhere) (string, error) { - if where.FilterTyped == nil { - return "", utils.ErrUnimplemented - } - - formatted, err := FormatPredicate(where.FilterTyped) - if err != nil { - return "", err - } - - result := "WHERE " + formatted - - return result, nil -} diff --git a/ydb/library/yql/providers/generic/connector/app/server/rdbms/ya.make b/ydb/library/yql/providers/generic/connector/app/server/rdbms/ya.make index b0a2a078f2..1f8ee91ca6 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/rdbms/ya.make +++ b/ydb/library/yql/providers/generic/connector/app/server/rdbms/ya.make @@ -6,7 +6,6 @@ SRCS( handler_factory.go interface.go mock.go - predicate_builder.go schema_builder.go ) diff --git a/ydb/library/yql/providers/generic/connector/app/server/streaming/streamer_test.go b/ydb/library/yql/providers/generic/connector/app/server/streaming/streamer_test.go index 197abfc6f1..cb50b4b220 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/streaming/streamer_test.go +++ b/ydb/library/yql/providers/generic/connector/app/server/streaming/streamer_test.go @@ -198,6 +198,7 @@ func (tc testCaseStreaming) execute(t *testing.T) { typeMapper := clickhouse.NewTypeMapper() handlerFactory := &rdbms.HandlerFactoryMock{ + SQLFormatter: clickhouse.NewSQLFormatter(), ConnectionManager: connectionManager, TypeMapper: typeMapper, } diff --git a/ydb/library/yql/providers/generic/connector/app/server/utils/errors.go b/ydb/library/yql/providers/generic/connector/app/server/utils/errors.go index a7517a193e..ba6f27e34c 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/utils/errors.go +++ b/ydb/library/yql/providers/generic/connector/app/server/utils/errors.go @@ -10,18 +10,21 @@ import ( ) var ( - ErrTableDoesNotExist = fmt.Errorf("table does not exist") - ErrDataSourceNotSupported = fmt.Errorf("data source not supported") - ErrDataTypeNotSupported = fmt.Errorf("data type not supported") - ErrReadLimitExceeded = fmt.Errorf("read limit exceeded") - ErrInvalidRequest = fmt.Errorf("invalid request") - ErrValueOutOfTypeBounds = fmt.Errorf("value is out of possible range of values for the type") - ErrUnimplemented = fmt.Errorf("unimplemented") - ErrUnimplementedTypedValue = fmt.Errorf("unimplemented typed value") - ErrUnimplementedExpression = fmt.Errorf("unimplemented expression") - ErrUnimplementedOperation = fmt.Errorf("unimplemented operation") - ErrUnimplementedPredicateType = fmt.Errorf("unimplemented predicate type") - ErrInvariantViolation = fmt.Errorf("implementation error (invariant violation)") + ErrTableDoesNotExist = fmt.Errorf("table does not exist") + ErrDataSourceNotSupported = fmt.Errorf("data source not supported") + ErrDataTypeNotSupported = fmt.Errorf("data type not supported") + ErrReadLimitExceeded = fmt.Errorf("read limit exceeded") + ErrInvalidRequest = fmt.Errorf("invalid request") + ErrValueOutOfTypeBounds = fmt.Errorf("value is out of possible range of values for the type") + ErrUnimplemented = fmt.Errorf("unimplemented") + ErrUnimplementedTypedValue = fmt.Errorf("unimplemented typed value") + ErrUnimplementedExpression = fmt.Errorf("unimplemented expression") + ErrUnsupportedExpression = fmt.Errorf("expression is not supported") + ErrUnimplementedOperation = fmt.Errorf("unimplemented operation") + ErrUnimplementedPredicateType = fmt.Errorf("unimplemented predicate type") + ErrInvariantViolation = fmt.Errorf("implementation error (invariant violation)") + ErrUnimplementedArithmeticalExpression = fmt.Errorf("unimplemented arithmetical expression") + ErrEmptyTableName = fmt.Errorf("empty table name") ) func NewSuccess() *api_service_protos.TError { @@ -66,6 +69,10 @@ func NewAPIErrorFromStdError(err error) *api_service_protos.TError { status = Ydb.StatusIds_UNSUPPORTED case errors.Is(err, ErrUnimplemented): status = Ydb.StatusIds_UNSUPPORTED + case errors.Is(err, ErrUnimplementedArithmeticalExpression): + status = Ydb.StatusIds_UNSUPPORTED + case errors.Is(err, ErrEmptyTableName): + status = Ydb.StatusIds_BAD_REQUEST default: status = Ydb.StatusIds_INTERNAL_ERROR } diff --git a/ydb/library/yql/providers/generic/connector/app/server/utils/predicate_builder.go b/ydb/library/yql/providers/generic/connector/app/server/utils/predicate_builder.go new file mode 100644 index 0000000000..b763a9a265 --- /dev/null +++ b/ydb/library/yql/providers/generic/connector/app/server/utils/predicate_builder.go @@ -0,0 +1,291 @@ +package utils + +import ( + "fmt" + "strings" + + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + api_service_protos "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" +) + +type PredicateBuilderFeatures interface { + // Support for high level expression (without subexpressions, they are checked separately) + SupportsExpression(expression *api_service_protos.TExpression) bool +} + +func FormatValue(value *Ydb.TypedValue) (string, error) { + switch v := value.Value.Value.(type) { + case *Ydb.Value_BoolValue: + return fmt.Sprintf("%t", v.BoolValue), nil + case *Ydb.Value_Int32Value: + return fmt.Sprintf("%d", v.Int32Value), nil + case *Ydb.Value_Uint32Value: + return fmt.Sprintf("%d", v.Uint32Value), nil + case *Ydb.Value_Int64Value: + return fmt.Sprintf("%d", v.Int64Value), nil + case *Ydb.Value_Uint64Value: + return fmt.Sprintf("%d", v.Uint64Value), nil + case *Ydb.Value_FloatValue: + return fmt.Sprintf("%e", v.FloatValue), nil // TODO: support query parameters + case *Ydb.Value_DoubleValue: + return fmt.Sprintf("%e", v.DoubleValue), nil // TODO: support query parameters + default: + return "", fmt.Errorf("%w, type: %T", ErrUnimplementedTypedValue, v) + } +} + +func FormatColumn(col string) (string, error) { + return col, nil +} + +func FormatNull(n *api_service_protos.TExpression_TNull) (string, error) { + return "NULL", nil +} + +func FormatArithmeticalExpression(expression *api_service_protos.TExpression_TArithmeticalExpression, features PredicateBuilderFeatures) (string, error) { + var operation string + + switch op := expression.Operation; op { + case api_service_protos.TExpression_TArithmeticalExpression_MUL: + operation = " * " + case api_service_protos.TExpression_TArithmeticalExpression_ADD: + operation = " + " + case api_service_protos.TExpression_TArithmeticalExpression_SUB: + operation = " - " + case api_service_protos.TExpression_TArithmeticalExpression_BIT_AND: + operation = " & " + case api_service_protos.TExpression_TArithmeticalExpression_BIT_OR: + operation = " | " + case api_service_protos.TExpression_TArithmeticalExpression_BIT_XOR: + operation = " ^ " + default: + return "", fmt.Errorf("%w, op: %d", ErrUnimplementedArithmeticalExpression, op) + } + + var ( + left string + right string + err error + ) + + left, err = FormatExpression(expression.LeftValue, features) + if err != nil { + return "", fmt.Errorf("failed to format left argument: %w", err) + } + + right, err = FormatExpression(expression.RightValue, features) + if err != nil { + return "", fmt.Errorf("failed to format right argument: %w", err) + } + + return fmt.Sprintf("(%s%s%s)", left, operation, right), nil +} + +func FormatExpression(expression *api_service_protos.TExpression, features PredicateBuilderFeatures) (string, error) { + if !features.SupportsExpression(expression) { + return "", ErrUnsupportedExpression + } + + switch e := expression.Payload.(type) { + case *api_service_protos.TExpression_Column: + return FormatColumn(e.Column) + case *api_service_protos.TExpression_TypedValue: + return FormatValue(e.TypedValue) + case *api_service_protos.TExpression_ArithmeticalExpression: + return FormatArithmeticalExpression(e.ArithmeticalExpression, features) + case *api_service_protos.TExpression_Null: + return FormatNull(e.Null) + default: + return "", fmt.Errorf("%w, type: %T", ErrUnimplementedExpression, e) + } +} + +func FormatComparison(comparison *api_service_protos.TPredicate_TComparison, features PredicateBuilderFeatures) (string, error) { + var operation string + + switch op := comparison.Operation; op { + case api_service_protos.TPredicate_TComparison_L: + operation = " < " + case api_service_protos.TPredicate_TComparison_LE: + operation = " <= " + case api_service_protos.TPredicate_TComparison_EQ: + operation = " = " + case api_service_protos.TPredicate_TComparison_NE: + operation = " <> " + case api_service_protos.TPredicate_TComparison_GE: + operation = " >= " + case api_service_protos.TPredicate_TComparison_G: + operation = " > " + default: + return "", fmt.Errorf("%w, op: %d", ErrUnimplementedOperation, op) + } + + var ( + left string + right string + err error + ) + + left, err = FormatExpression(comparison.LeftValue, features) + if err != nil { + return "", fmt.Errorf("failed to format left argument: %w", err) + } + + right, err = FormatExpression(comparison.RightValue, features) + if err != nil { + return "", fmt.Errorf("failed to format right argument: %w", err) + } + + return fmt.Sprintf("(%s%s%s)", left, operation, right), nil +} + +func FormatNegation(negation *api_service_protos.TPredicate_TNegation, features PredicateBuilderFeatures) (string, error) { + pred, err := FormatPredicate(negation.Operand, features, false) + if err != nil { + return "", fmt.Errorf("failed to format NOT statement: %w", err) + } + + return fmt.Sprintf("(NOT %s)", pred), nil +} + +func FormatConjunction(conjunction *api_service_protos.TPredicate_TConjunction, features PredicateBuilderFeatures, topLevel bool) (string, error) { + var ( + sb strings.Builder + succeeded int32 = 0 + statement string + err error + first string + ) + + for _, predicate := range conjunction.Operands { + statement, err = FormatPredicate(predicate, features, false) + if err != nil { + if !topLevel { + return "", fmt.Errorf("failed to format AND statement: %w", err) + } + } else { + if succeeded > 0 { + if succeeded == 1 { + sb.WriteString("(") + sb.WriteString(first) + } + + sb.WriteString(" AND ") + sb.WriteString(statement) + } else { + first = statement + } + + succeeded++ + } + } + + if succeeded == 0 { + return "", fmt.Errorf("failed to format AND statement: %w", err) + } + + if succeeded == 1 { + sb.WriteString(first) + } else { + sb.WriteString(")") + } + + return sb.String(), nil +} + +func FormatDisjunction(disjunction *api_service_protos.TPredicate_TDisjunction, features PredicateBuilderFeatures) (string, error) { + var ( + sb strings.Builder + cnt int32 = 0 + statement string + err error + first string + ) + + for _, predicate := range disjunction.Operands { + statement, err = FormatPredicate(predicate, features, false) + if err != nil { + return "", fmt.Errorf("failed to format OR statement: %w", err) + } else { + if cnt > 0 { + if cnt == 1 { + sb.WriteString("(") + sb.WriteString(first) + } + + sb.WriteString(" OR ") + sb.WriteString(statement) + } else { + first = statement + } + + cnt++ + } + } + + if cnt == 0 { + return "", fmt.Errorf("failed to format OR statement: no operands") + } + + if cnt == 1 { + sb.WriteString(first) + } else { + sb.WriteString(")") + } + + return sb.String(), nil +} + +func FormatIsNull(isNull *api_service_protos.TPredicate_TIsNull, features PredicateBuilderFeatures) (string, error) { + statement, err := FormatExpression(isNull.Value, features) + if err != nil { + return "", fmt.Errorf("failed to format IS NULL statement: %w", err) + } + + return fmt.Sprintf("(%s IS NULL)", statement), nil +} + +func FormatIsNotNull(isNotNull *api_service_protos.TPredicate_TIsNotNull, features PredicateBuilderFeatures) (string, error) { + statement, err := FormatExpression(isNotNull.Value, features) + if err != nil { + return "", fmt.Errorf("failed to format IS NOT NULL statement: %w", err) + } + + return fmt.Sprintf("(%s IS NOT NULL)", statement), nil +} + +func FormatPredicate(predicate *api_service_protos.TPredicate, features PredicateBuilderFeatures, topLevel bool) (string, error) { + switch p := predicate.Payload.(type) { + case *api_service_protos.TPredicate_Negation: + return FormatNegation(p.Negation, features) + case *api_service_protos.TPredicate_Conjunction: + return FormatConjunction(p.Conjunction, features, topLevel) + case *api_service_protos.TPredicate_Disjunction: + return FormatDisjunction(p.Disjunction, features) + case *api_service_protos.TPredicate_IsNull: + return FormatIsNull(p.IsNull, features) + case *api_service_protos.TPredicate_IsNotNull: + return FormatIsNotNull(p.IsNotNull, features) + case *api_service_protos.TPredicate_Comparison: + return FormatComparison(p.Comparison, features) + case *api_service_protos.TPredicate_BoolExpression: + return FormatExpression(p.BoolExpression.Value, features) + default: + return "", fmt.Errorf("%w, type: %T", ErrUnimplementedPredicateType, p) + } +} + +func FormatWhereClause(where *api_service_protos.TSelect_TWhere, features PredicateBuilderFeatures) (string, error) { + if where.FilterTyped == nil { + return "", ErrUnimplemented + } + + formatted, err := FormatPredicate(where.FilterTyped, features, true) + if err != nil { + return "", err + } + + result := "WHERE " + formatted + + return result, nil +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/utils/select_helpers.go b/ydb/library/yql/providers/generic/connector/app/server/utils/select_helpers.go index 520bfcbbee..ebf5e7d0c1 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/utils/select_helpers.go +++ b/ydb/library/yql/providers/generic/connector/app/server/utils/select_helpers.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "strings" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" api_service_protos "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" @@ -36,3 +37,41 @@ func SelectWhatToYDBColumns(selectWhat *api_service_protos.TSelect_TWhat) ([]*Yd return columns, nil } + +func FormatSelectColumns(selectWhat *api_service_protos.TSelect_TWhat, tableName string, fakeZeroOnEmptyColumnsSet bool) (string, error) { + // SELECT $columns FROM $from + if tableName == "" { + return "", ErrEmptyTableName + } + + var sb strings.Builder + + sb.WriteString("SELECT ") + + columns, err := SelectWhatToYDBColumns(selectWhat) + if err != nil { + return "", fmt.Errorf("convert Select.What.Items to Ydb.Columns: %w", err) + } + + // for the case of empty column set select some constant for constructing a valid sql statement + if len(columns) == 0 { + if fakeZeroOnEmptyColumnsSet { + sb.WriteString("0") + } else { + return "", fmt.Errorf("empty columns set") + } + } else { + for i, column := range columns { + sb.WriteString(column.GetName()) + + if i != len(columns)-1 { + sb.WriteString(", ") + } + } + } + + sb.WriteString(" FROM ") + sb.WriteString(tableName) + + return sb.String(), nil +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/utils/sql.go b/ydb/library/yql/providers/generic/connector/app/server/utils/sql.go index 578602a650..51d7c2da76 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/utils/sql.go +++ b/ydb/library/yql/providers/generic/connector/app/server/utils/sql.go @@ -33,3 +33,7 @@ type ConnectionManagerBase struct { type QueryExecutor interface { DescribeTable(ctx context.Context, conn Connection, request *api_service_protos.TDescribeTableRequest) (Rows, error) } + +type SQLFormatter interface { + FormatRead(logger log.Logger, selectReq *api_service_protos.TSelect) (string, error) +} diff --git a/ydb/library/yql/providers/generic/connector/app/server/utils/unit_test_helpers.go b/ydb/library/yql/providers/generic/connector/app/server/utils/unit_test_helpers.go index ef47473088..f8b3fe408e 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/utils/unit_test_helpers.go +++ b/ydb/library/yql/providers/generic/connector/app/server/utils/unit_test_helpers.go @@ -6,30 +6,96 @@ import ( api_service_protos "github.com/ydb-platform/ydb/ydb/library/yql/providers/generic/connector/libgo/service/protos" ) -func MakeTestSplit() *api_service_protos.TSplit { - return &api_service_protos.TSplit{ - Select: &api_service_protos.TSelect{ - DataSourceInstance: &api_common.TDataSourceInstance{}, - What: &api_service_protos.TSelect_TWhat{ - Items: []*api_service_protos.TSelect_TWhat_TItem{ - { - Payload: &api_service_protos.TSelect_TWhat_TItem_Column{ - Column: &Ydb.Column{ - Name: "col0", - Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_INT32}}, - }, - }, +func NewPrimitiveType(t Ydb.Type_PrimitiveTypeId) *Ydb.Type { + return &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: t, + }, + } +} + +// NewDefaultWhat generates default What field with a pair of columns +func NewDefaultWhat() *api_service_protos.TSelect_TWhat { + return &api_service_protos.TSelect_TWhat{ + Items: []*api_service_protos.TSelect_TWhat_TItem{ + &api_service_protos.TSelect_TWhat_TItem{ + Payload: &api_service_protos.TSelect_TWhat_TItem_Column{ + Column: &Ydb.Column{ + Name: "col0", + Type: NewPrimitiveType(Ydb.Type_INT32), + }, + }, + }, + &api_service_protos.TSelect_TWhat_TItem{ + Payload: &api_service_protos.TSelect_TWhat_TItem_Column{ + Column: &Ydb.Column{ + Name: "col1", + Type: NewPrimitiveType(Ydb.Type_STRING), }, - { - Payload: &api_service_protos.TSelect_TWhat_TItem_Column{ - Column: &Ydb.Column{ - Name: "col1", - Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_STRING}}, - }, - }, + }, + }, + }, + } +} + +func NewColumnExpression(name string) *api_service_protos.TExpression { + return &api_service_protos.TExpression{ + Payload: &api_service_protos.TExpression_Column{ + Column: name, + }, + } +} + +func NewInt32ValueExpression(val int32) *api_service_protos.TExpression { + return &api_service_protos.TExpression{ + Payload: &api_service_protos.TExpression_TypedValue{ + TypedValue: &Ydb.TypedValue{ + Type: NewPrimitiveType(Ydb.Type_INT32), + Value: &Ydb.Value{ + Value: &Ydb.Value_Int32Value{ + Int32Value: val, }, }, }, + }, + } +} + +func NewUint64ValueExpression(val uint64) *api_service_protos.TExpression { + return &api_service_protos.TExpression{ + Payload: &api_service_protos.TExpression_TypedValue{ + TypedValue: &Ydb.TypedValue{ + Type: NewPrimitiveType(Ydb.Type_UINT64), + Value: &Ydb.Value{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: val, + }, + }, + }, + }, + } +} + +func NewTextValueExpression(val string) *api_service_protos.TExpression { + return &api_service_protos.TExpression{ + Payload: &api_service_protos.TExpression_TypedValue{ + TypedValue: &Ydb.TypedValue{ + Type: NewPrimitiveType(Ydb.Type_UTF8), + Value: &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: val, + }, + }, + }, + }, + } +} + +func MakeTestSplit() *api_service_protos.TSplit { + return &api_service_protos.TSplit{ + Select: &api_service_protos.TSelect{ + DataSourceInstance: &api_common.TDataSourceInstance{}, + What: NewDefaultWhat(), From: &api_service_protos.TSelect_TFrom{ Table: "example_1", }, diff --git a/ydb/library/yql/providers/generic/connector/app/server/utils/ya.make b/ydb/library/yql/providers/generic/connector/app/server/utils/ya.make index ccf51cffe0..f07e74b159 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/utils/ya.make +++ b/ydb/library/yql/providers/generic/connector/app/server/utils/ya.make @@ -7,6 +7,7 @@ SRCS( endpoint.go errors.go logger.go + predicate_builder.go protobuf.go select_helpers.go sql.go diff --git a/ydb/library/yql/providers/generic/connector/app/server/ya.make b/ydb/library/yql/providers/generic/connector/app/server/ya.make index 063cb98f35..9fcde7a038 100644 --- a/ydb/library/yql/providers/generic/connector/app/server/ya.make +++ b/ydb/library/yql/providers/generic/connector/app/server/ya.make @@ -2,7 +2,7 @@ GO_LIBRARY() SRCS( cmd.go - config.go + config.go doc.go grpc_metrics.go httppuller.go diff --git a/ydb/library/yql/providers/generic/connector/tests/test_cases/select_pushdown.py b/ydb/library/yql/providers/generic/connector/tests/test_cases/select_pushdown.py index 289d165135..c5fad9f6aa 100644 --- a/ydb/library/yql/providers/generic/connector/tests/test_cases/select_pushdown.py +++ b/ydb/library/yql/providers/generic/connector/tests/test_cases/select_pushdown.py @@ -83,46 +83,62 @@ class Factory: data_source_type=DataSourceType(pg=postgresql.Int4()), ), Column( + name='col_int64', + ydb_type=Type.INT64, + data_source_type=DataSourceType(pg=postgresql.Int8()), + ), + Column( name='col_string', ydb_type=Type.UTF8, data_source_type=DataSourceType(pg=postgresql.Text()), ), + Column( + name='col_float', + ydb_type=Type.FLOAT, + data_source_type=DataSourceType(pg=postgresql.Float4()), + ), ), ) data_in = [ - [ - 1, - 'one', - ], - [ - 2, - 'two', - ], - [ - 3, - 'three', - ], + [1, 2, 'one', 1.1], + [2, 2, 'two', 1.23456789], + [3, 5, 'three', 0.00000012], ] - data_out = [ + data_out_1 = [ ['one'], ] + data_out_2 = [ + ['two'], + ] + data_source_kind = EDataSourceKind.POSTGRESQL return [ TestCase( name=f'{self._name}_{data_source_kind}', data_in=data_in, - data_out_=data_out, + data_out_=data_out_1, pragmas=dict({'generic.UsePredicatePushdown': 'true'}), select_what=SelectWhat(SelectWhat.Item(name='col_string')), select_where=SelectWhere('col_int32 = 1'), data_source_kind=data_source_kind, schema=schema, database=Database.make_for_data_source_kind(data_source_kind), - ) + ), + TestCase( + name=f'{self._name}_{data_source_kind}', + data_in=data_in, + data_out_=data_out_2, + pragmas=dict({'generic.UsePredicatePushdown': 'true'}), + select_what=SelectWhat(SelectWhat.Item(name='col_string')), + select_where=SelectWhere('col_int32 = col_int64'), + data_source_kind=data_source_kind, + schema=schema, + database=Database.make_for_data_source_kind(data_source_kind), + ), ] def make_test_cases(self) -> Sequence[TestCase]: diff --git a/ydb/library/yql/providers/generic/connector/tests/utils/dqrun.py b/ydb/library/yql/providers/generic/connector/tests/utils/dqrun.py index 5f88098172..c99d50869b 100644 --- a/ydb/library/yql/providers/generic/connector/tests/utils/dqrun.py +++ b/ydb/library/yql/providers/generic/connector/tests/utils/dqrun.py @@ -244,14 +244,14 @@ class Runner: for line in out.stderr.decode('utf-8').splitlines(): LOGGER.error(line) - unique_suffix = test_dir.name - err_file = yatest.common.output_path(f'dqrun-{unique_suffix}.err') - with open(err_file, "w") as f: - f.write(out.stderr.decode('utf-8')) - - out_file = yatest.common.output_path(f'dqrun-{unique_suffix}.out') - with open(out_file, "w") as f: - f.write(out.stdout.decode('utf-8')) + unique_suffix = test_dir.name + err_file = yatest.common.output_path(f'dqrun-{unique_suffix}.err') + with open(err_file, "w") as f: + f.write(out.stderr.decode('utf-8')) + + out_file = yatest.common.output_path(f'dqrun-{unique_suffix}.out') + with open(out_file, "w") as f: + f.write(out.stdout.decode('utf-8')) return Result( data_out=data_out, |