aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/google.golang.org/grpc/internal/wrr/random.go
blob: 6d5eb7d462099fc88eb9900396ddbb213f8fef5e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/*
 *
 * Copyright 2019 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package wrr

import (
	"fmt"
	"sort"
	"sync"

	"google.golang.org/grpc/internal/grpcrand"
)

// weightedItem is a wrapped weighted item that is used to implement weighted random algorithm.
type weightedItem struct {
	item              interface{}
	weight            int64
	accumulatedWeight int64
}

func (w *weightedItem) String() string {
	return fmt.Sprint(*w)
}

// randomWRR is a struct that contains weighted items implement weighted random algorithm.
type randomWRR struct {
	mu    sync.RWMutex
	items []*weightedItem
	// Are all item's weights equal
	equalWeights bool
}

// NewRandom creates a new WRR with random.
func NewRandom() WRR {
	return &randomWRR{}
}

var grpcrandInt63n = grpcrand.Int63n

func (rw *randomWRR) Next() (item interface{}) {
	rw.mu.RLock()
	defer rw.mu.RUnlock()
	if len(rw.items) == 0 {
		return nil
	}
	if rw.equalWeights {
		return rw.items[grpcrandInt63n(int64(len(rw.items)))].item
	}

	sumOfWeights := rw.items[len(rw.items)-1].accumulatedWeight
	// Random number in [0, sumOfWeights).
	randomWeight := grpcrandInt63n(sumOfWeights)
	// Item's accumulated weights are in ascending order, because item's weight >= 0.
	// Binary search rw.items to find first item whose accumulatedWeight > randomWeight
	// The return i is guaranteed to be in range [0, len(rw.items)) because randomWeight < last item's accumulatedWeight
	i := sort.Search(len(rw.items), func(i int) bool { return rw.items[i].accumulatedWeight > randomWeight })
	return rw.items[i].item
}

func (rw *randomWRR) Add(item interface{}, weight int64) {
	rw.mu.Lock()
	defer rw.mu.Unlock()
	accumulatedWeight := weight
	equalWeights := true
	if len(rw.items) > 0 {
		lastItem := rw.items[len(rw.items)-1]
		accumulatedWeight = lastItem.accumulatedWeight + weight
		equalWeights = rw.equalWeights && weight == lastItem.weight
	}
	rw.equalWeights = equalWeights
	rItem := &weightedItem{item: item, weight: weight, accumulatedWeight: accumulatedWeight}
	rw.items = append(rw.items, rItem)
}

func (rw *randomWRR) String() string {
	return fmt.Sprint(rw.items)
}