aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/ClickHouse/clickhouse-go/helpers.go
blob: 7e10d0620414207a3d042a2ae1bf51d52e5d9559 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package clickhouse

import (
	"bytes"
	"database/sql/driver"
	"fmt"
	"reflect"
	"regexp"
	"strings"
	"time"
)

func numInput(query string) int {

	var (
		count         int
		args          = make(map[string]struct{})
		reader        = bytes.NewReader([]byte(query))
		quote, gravis bool
		escape        bool
		keyword       bool
		inBetween     bool
		like          = newMatcher("like")
		limit         = newMatcher("limit")
		offset        = newMatcher("offset")
		between       = newMatcher("between")
		in            = newMatcher("in")
		and           = newMatcher("and")
		from          = newMatcher("from")
		join          = newMatcher("join")
		subSelect     = newMatcher("select")
	)
	for {
		if char, _, err := reader.ReadRune(); err == nil {
			if escape {
				escape = false
				continue
			}
			switch char {
			case '\\':
				if gravis || quote {
					escape = true
				}
			case '\'':
				if !gravis {
					quote = !quote
				}
			case '`':
				if !quote {
					gravis = !gravis
				}
			}
			if quote || gravis {
				continue
			}
			switch {
			case char == '?' && keyword:
				count++
			case char == '@':
				if param := paramParser(reader); len(param) != 0 {
					if _, found := args[param]; !found {
						args[param] = struct{}{}
						count++
					}
				}
			case
				char == '=',
				char == '<',
				char == '>',
				char == '(',
				char == ',',
				char == '[',
				char == '%':
				keyword = true
			default:
				if limit.matchRune(char) || offset.matchRune(char) || like.matchRune(char) ||
					in.matchRune(char) || from.matchRune(char) || join.matchRune(char) || subSelect.matchRune(char) {
					keyword = true
				} else if between.matchRune(char) {
					keyword = true
					inBetween = true
				} else if inBetween && and.matchRune(char) {
					keyword = true
					inBetween = false
				} else {
					keyword = keyword && (char == ' ' || char == '\t' || char == '\n')
				}
			}
		} else {
			break
		}
	}
	return count
}

func paramParser(reader *bytes.Reader) string {
	var name bytes.Buffer
	for {
		if char, _, err := reader.ReadRune(); err == nil {
			if char == '_' || char >= '0' && char <= '9' || 'a' <= char && char <= 'z' || 'A' <= char && char <= 'Z' {
				name.WriteRune(char)
			} else {
				reader.UnreadRune()
				break
			}
		} else {
			break
		}
	}
	return name.String()
}

var selectRe = regexp.MustCompile(`\s+SELECT\s+`)

func isInsert(query string) bool {
	if f := strings.Fields(query); len(f) > 2 {
		return strings.EqualFold("INSERT", f[0]) && strings.EqualFold("INTO", f[1]) && !selectRe.MatchString(strings.ToUpper(query))
	}
	return false
}

func quote(v driver.Value) string {
	switch v := reflect.ValueOf(v); v.Kind() {
	case reflect.Slice:
		values := make([]string, 0, v.Len())
		for i := 0; i < v.Len(); i++ {
			values = append(values, quote(v.Index(i).Interface()))
		}
		return strings.Join(values, ", ")
	}
	switch v := v.(type) {
	case string:
		return "'" + strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace(v) + "'"
	case time.Time:
		return formatTime(v)
	case nil:
		return "null"
	}
	return fmt.Sprint(v)
}

func formatTime(v time.Time) string {
	return v.Format("toDateTime('2006-01-02 15:04:05', '" + v.Location().String() + "')")
}