aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/jackc/pgx/v5/named_args_test.go
blob: bd54faa1bfd41fa87cc77ff53ede8f2ef78faf17 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package pgx_test

import (
	"context"
	"testing"

	"github.com/jackc/pgx/v5"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNamedArgsRewriteQuery(t *testing.T) {
	t.Parallel()

	for i, tt := range []struct {
		sql          string
		args         []any
		namedArgs    pgx.NamedArgs
		expectedSQL  string
		expectedArgs []any
	}{
		{
			sql:          "select * from users where id = @id",
			namedArgs:    pgx.NamedArgs{"id": int32(42)},
			expectedSQL:  "select * from users where id = $1",
			expectedArgs: []any{int32(42)},
		},
		{
			sql:          "select * from t where foo < @abc and baz = @def and bar < @abc",
			namedArgs:    pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
			expectedSQL:  "select * from t where foo < $1 and baz = $2 and bar < $1",
			expectedArgs: []any{int32(42), int32(1)},
		},
		{
			sql:          "select @a::int, @b::text",
			namedArgs:    pgx.NamedArgs{"a": int32(42), "b": "foo"},
			expectedSQL:  "select $1::int, $2::text",
			expectedArgs: []any{int32(42), "foo"},
		},
		{
			sql:          "select @Abc::int, @b_4::text",
			namedArgs:    pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"},
			expectedSQL:  "select $1::int, $2::text",
			expectedArgs: []any{int32(42), "foo"},
		},
		{
			sql:          "at end @",
			namedArgs:    pgx.NamedArgs{"a": int32(42), "b": "foo"},
			expectedSQL:  "at end @",
			expectedArgs: []any{},
		},
		{
			sql:          "ignores without letter after @ foo bar",
			namedArgs:    pgx.NamedArgs{"a": int32(42), "b": "foo"},
			expectedSQL:  "ignores without letter after @ foo bar",
			expectedArgs: []any{},
		},
		{
			sql:          "name must start with letter @1 foo bar",
			namedArgs:    pgx.NamedArgs{"a": int32(42), "b": "foo"},
			expectedSQL:  "name must start with letter @1 foo bar",
			expectedArgs: []any{},
		},
		{
			sql:          `select *, '@foo' as "@bar" from users where id = @id`,
			namedArgs:    pgx.NamedArgs{"id": int32(42)},
			expectedSQL:  `select *, '@foo' as "@bar" from users where id = $1`,
			expectedArgs: []any{int32(42)},
		},
		{
			sql: `select * -- @foo
			from users -- @single line comments
			where id = @id;`,
			namedArgs: pgx.NamedArgs{"id": int32(42)},
			expectedSQL: `select * -- @foo
			from users -- @single line comments
			where id = $1;`,
			expectedArgs: []any{int32(42)},
		},
		{
			sql: `select * /* @multi line
			@comment
			*/
			/* /* with @nesting */ */
			from users
			where id = @id;`,
			namedArgs: pgx.NamedArgs{"id": int32(42)},
			expectedSQL: `select * /* @multi line
			@comment
			*/
			/* /* with @nesting */ */
			from users
			where id = $1;`,
			expectedArgs: []any{int32(42)},
		},

		// test comments and quotes
	} {
		sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
		require.NoError(t, err)
		assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
		assert.Equalf(t, tt.expectedArgs, args, "%d", i)
	}
}