aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds/provider.go
blob: 5c699f16650563a0bcfae2da15f8b4592ec7240d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
package ec2rolecreds

import (
	"bufio"
	"context"
	"encoding/json"
	"fmt"
	"math"
	"path"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
	"github.com/aws/aws-sdk-go-v2/internal/sdk"
	"github.com/aws/smithy-go"
	"github.com/aws/smithy-go/logging"
	"github.com/aws/smithy-go/middleware"
)

// ProviderName provides a name of EC2Role provider
const ProviderName = "EC2RoleProvider"

// GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
// GetMetadata operation.
type GetMetadataAPIClient interface {
	GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
}

// A Provider retrieves credentials from the EC2 service, and keeps track if
// those credentials are expired.
//
// The New function must be used to create the with a custom EC2 IMDS client.
//
//	p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
//	     o.Client = imds.New(imds.Options{/* custom options */})
//	})
type Provider struct {
	options Options
}

// Options is a list of user settable options for setting the behavior of the Provider.
type Options struct {
	// The API client that will be used by the provider to make GetMetadata API
	// calls to EC2 IMDS.
	//
	// If nil, the provider will default to the EC2 IMDS client.
	Client GetMetadataAPIClient
}

// New returns an initialized Provider value configured to retrieve
// credentials from EC2 Instance Metadata service.
func New(optFns ...func(*Options)) *Provider {
	options := Options{}

	for _, fn := range optFns {
		fn(&options)
	}

	if options.Client == nil {
		options.Client = imds.New(imds.Options{})
	}

	return &Provider{
		options: options,
	}
}

// Retrieve retrieves credentials from the EC2 service. Error will be returned
// if the request fails, or unable to extract the desired credentials.
func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
	credsList, err := requestCredList(ctx, p.options.Client)
	if err != nil {
		return aws.Credentials{Source: ProviderName}, err
	}

	if len(credsList) == 0 {
		return aws.Credentials{Source: ProviderName},
			fmt.Errorf("unexpected empty EC2 IMDS role list")
	}
	credsName := credsList[0]

	roleCreds, err := requestCred(ctx, p.options.Client, credsName)
	if err != nil {
		return aws.Credentials{Source: ProviderName}, err
	}

	creds := aws.Credentials{
		AccessKeyID:     roleCreds.AccessKeyID,
		SecretAccessKey: roleCreds.SecretAccessKey,
		SessionToken:    roleCreds.Token,
		Source:          ProviderName,

		CanExpire: true,
		Expires:   roleCreds.Expiration,
	}

	// Cap role credentials Expires to 1 hour so they can be refreshed more
	// often. Jitter will be applied credentials cache if being used.
	if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
		creds.Expires = anHour
	}

	return creds, nil
}

// HandleFailToRefresh will extend the credentials Expires time if it it is
// expired. If the credentials will not expire within the minimum time, they
// will be returned.
//
// If the credentials cannot expire, the original error will be returned.
func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
	aws.Credentials, error,
) {
	if !prevCreds.CanExpire {
		return aws.Credentials{}, err
	}

	if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
		return prevCreds, nil
	}

	newCreds := prevCreds
	randFloat64, err := sdkrand.CryptoRandFloat64()
	if err != nil {
		return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
	}

	// Random distribution of [5,15) minutes.
	expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
	newCreds.Expires = sdk.NowTime().Add(expireOffset)

	logger := middleware.GetLogger(ctx)
	logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))

	return newCreds, nil
}

// AdjustExpiresBy will adds the passed in duration to the passed in
// credential's Expires time, unless the time until Expires is less than 15
// minutes. Returns the credentials, even if not updated.
func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
	aws.Credentials, error,
) {
	if !creds.CanExpire {
		return creds, nil
	}
	if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
		return creds, nil
	}

	creds.Expires = creds.Expires.Add(dur)
	return creds, nil
}

// ec2RoleCredRespBody provides the shape for unmarshaling credential
// request responses.
type ec2RoleCredRespBody struct {
	// Success State
	Expiration      time.Time
	AccessKeyID     string
	SecretAccessKey string
	Token           string

	// Error state
	Code    string
	Message string
}

const iamSecurityCredsPath = "/iam/security-credentials/"

// requestCredList requests a list of credentials from the EC2 service. If
// there are no credentials, or there is an error making or receiving the
// request
func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
		Path: iamSecurityCredsPath,
	})
	if err != nil {
		return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
	}
	defer resp.Content.Close()

	credsList := []string{}
	s := bufio.NewScanner(resp.Content)
	for s.Scan() {
		credsList = append(credsList, s.Text())
	}

	if err := s.Err(); err != nil {
		return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
	}

	return credsList, nil
}

// requestCred requests the credentials for a specific credentials from the EC2 service.
//
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
		Path: path.Join(iamSecurityCredsPath, credsName),
	})
	if err != nil {
		return ec2RoleCredRespBody{},
			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
				credsName, err)
	}
	defer resp.Content.Close()

	var respCreds ec2RoleCredRespBody
	if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
		return ec2RoleCredRespBody{},
			fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
				credsName, err)
	}

	if !strings.EqualFold(respCreds.Code, "Success") {
		// If an error code was returned something failed requesting the role.
		return ec2RoleCredRespBody{},
			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
				credsName,
				&smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
	}

	return respCreds, nil
}