aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/google.golang.org/grpc/xds/internal/balancer/wrrlocality/balancer.go
blob: 4df2e4ed0086af75a1725fb606a425de1fb7885c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
/*
 *
 * Copyright 2023 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

// Package wrrlocality provides an implementation of the wrr locality LB policy,
// as defined in [A52 - xDS Custom LB Policies].
//
// [A52 - xDS Custom LB Policies]: https://github.com/grpc/proposal/blob/master/A52-xds-custom-lb-policies.md
package wrrlocality

import (
	"encoding/json"
	"errors"
	"fmt"

	"google.golang.org/grpc/balancer"
	"google.golang.org/grpc/balancer/weightedtarget"
	"google.golang.org/grpc/internal/grpclog"
	internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
	"google.golang.org/grpc/resolver"
	"google.golang.org/grpc/serviceconfig"
	"google.golang.org/grpc/xds/internal"
)

// Name is the name of wrr_locality balancer.
const Name = "xds_wrr_locality_experimental"

func init() {
	balancer.Register(bb{})
}

type bb struct{}

func (bb) Name() string {
	return Name
}

// LBConfig is the config for the wrr locality balancer.
type LBConfig struct {
	serviceconfig.LoadBalancingConfig `json:"-"`
	// ChildPolicy is the config for the child policy.
	ChildPolicy *internalserviceconfig.BalancerConfig `json:"childPolicy,omitempty"`
}

// To plumb in a different child in tests.
var weightedTargetName = weightedtarget.Name

func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
	builder := balancer.Get(weightedTargetName)
	if builder == nil {
		// Shouldn't happen, registered through imported weighted target,
		// defensive programming.
		return nil
	}

	// Doesn't need to intercept any balancer.ClientConn operations; pass
	// through by just giving cc to child balancer.
	wtb := builder.Build(cc, bOpts)
	if wtb == nil {
		// shouldn't happen, defensive programming.
		return nil
	}
	wtbCfgParser, ok := builder.(balancer.ConfigParser)
	if !ok {
		// Shouldn't happen, imported weighted target builder has this method.
		return nil
	}
	wrrL := &wrrLocalityBalancer{
		child:       wtb,
		childParser: wtbCfgParser,
	}

	wrrL.logger = prefixLogger(wrrL)
	wrrL.logger.Infof("Created")
	return wrrL
}

func (bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
	var lbCfg *LBConfig
	if err := json.Unmarshal(s, &lbCfg); err != nil {
		return nil, fmt.Errorf("xds_wrr_locality: invalid LBConfig: %s, error: %v", string(s), err)
	}
	if lbCfg == nil || lbCfg.ChildPolicy == nil {
		return nil, errors.New("xds_wrr_locality: invalid LBConfig: child policy field must be set")
	}
	return lbCfg, nil
}

type attributeKey struct{}

// Equal allows the values to be compared by Attributes.Equal.
func (a AddrInfo) Equal(o interface{}) bool {
	oa, ok := o.(AddrInfo)
	return ok && oa.LocalityWeight == a.LocalityWeight
}

// AddrInfo is the locality weight of the locality an address is a part of.
type AddrInfo struct {
	LocalityWeight uint32
}

// SetAddrInfo returns a copy of addr in which the BalancerAttributes field is
// updated with AddrInfo.
func SetAddrInfo(addr resolver.Address, addrInfo AddrInfo) resolver.Address {
	addr.BalancerAttributes = addr.BalancerAttributes.WithValue(attributeKey{}, addrInfo)
	return addr
}

func (a AddrInfo) String() string {
	return fmt.Sprintf("Locality Weight: %d", a.LocalityWeight)
}

// getAddrInfo returns the AddrInfo stored in the BalancerAttributes field of
// addr. Returns false if no AddrInfo found.
func getAddrInfo(addr resolver.Address) (AddrInfo, bool) {
	v := addr.BalancerAttributes.Value(attributeKey{})
	ai, ok := v.(AddrInfo)
	return ai, ok
}

// wrrLocalityBalancer wraps a weighted target balancer, and builds
// configuration for the weighted target once it receives configuration
// specifying the weighted target child balancer and locality weight
// information.
type wrrLocalityBalancer struct {
	// child will be a weighted target balancer, and will be built it at
	// wrrLocalityBalancer build time. Other than preparing configuration, other
	// balancer operations are simply pass through.
	child balancer.Balancer

	childParser balancer.ConfigParser

	logger *grpclog.PrefixLogger
}

func (b *wrrLocalityBalancer) UpdateClientConnState(s balancer.ClientConnState) error {
	lbCfg, ok := s.BalancerConfig.(*LBConfig)
	if !ok {
		b.logger.Errorf("Received config with unexpected type %T: %v", s.BalancerConfig, s.BalancerConfig)
		return balancer.ErrBadResolverState
	}

	weightedTargets := make(map[string]weightedtarget.Target)
	for _, addr := range s.ResolverState.Addresses {
		// This get of LocalityID could potentially return a zero value. This
		// shouldn't happen though (this attribute that is set actually gets
		// used to build localities in the first place), and thus don't error
		// out, and just build a weighted target with undefined behavior.
		locality, err := internal.GetLocalityID(addr).ToString()
		if err != nil {
			// Should never happen.
			logger.Errorf("Failed to marshal LocalityID: %v, skipping this locality in weighted target")
		}
		ai, ok := getAddrInfo(addr)
		if !ok {
			return fmt.Errorf("xds_wrr_locality: missing locality weight information in address %q", addr)
		}
		weightedTargets[locality] = weightedtarget.Target{Weight: ai.LocalityWeight, ChildPolicy: lbCfg.ChildPolicy}
	}
	wtCfg := &weightedtarget.LBConfig{Targets: weightedTargets}
	wtCfgJSON, err := json.Marshal(wtCfg)
	if err != nil {
		// Shouldn't happen.
		return fmt.Errorf("xds_wrr_locality: error marshalling prepared config: %v", wtCfg)
	}
	var sc serviceconfig.LoadBalancingConfig
	if sc, err = b.childParser.ParseConfig(wtCfgJSON); err != nil {
		return fmt.Errorf("xds_wrr_locality: config generated %v is invalid: %v", wtCfgJSON, err)
	}

	return b.child.UpdateClientConnState(balancer.ClientConnState{
		ResolverState:  s.ResolverState,
		BalancerConfig: sc,
	})
}

func (b *wrrLocalityBalancer) ResolverError(err error) {
	b.child.ResolverError(err)
}

func (b *wrrLocalityBalancer) UpdateSubConnState(sc balancer.SubConn, scState balancer.SubConnState) {
	b.child.UpdateSubConnState(sc, scState)
}

func (b *wrrLocalityBalancer) Close() {
	b.child.Close()
}