diff options
author | hcpp <hcpp@ydb.tech> | 2023-11-08 12:09:41 +0300 |
---|---|---|
committer | hcpp <hcpp@ydb.tech> | 2023-11-08 12:56:14 +0300 |
commit | a361f5b98b98b44ea510d274f6769164640dd5e1 (patch) | |
tree | c47c80962c6e2e7b06798238752fd3da0191a3f6 /library/go/httputil/middleware | |
parent | 9478806fde1f4d40bd5a45e7cbe77237dab613e9 (diff) | |
download | ydb-a361f5b98b98b44ea510d274f6769164640dd5e1.tar.gz |
metrics have been added
Diffstat (limited to 'library/go/httputil/middleware')
-rw-r--r-- | library/go/httputil/middleware/tvm/gotest/ya.make | 3 | ||||
-rw-r--r-- | library/go/httputil/middleware/tvm/middleware.go | 112 | ||||
-rw-r--r-- | library/go/httputil/middleware/tvm/middleware_opts.go | 46 | ||||
-rw-r--r-- | library/go/httputil/middleware/tvm/middleware_test.go | 126 | ||||
-rw-r--r-- | library/go/httputil/middleware/tvm/ya.make | 12 |
5 files changed, 299 insertions, 0 deletions
diff --git a/library/go/httputil/middleware/tvm/gotest/ya.make b/library/go/httputil/middleware/tvm/gotest/ya.make new file mode 100644 index 0000000000..f8ad1ffb46 --- /dev/null +++ b/library/go/httputil/middleware/tvm/gotest/ya.make @@ -0,0 +1,3 @@ +GO_TEST_FOR(library/go/httputil/middleware/tvm) + +END() diff --git a/library/go/httputil/middleware/tvm/middleware.go b/library/go/httputil/middleware/tvm/middleware.go new file mode 100644 index 0000000000..2e578ffca1 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware.go @@ -0,0 +1,112 @@ +package tvm + +import ( + "context" + "net/http" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/core/log/ctxlog" + "github.com/ydb-platform/ydb/library/go/core/log/nop" + "github.com/ydb-platform/ydb/library/go/httputil/headers" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "golang.org/x/xerrors" +) + +const ( + // XYaServiceTicket is http header that should be used for service ticket transfer. + XYaServiceTicket = headers.XYaServiceTicketKey + // XYaUserTicket is http header that should be used for user ticket transfer. + XYaUserTicket = headers.XYaUserTicketKey +) + +type ( + MiddlewareOption func(*middleware) + + middleware struct { + l log.Structured + + clients []tvm.Client + + authClient func(context.Context, tvm.ClientID, tvm.ClientID) error + + onError func(w http.ResponseWriter, r *http.Request, err error) + } +) + +func defaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusForbidden) +} + +func getMiddleware(clients []tvm.Client, opts ...MiddlewareOption) middleware { + m := middleware{ + clients: clients, + onError: defaultErrorHandler, + } + + for _, opt := range opts { + opt(&m) + } + + if m.authClient == nil { + panic("must provide authorization policy") + } + + if m.l == nil { + m.l = &nop.Logger{} + } + + return m +} + +// CheckServiceTicketMultiClient returns http middleware that validates service tickets for all incoming requests. +// It tries to check ticket with all the given clients in the given order +// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket. +func CheckServiceTicketMultiClient(clients []tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler { + m := getMiddleware(clients, opts...) + return m.wrap +} + +// CheckServiceTicket returns http middleware that validates service tickets for all incoming requests. +// +// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket. +func CheckServiceTicket(client tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler { + m := getMiddleware([]tvm.Client{client}, opts...) + return m.wrap +} + +func (m *middleware) wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + serviceTicket := r.Header.Get(XYaServiceTicket) + if serviceTicket == "" { + ctxlog.Error(r.Context(), m.l.Logger(), "missing service ticket") + m.onError(w, r, xerrors.New("missing service ticket")) + return + } + var ( + ticket *tvm.CheckedServiceTicket + err error + ) + for _, client := range m.clients { + ticket, err = client.CheckServiceTicket(r.Context(), serviceTicket) + if err == nil { + break + } + } + if err != nil { + ctxlog.Error(r.Context(), m.l.Logger(), "service ticket check failed", log.Error(err)) + m.onError(w, r, xerrors.Errorf("service ticket check failed: %w", err)) + return + } + + if err := m.authClient(r.Context(), ticket.SrcID, ticket.DstID); err != nil { + ctxlog.Error(r.Context(), m.l.Logger(), "client authorization failed", + log.String("ticket", ticket.LogInfo), + log.Error(err)) + m.onError(w, r, xerrors.Errorf("client authorization failed: %w", err)) + return + } + + r = r.WithContext(tvm.WithServiceTicket(r.Context(), ticket)) + next.ServeHTTP(w, r) + }) +} diff --git a/library/go/httputil/middleware/tvm/middleware_opts.go b/library/go/httputil/middleware/tvm/middleware_opts.go new file mode 100644 index 0000000000..4e33b4ee59 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware_opts.go @@ -0,0 +1,46 @@ +package tvm + +import ( + "context" + "net/http" + + "github.com/ydb-platform/ydb/library/go/core/log" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" + "golang.org/x/xerrors" +) + +// WithAllowedClients sets list of allowed clients. +func WithAllowedClients(allowedClients []tvm.ClientID) MiddlewareOption { + return func(m *middleware) { + m.authClient = func(_ context.Context, src tvm.ClientID, dst tvm.ClientID) error { + for _, allowed := range allowedClients { + if allowed == src { + return nil + } + } + + return xerrors.Errorf("client with tvm_id=%d is not whitelisted", dst) + } + } +} + +// WithClientAuth sets custom function for client authorization. +func WithClientAuth(authClient func(ctx context.Context, src tvm.ClientID, dst tvm.ClientID) error) MiddlewareOption { + return func(m *middleware) { + m.authClient = authClient + } +} + +// WithErrorHandler sets http handler invoked for rejected requests. +func WithErrorHandler(h func(w http.ResponseWriter, r *http.Request, err error)) MiddlewareOption { + return func(m *middleware) { + m.onError = h + } +} + +// WithLogger sets logger. +func WithLogger(l log.Structured) MiddlewareOption { + return func(m *middleware) { + m.l = l + } +} diff --git a/library/go/httputil/middleware/tvm/middleware_test.go b/library/go/httputil/middleware/tvm/middleware_test.go new file mode 100644 index 0000000000..e6005a76a6 --- /dev/null +++ b/library/go/httputil/middleware/tvm/middleware_test.go @@ -0,0 +1,126 @@ +package tvm + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/ydb-platform/ydb/library/go/yandex/tvm" +) + +type fakeClient struct { + ticket *tvm.CheckedServiceTicket + err error +} + +func (f *fakeClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) { + panic("implement me") +} + +func (f *fakeClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) { + panic("implement me") +} + +func (f *fakeClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) { + return f.ticket, f.err +} + +func (f *fakeClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) { + panic("implement me") +} + +func (f *fakeClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) { + panic("implement me") +} + +func (f *fakeClient) GetRoles(ctx context.Context) (*tvm.Roles, error) { + panic("implement me") +} + +func TestMiddlewareOkTicket(t *testing.T) { + var f fakeClient + f.ticket = &tvm.CheckedServiceTicket{SrcID: 42} + + m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + + var handlerCalled bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + require.Equal(t, f.ticket, tvm.ContextServiceTicket(r.Context())) + }) + + m(handler).ServeHTTP(nil, r) + require.True(t, handlerCalled) +} + +func TestMiddlewareClientNotAllowed(t *testing.T) { + var f fakeClient + f.ticket = &tvm.CheckedServiceTicket{SrcID: 43} + + m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + w := httptest.NewRecorder() + + m(nil).ServeHTTP(w, r) + require.Equal(t, 403, w.Code) +} + +func TestMiddlewareMissingTicket(t *testing.T) { + m := CheckServiceTicket(nil, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + m(nil).ServeHTTP(w, r) + require.Equal(t, 403, w.Code) +} + +func TestMiddlewareInvalidTicket(t *testing.T) { + var f fakeClient + f.err = &tvm.Error{} + + m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42})) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + w := httptest.NewRecorder() + + m(nil).ServeHTTP(w, r) + require.Equal(t, 403, w.Code) +} + +func TestMiddlewareMultipleDsts(t *testing.T) { + var f1, f2, f3 fakeClient + f1.err = &tvm.Error{} + f2.err = &tvm.Error{} + f3.ticket = &tvm.CheckedServiceTicket{SrcID: 42, DstID: 43} + + m := CheckServiceTicketMultiClient([]tvm.Client{ + &f1, + &f3, + &f2, + }, WithClientAuth(func(ctx context.Context, src tvm.ClientID, dst tvm.ClientID) error { + require.Equal(t, tvm.ClientID(43), dst) + require.Equal(t, tvm.ClientID(42), src) + return nil + })) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set(XYaServiceTicket, "123") + + var handlerCalled bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + require.Equal(t, f3.ticket, tvm.ContextServiceTicket(r.Context())) + }) + + m(handler).ServeHTTP(nil, r) + require.True(t, handlerCalled) +} diff --git a/library/go/httputil/middleware/tvm/ya.make b/library/go/httputil/middleware/tvm/ya.make new file mode 100644 index 0000000000..7aab530b70 --- /dev/null +++ b/library/go/httputil/middleware/tvm/ya.make @@ -0,0 +1,12 @@ +GO_LIBRARY() + +SRCS( + middleware.go + middleware_opts.go +) + +GO_TEST_SRCS(middleware_test.go) + +END() + +RECURSE(gotest) |