aboutsummaryrefslogtreecommitdiffstats
path: root/library/go/httputil/middleware/tvm/middleware.go
blob: cbcf1aae1d4f1b9f1b1eb792518c5bcc1512397f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package tvm

import (
	"context"
	"net/http"

	"golang.org/x/xerrors"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/yandex/tvm"
)

const (
	// XYaServiceTicket is http header that should be used for service ticket transfer.
	XYaServiceTicket = "X-Ya-Service-Ticket"
	// XYaUserTicket is http header that should be used for user ticket transfer.
	XYaUserTicket = "X-Ya-User-Ticket"
)

type (
	MiddlewareOption func(*middleware)

	middleware struct {
		l log.Structured

		tvm tvm.Client

		authClient func(context.Context, tvm.ClientID) error

		onError func(w http.ResponseWriter, r *http.Request, err error)
	}
)

func defaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
	http.Error(w, err.Error(), http.StatusForbidden)
}

// CheckServiceTicket returns http middleware that validates service tickets for all incoming requests.
//
// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket.
func CheckServiceTicket(tvm tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler {
	m := middleware{
		tvm:     tvm,
		onError: defaultErrorHandler,
	}

	for _, opt := range opts {
		opt(&m)
	}

	if m.authClient == nil {
		panic("must provide authorization policy")
	}

	if m.l == nil {
		m.l = &nop.Logger{}
	}

	return m.wrap
}

func (m *middleware) wrap(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		serviceTicket := r.Header.Get(XYaServiceTicket)
		if serviceTicket == "" {
			ctxlog.Error(r.Context(), m.l.Logger(), "missing service ticket")
			m.onError(w, r, xerrors.New("missing service ticket"))
			return
		}

		ticket, err := m.tvm.CheckServiceTicket(r.Context(), serviceTicket)
		if err != nil {
			ctxlog.Error(r.Context(), m.l.Logger(), "service ticket check failed", log.Error(err))
			m.onError(w, r, xerrors.Errorf("service ticket check failed: %w", err))
			return
		}

		if err := m.authClient(r.Context(), ticket.SrcID); err != nil {
			ctxlog.Error(r.Context(), m.l.Logger(), "client authorization failed",
				log.String("ticket", ticket.LogInfo),
				log.Error(err))
			m.onError(w, r, xerrors.Errorf("client authorization failed: %w", err))
			return
		}

		r = r.WithContext(tvm.WithServiceTicket(r.Context(), ticket))
		next.ServeHTTP(w, r)
	})
}