aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/google.golang.org/grpc/internal/testutils/fakegrpclb/server.go
blob: 82be2c1af1a40e898d625cd01f44c38800c041e9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
/*
 *
 * Copyright 2022 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

// Package fakegrpclb provides a fake implementation of the grpclb server.
package fakegrpclb

import (
	"errors"
	"fmt"
	"io"
	"net"
	"strconv"
	"sync"
	"time"

	"google.golang.org/grpc"
	lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
	lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/grpclog"
	"google.golang.org/grpc/internal/pretty"
	"google.golang.org/grpc/status"
)

var logger = grpclog.Component("fake_grpclb")

// ServerParams wraps options passed while creating a Server.
type ServerParams struct {
	ListenPort    int                 // Listening port for the balancer server.
	ServerOptions []grpc.ServerOption // gRPC options for the balancer server.

	LoadBalancedServiceName string   // Service name being load balanced for.
	LoadBalancedServicePort int      // Service port being load balanced for.
	BackendAddresses        []string // Service backends to balance load across.
	ShortStream             bool     // End balancer stream after sending server list.
}

// Server is a fake implementation of the grpclb LoadBalancer service. It does
// not support stats reporting from clients, and always sends back a static list
// of backends to the client to balance load across.
//
// It is safe for concurrent access.
type Server struct {
	lbgrpc.UnimplementedLoadBalancerServer

	// Options copied over from ServerParams passed to NewServer.
	sOpts       []grpc.ServerOption // gRPC server options.
	serviceName string              // Service name being load balanced for.
	servicePort int                 // Service port being load balanced for.
	shortStream bool                // End balancer stream after sending server list.

	// Values initialized using ServerParams passed to NewServer.
	backends []*lbpb.Server // Service backends to balance load across.
	lis      net.Listener   // Listener for grpc connections to the LoadBalancer service.

	// mu guards access to below fields.
	mu         sync.Mutex
	grpcServer *grpc.Server // Underlying grpc server.
	address    string       // Actual listening address.

	stopped chan struct{} // Closed when Stop() is called.
}

// NewServer creates a new Server with passed in params. Returns a non-nil error
// if the params are invalid.
func NewServer(params ServerParams) (*Server, error) {
	var servers []*lbpb.Server
	for _, addr := range params.BackendAddresses {
		ipStr, portStr, err := net.SplitHostPort(addr)
		if err != nil {
			return nil, fmt.Errorf("failed to parse list of backend address %q: %v", addr, err)
		}
		ip := net.ParseIP(ipStr)
		if ip == nil {
			return nil, fmt.Errorf("failed to parse ip: %q", ipStr)
		}
		port, err := strconv.Atoi(portStr)
		if err != nil {
			return nil, fmt.Errorf("failed to convert port %q to int", portStr)
		}
		logger.Infof("Adding backend ip: %q, port: %d to server list", ip.String(), port)
		servers = append(servers, &lbpb.Server{
			IpAddress: ip,
			Port:      int32(port),
		})
	}

	lis, err := net.Listen("tcp", "localhost:"+strconv.Itoa(params.ListenPort))
	if err != nil {
		return nil, fmt.Errorf("failed to listen on port %q: %v", params.ListenPort, err)
	}

	return &Server{
		sOpts:       params.ServerOptions,
		serviceName: params.LoadBalancedServiceName,
		servicePort: params.LoadBalancedServicePort,
		shortStream: params.ShortStream,
		backends:    servers,
		lis:         lis,
		address:     lis.Addr().String(),
		stopped:     make(chan struct{}),
	}, nil
}

// Serve starts serving the LoadBalancer service on a gRPC server.
//
// It returns early with a non-nil error if it is unable to start serving.
// Otherwise, it blocks until Stop() is called, at which point it returns the
// error returned by the underlying grpc.Server's Serve() method.
func (s *Server) Serve() error {
	s.mu.Lock()
	if s.grpcServer != nil {
		s.mu.Unlock()
		return errors.New("Serve() called multiple times")
	}

	server := grpc.NewServer(s.sOpts...)
	s.grpcServer = server
	s.mu.Unlock()

	logger.Infof("Begin listening on %s", s.lis.Addr().String())
	lbgrpc.RegisterLoadBalancerServer(server, s)
	return server.Serve(s.lis) // This call will block.
}

// Stop stops serving the LoadBalancer service and unblocks the preceding call
// to Serve().
func (s *Server) Stop() {
	defer close(s.stopped)
	s.mu.Lock()
	if s.grpcServer != nil {
		s.grpcServer.Stop()
		s.grpcServer = nil
	}
	s.mu.Unlock()
}

// Address returns the host:port on which the LoadBalancer service is serving.
func (s *Server) Address() string {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.address
}

// BalanceLoad provides a fake implementation of the LoadBalancer service.
func (s *Server) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
	logger.Info("New BalancerLoad stream started")

	req, err := stream.Recv()
	if err == io.EOF {
		logger.Warning("Received EOF when reading from the stream")
		return nil
	}
	if err != nil {
		logger.Warning("Failed to read LoadBalanceRequest from stream: %v", err)
		return err
	}
	logger.Infof("Received LoadBalancerRequest:\n%s", pretty.ToJSON(req))

	// Initial request contains the service being load balanced for.
	initialReq := req.GetInitialRequest()
	if initialReq == nil {
		logger.Info("First message on the stream does not contain an InitialLoadBalanceRequest")
		return status.Error(codes.Unknown, "First request not an InitialLoadBalanceRequest")
	}

	// Basic validation of the service name and port from the incoming request.
	//
	// Clients targeting service:port can sometimes include the ":port" suffix in
	// their requested names; handle this case.
	serviceName, port, err := net.SplitHostPort(initialReq.Name)
	if err != nil {
		// Requested name did not contain a port. So, use the name as is.
		serviceName = initialReq.Name
	} else {
		p, err := strconv.Atoi(port)
		if err != nil {
			logger.Info("Failed to parse requested service port %q to integer", port)
			return status.Error(codes.Unknown, "Bad requested service port number")
		}
		if p != s.servicePort {
			logger.Info("Requested service port number %q does not match expected", port, s.servicePort)
			return status.Error(codes.Unknown, "Bad requested service port number")
		}
	}
	if serviceName != s.serviceName {
		logger.Info("Requested service name %q does not match expected %q", serviceName, s.serviceName)
		return status.Error(codes.NotFound, "Bad requested service name")
	}

	// Empty initial response disables stats reporting from the client. Stats
	// reporting from the client is used to determine backend load and is not
	// required for the purposes of this fake.
	initResp := &lbpb.LoadBalanceResponse{
		LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
			InitialResponse: &lbpb.InitialLoadBalanceResponse{},
		},
	}
	if err := stream.Send(initResp); err != nil {
		logger.Warningf("Failed to send InitialLoadBalanceResponse on the stream: %v", err)
		return err
	}

	resp := &lbpb.LoadBalanceResponse{
		LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
			ServerList: &lbpb.ServerList{Servers: s.backends},
		},
	}
	logger.Infof("Sending response with server list: %s", pretty.ToJSON(resp))
	if err := stream.Send(resp); err != nil {
		logger.Warningf("Failed to send InitialLoadBalanceResponse on the stream: %v", err)
		return err
	}

	if s.shortStream {
		logger.Info("Ending stream early as the short stream option was set")
		return nil
	}

	for {
		select {
		case <-stream.Context().Done():
			return nil
		case <-s.stopped:
			return nil
		case <-time.After(10 * time.Second):
			logger.Infof("Sending response with server list: %s", pretty.ToJSON(resp))
			if err := stream.Send(resp); err != nil {
				logger.Warningf("Failed to send InitialLoadBalanceResponse on the stream: %v", err)
				return err
			}
		}
	}
}