aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/ClickHouse/clickhouse-go/connect.go
blob: 14cb1c81ce5b9f2c1b2bf2ffb27200a9c31a5453 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
package clickhouse

import (
	"bufio"
	"crypto/tls"
	"database/sql/driver"
	"net"
	"sync"
	"sync/atomic"
	"time"
)

var tick int32

type openStrategy int8

func (s openStrategy) String() string {
	switch s {
	case connOpenInOrder:
		return "in_order"
	case connOpenTimeRandom:
		return "time_random"
	}
	return "random"
}

const (
	connOpenRandom openStrategy = iota + 1
	connOpenInOrder
	connOpenTimeRandom
)

type connOptions struct {
	secure, skipVerify                     bool
	tlsConfig                              *tls.Config
	hosts                                  []string
	connTimeout, readTimeout, writeTimeout time.Duration
	noDelay                                bool
	openStrategy                           openStrategy
	logf                                   func(string, ...interface{})
}

// DialFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDial
type DialFunc func(network, address string, timeout time.Duration, config *tls.Config) (net.Conn, error)

var (
	customDialLock sync.RWMutex
	customDial     DialFunc
)

// RegisterDial registers a custom dial function.
func RegisterDial(dial DialFunc) {
	customDialLock.Lock()
	customDial = dial
	customDialLock.Unlock()
}

// DeregisterDial deregisters the custom dial function.
func DeregisterDial() {
	customDialLock.Lock()
	customDial = nil
	customDialLock.Unlock()
}
func dial(options connOptions) (*connect, error) {
	var (
		err error
		abs = func(v int) int {
			if v < 0 {
				return -1 * v
			}
			return v
		}
		conn  net.Conn
		ident = abs(int(atomic.AddInt32(&tick, 1)))
	)
	tlsConfig := options.tlsConfig
	if options.secure {
		if tlsConfig == nil {
			tlsConfig = &tls.Config{}
		}
		tlsConfig.InsecureSkipVerify = options.skipVerify
	}
	checkedHosts := make(map[int]struct{}, len(options.hosts))
	for i := range options.hosts {
		var num int
		switch options.openStrategy {
		case connOpenInOrder:
			num = i
		case connOpenRandom:
			num = (ident + i) % len(options.hosts)
		case connOpenTimeRandom:
			// select host based on milliseconds
			num = int((time.Now().UnixNano()/1000)%1000) % len(options.hosts)
			for _, ok := checkedHosts[num]; ok; _, ok = checkedHosts[num] {
				num = int(time.Now().UnixNano()) % len(options.hosts)
			}
			checkedHosts[num] = struct{}{}
		}
		customDialLock.RLock()
		cd := customDial
		customDialLock.RUnlock()
		switch {
		case options.secure:
			if cd != nil {
				conn, err = cd("tcp", options.hosts[num], options.connTimeout, tlsConfig)
			} else {
				conn, err = tls.DialWithDialer(
					&net.Dialer{
						Timeout: options.connTimeout,
					},
					"tcp",
					options.hosts[num],
					tlsConfig,
				)
			}
		default:
			if cd != nil {
				conn, err = cd("tcp", options.hosts[num], options.connTimeout, nil)
			} else {
				conn, err = net.DialTimeout("tcp", options.hosts[num], options.connTimeout)
			}
		}
		if err == nil {
			options.logf(
				"[dial] secure=%t, skip_verify=%t, strategy=%s, ident=%d, server=%d -> %s",
				options.secure,
				options.skipVerify,
				options.openStrategy,
				ident,
				num,
				conn.RemoteAddr(),
			)
			if tcp, ok := conn.(*net.TCPConn); ok {
				err = tcp.SetNoDelay(options.noDelay) // Disable or enable the Nagle Algorithm for this tcp socket
				if err != nil {
					return nil, err
				}
			}
			return &connect{
				Conn:         conn,
				logf:         options.logf,
				ident:        ident,
				buffer:       bufio.NewReader(conn),
				readTimeout:  options.readTimeout,
				writeTimeout: options.writeTimeout,
			}, nil
		} else {
			options.logf(
				"[dial err] secure=%t, skip_verify=%t, strategy=%s, ident=%d, addr=%s\n%#v",
				options.secure,
				options.skipVerify,
				options.openStrategy,
				ident,
				options.hosts[num],
				err,
			)
		}
	}
	return nil, err
}

type connect struct {
	net.Conn
	logf                  func(string, ...interface{})
	ident                 int
	buffer                *bufio.Reader
	closed                bool
	readTimeout           time.Duration
	writeTimeout          time.Duration
	lastReadDeadlineTime  time.Time
	lastWriteDeadlineTime time.Time
}

func (conn *connect) Read(b []byte) (int, error) {
	var (
		n      int
		err    error
		total  int
		dstLen = len(b)
	)
	if currentTime := now(); conn.readTimeout != 0 && currentTime.Sub(conn.lastReadDeadlineTime) > (conn.readTimeout>>2) {
		conn.SetReadDeadline(time.Now().Add(conn.readTimeout))
		conn.lastReadDeadlineTime = currentTime
	}
	for total < dstLen {
		if n, err = conn.buffer.Read(b[total:]); err != nil {
			conn.logf("[connect] read error: %v", err)
			conn.Close()
			return n, driver.ErrBadConn
		}
		total += n
	}
	return total, nil
}

func (conn *connect) Write(b []byte) (int, error) {
	var (
		n      int
		err    error
		total  int
		srcLen = len(b)
	)
	if currentTime := now(); conn.writeTimeout != 0 && currentTime.Sub(conn.lastWriteDeadlineTime) > (conn.writeTimeout>>2) {
		conn.SetWriteDeadline(time.Now().Add(conn.writeTimeout))
		conn.lastWriteDeadlineTime = currentTime
	}
	for total < srcLen {
		if n, err = conn.Conn.Write(b[total:]); err != nil {
			conn.logf("[connect] write error: %v", err)
			conn.Close()
			return n, driver.ErrBadConn
		}
		total += n
	}
	return n, nil
}

func (conn *connect) Close() error {
	if !conn.closed {
		conn.closed = true
		return conn.Conn.Close()
	}
	return nil
}