aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/gonum.org/v1/gonum/stat/roc.go
blob: 05c6b44d38518b75a793ede37968dd4c39824b0e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package stat

import (
	"math"
	"sort"
)

// ROC returns paired false positive rate (FPR) and true positive rate
// (TPR) values corresponding to cutoff points on the receiver operator
// characteristic (ROC) curve obtained when y is treated as a binary
// classifier for classes with weights. The cutoff thresholds used to
// calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
// are the true and false positive rates for y >= thresh[i].
//
// The input y and cutoffs must be sorted, and values in y must correspond
// to values in classes and weights. SortWeightedLabeled can be used to
// sort y together with classes and weights.
//
// For a given cutoff value, observations corresponding to entries in y
// greater than the cutoff value are classified as true, while those
// less than or equal to the cutoff value are classified as false. These
// assigned class labels are compared with the true values in the classes
// slice and used to calculate the FPR and TPR.
//
// If weights is nil, all weights are treated as 1. If weights is not nil
// it must have the same length as y and classes, otherwise ROC will panic.
//
// If cutoffs is nil or empty, all possible cutoffs are calculated,
// resulting in fpr and tpr having length one greater than the number of
// unique values in y. Otherwise fpr and tpr will be returned with the
// same length as cutoffs. floats.Span can be used to generate equally
// spaced cutoffs.
//
// More details about ROC curves are available at
// https://en.wikipedia.org/wiki/Receiver_operating_characteristic
func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
	if len(y) != len(classes) {
		panic("stat: slice length mismatch")
	}
	if weights != nil && len(y) != len(weights) {
		panic("stat: slice length mismatch")
	}
	if !sort.Float64sAreSorted(y) {
		panic("stat: input must be sorted ascending")
	}
	if !sort.Float64sAreSorted(cutoffs) {
		panic("stat: cutoff values must be sorted ascending")
	}
	if len(y) == 0 {
		return nil, nil, nil
	}
	if len(cutoffs) == 0 {
		if cutoffs == nil || cap(cutoffs) < len(y)+1 {
			cutoffs = make([]float64, len(y)+1)
		} else {
			cutoffs = cutoffs[:len(y)+1]
		}
		// Choose all possible cutoffs for unique values in y.
		bin := 0
		cutoffs[bin] = y[0]
		for i, u := range y[1:] {
			if u == y[i] {
				continue
			}
			bin++
			cutoffs[bin] = u
		}
		cutoffs[bin+1] = math.Inf(1)
		cutoffs = cutoffs[:bin+2]
	} else {
		// Don't mutate the provided cutoffs.
		tmp := cutoffs
		cutoffs = make([]float64, len(cutoffs))
		copy(cutoffs, tmp)
	}

	tpr = make([]float64, len(cutoffs))
	fpr = make([]float64, len(cutoffs))
	var bin int
	var nPos, nNeg float64
	for i, u := range classes {
		// Update the bin until it matches the next y value
		// skipping empty bins.
		for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] {
			bin++
			tpr[bin] = tpr[bin-1]
			fpr[bin] = fpr[bin-1]
		}
		posWeight, negWeight := 1.0, 0.0
		if weights != nil {
			posWeight = weights[i]
		}
		if !u {
			posWeight, negWeight = negWeight, posWeight
		}
		nPos += posWeight
		nNeg += negWeight
		// Count false negatives (in tpr) and true negatives (in fpr).
		if y[i] < cutoffs[bin] {
			tpr[bin] += posWeight
			fpr[bin] += negWeight
		}
	}

	invNeg := 1 / nNeg
	invPos := 1 / nPos
	// Convert negative counts to TPR and FPR.
	// Bins beyond the maximum value in y are skipped
	// leaving these fpr and tpr elements as zero.
	for i := range tpr[:bin+1] {
		// Prevent fused float operations by
		// making explicit float64 conversions.
		tpr[i] = 1 - float64(tpr[i]*invPos)
		fpr[i] = 1 - float64(fpr[i]*invNeg)
	}
	for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
		tpr[i], tpr[j] = tpr[j], tpr[i]
		fpr[i], fpr[j] = fpr[j], fpr[i]
	}
	for i, j := 0, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
		cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
	}

	return tpr, fpr, cutoffs
}

// TOC returns the Total Operating Characteristic for the classes provided
// and the minimum and maximum bounds for the TOC.
//
// The input y values that correspond to classes and weights must be sorted
// in ascending order. classes[i] is the class of value y[i] and weights[i]
// is the weight of y[i]. SortWeightedLabeled can be used to sort classes
// together with weights by the rank variable, i+1.
//
// The returned ntp values can be interpreted as the number of true positives
// where values above the given rank are assigned class true for each given
// rank from 1 to len(classes).
//
//	ntp_i = sum_{j ≥ len(ntp)-1 - i} [ classes_j ] * weights_j, where [x] = 1 if x else 0.
//
// The values of min and max provide the minimum and maximum possible number
// of false values for the set of classes. The first element of ntp, min and
// max are always zero as this corresponds to assigning all data class false
// and the last elements are always weighted sum of classes as this corresponds
// to assigning every data class true. For len(classes) != 0, the lengths of
// min, ntp and max are len(classes)+1.
//
// If weights is nil, all weights are treated as 1. When weights are not nil,
// the calculation of min and max allows for partial assignment of single data
// points. If weights is not nil it must have the same length as classes,
// otherwise TOC will panic.
//
// More details about TOC curves are available at
// https://en.wikipedia.org/wiki/Total_operating_characteristic
func TOC(classes []bool, weights []float64) (min, ntp, max []float64) {
	if weights != nil && len(classes) != len(weights) {
		panic("stat: slice length mismatch")
	}
	if len(classes) == 0 {
		return nil, nil, nil
	}

	ntp = make([]float64, len(classes)+1)
	min = make([]float64, len(ntp))
	max = make([]float64, len(ntp))
	if weights == nil {
		for i := range ntp[1:] {
			ntp[i+1] = ntp[i]
			if classes[len(classes)-i-1] {
				ntp[i+1]++
			}
		}
		totalPositive := ntp[len(ntp)-1]
		for i := range ntp {
			min[i] = math.Max(0, totalPositive-float64(len(classes)-i))
			max[i] = math.Min(totalPositive, float64(i))
		}
		return min, ntp, max
	}

	cumw := max // Reuse max for cumulative weight. Update its elements last.
	for i := range ntp[1:] {
		ntp[i+1] = ntp[i]
		w := weights[len(weights)-i-1]
		cumw[i+1] = cumw[i] + w
		if classes[len(classes)-i-1] {
			ntp[i+1] += w
		}
	}
	totw := cumw[len(cumw)-1]
	totalPositive := ntp[len(ntp)-1]
	for i := range ntp {
		min[i] = math.Max(0, totalPositive-(totw-cumw[i]))
		max[i] = math.Min(totalPositive, cumw[i])
	}
	return min, ntp, max
}