aboutsummaryrefslogtreecommitdiffstats
path: root/library/go/httputil/middleware
diff options
context:
space:
mode:
authorhcpp <hcpp@ydb.tech>2023-11-08 12:09:41 +0300
committerhcpp <hcpp@ydb.tech>2023-11-08 12:56:14 +0300
commita361f5b98b98b44ea510d274f6769164640dd5e1 (patch)
treec47c80962c6e2e7b06798238752fd3da0191a3f6 /library/go/httputil/middleware
parent9478806fde1f4d40bd5a45e7cbe77237dab613e9 (diff)
downloadydb-a361f5b98b98b44ea510d274f6769164640dd5e1.tar.gz
metrics have been added
Diffstat (limited to 'library/go/httputil/middleware')
-rw-r--r--library/go/httputil/middleware/tvm/gotest/ya.make3
-rw-r--r--library/go/httputil/middleware/tvm/middleware.go112
-rw-r--r--library/go/httputil/middleware/tvm/middleware_opts.go46
-rw-r--r--library/go/httputil/middleware/tvm/middleware_test.go126
-rw-r--r--library/go/httputil/middleware/tvm/ya.make12
5 files changed, 299 insertions, 0 deletions
diff --git a/library/go/httputil/middleware/tvm/gotest/ya.make b/library/go/httputil/middleware/tvm/gotest/ya.make
new file mode 100644
index 0000000000..f8ad1ffb46
--- /dev/null
+++ b/library/go/httputil/middleware/tvm/gotest/ya.make
@@ -0,0 +1,3 @@
+GO_TEST_FOR(library/go/httputil/middleware/tvm)
+
+END()
diff --git a/library/go/httputil/middleware/tvm/middleware.go b/library/go/httputil/middleware/tvm/middleware.go
new file mode 100644
index 0000000000..2e578ffca1
--- /dev/null
+++ b/library/go/httputil/middleware/tvm/middleware.go
@@ -0,0 +1,112 @@
+package tvm
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/ydb-platform/ydb/library/go/core/log"
+ "github.com/ydb-platform/ydb/library/go/core/log/ctxlog"
+ "github.com/ydb-platform/ydb/library/go/core/log/nop"
+ "github.com/ydb-platform/ydb/library/go/httputil/headers"
+ "github.com/ydb-platform/ydb/library/go/yandex/tvm"
+ "golang.org/x/xerrors"
+)
+
+const (
+ // XYaServiceTicket is http header that should be used for service ticket transfer.
+ XYaServiceTicket = headers.XYaServiceTicketKey
+ // XYaUserTicket is http header that should be used for user ticket transfer.
+ XYaUserTicket = headers.XYaUserTicketKey
+)
+
+type (
+ MiddlewareOption func(*middleware)
+
+ middleware struct {
+ l log.Structured
+
+ clients []tvm.Client
+
+ authClient func(context.Context, tvm.ClientID, tvm.ClientID) error
+
+ onError func(w http.ResponseWriter, r *http.Request, err error)
+ }
+)
+
+func defaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
+ http.Error(w, err.Error(), http.StatusForbidden)
+}
+
+func getMiddleware(clients []tvm.Client, opts ...MiddlewareOption) middleware {
+ m := middleware{
+ clients: clients,
+ onError: defaultErrorHandler,
+ }
+
+ for _, opt := range opts {
+ opt(&m)
+ }
+
+ if m.authClient == nil {
+ panic("must provide authorization policy")
+ }
+
+ if m.l == nil {
+ m.l = &nop.Logger{}
+ }
+
+ return m
+}
+
+// CheckServiceTicketMultiClient returns http middleware that validates service tickets for all incoming requests.
+// It tries to check ticket with all the given clients in the given order
+// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket.
+func CheckServiceTicketMultiClient(clients []tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler {
+ m := getMiddleware(clients, opts...)
+ return m.wrap
+}
+
+// CheckServiceTicket returns http middleware that validates service tickets for all incoming requests.
+//
+// ServiceTicket is stored on request context. It might be retrieved by calling tvm.ContextServiceTicket.
+func CheckServiceTicket(client tvm.Client, opts ...MiddlewareOption) func(next http.Handler) http.Handler {
+ m := getMiddleware([]tvm.Client{client}, opts...)
+ return m.wrap
+}
+
+func (m *middleware) wrap(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ serviceTicket := r.Header.Get(XYaServiceTicket)
+ if serviceTicket == "" {
+ ctxlog.Error(r.Context(), m.l.Logger(), "missing service ticket")
+ m.onError(w, r, xerrors.New("missing service ticket"))
+ return
+ }
+ var (
+ ticket *tvm.CheckedServiceTicket
+ err error
+ )
+ for _, client := range m.clients {
+ ticket, err = client.CheckServiceTicket(r.Context(), serviceTicket)
+ if err == nil {
+ break
+ }
+ }
+ if err != nil {
+ ctxlog.Error(r.Context(), m.l.Logger(), "service ticket check failed", log.Error(err))
+ m.onError(w, r, xerrors.Errorf("service ticket check failed: %w", err))
+ return
+ }
+
+ if err := m.authClient(r.Context(), ticket.SrcID, ticket.DstID); err != nil {
+ ctxlog.Error(r.Context(), m.l.Logger(), "client authorization failed",
+ log.String("ticket", ticket.LogInfo),
+ log.Error(err))
+ m.onError(w, r, xerrors.Errorf("client authorization failed: %w", err))
+ return
+ }
+
+ r = r.WithContext(tvm.WithServiceTicket(r.Context(), ticket))
+ next.ServeHTTP(w, r)
+ })
+}
diff --git a/library/go/httputil/middleware/tvm/middleware_opts.go b/library/go/httputil/middleware/tvm/middleware_opts.go
new file mode 100644
index 0000000000..4e33b4ee59
--- /dev/null
+++ b/library/go/httputil/middleware/tvm/middleware_opts.go
@@ -0,0 +1,46 @@
+package tvm
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/ydb-platform/ydb/library/go/core/log"
+ "github.com/ydb-platform/ydb/library/go/yandex/tvm"
+ "golang.org/x/xerrors"
+)
+
+// WithAllowedClients sets list of allowed clients.
+func WithAllowedClients(allowedClients []tvm.ClientID) MiddlewareOption {
+ return func(m *middleware) {
+ m.authClient = func(_ context.Context, src tvm.ClientID, dst tvm.ClientID) error {
+ for _, allowed := range allowedClients {
+ if allowed == src {
+ return nil
+ }
+ }
+
+ return xerrors.Errorf("client with tvm_id=%d is not whitelisted", dst)
+ }
+ }
+}
+
+// WithClientAuth sets custom function for client authorization.
+func WithClientAuth(authClient func(ctx context.Context, src tvm.ClientID, dst tvm.ClientID) error) MiddlewareOption {
+ return func(m *middleware) {
+ m.authClient = authClient
+ }
+}
+
+// WithErrorHandler sets http handler invoked for rejected requests.
+func WithErrorHandler(h func(w http.ResponseWriter, r *http.Request, err error)) MiddlewareOption {
+ return func(m *middleware) {
+ m.onError = h
+ }
+}
+
+// WithLogger sets logger.
+func WithLogger(l log.Structured) MiddlewareOption {
+ return func(m *middleware) {
+ m.l = l
+ }
+}
diff --git a/library/go/httputil/middleware/tvm/middleware_test.go b/library/go/httputil/middleware/tvm/middleware_test.go
new file mode 100644
index 0000000000..e6005a76a6
--- /dev/null
+++ b/library/go/httputil/middleware/tvm/middleware_test.go
@@ -0,0 +1,126 @@
+package tvm
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/ydb-platform/ydb/library/go/yandex/tvm"
+)
+
+type fakeClient struct {
+ ticket *tvm.CheckedServiceTicket
+ err error
+}
+
+func (f *fakeClient) GetServiceTicketForAlias(ctx context.Context, alias string) (string, error) {
+ panic("implement me")
+}
+
+func (f *fakeClient) GetServiceTicketForID(ctx context.Context, dstID tvm.ClientID) (string, error) {
+ panic("implement me")
+}
+
+func (f *fakeClient) CheckServiceTicket(ctx context.Context, ticket string) (*tvm.CheckedServiceTicket, error) {
+ return f.ticket, f.err
+}
+
+func (f *fakeClient) CheckUserTicket(ctx context.Context, ticket string, opts ...tvm.CheckUserTicketOption) (*tvm.CheckedUserTicket, error) {
+ panic("implement me")
+}
+
+func (f *fakeClient) GetStatus(ctx context.Context) (tvm.ClientStatusInfo, error) {
+ panic("implement me")
+}
+
+func (f *fakeClient) GetRoles(ctx context.Context) (*tvm.Roles, error) {
+ panic("implement me")
+}
+
+func TestMiddlewareOkTicket(t *testing.T) {
+ var f fakeClient
+ f.ticket = &tvm.CheckedServiceTicket{SrcID: 42}
+
+ m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42}))
+
+ r := httptest.NewRequest("GET", "/", nil)
+ r.Header.Set(XYaServiceTicket, "123")
+
+ var handlerCalled bool
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ handlerCalled = true
+ require.Equal(t, f.ticket, tvm.ContextServiceTicket(r.Context()))
+ })
+
+ m(handler).ServeHTTP(nil, r)
+ require.True(t, handlerCalled)
+}
+
+func TestMiddlewareClientNotAllowed(t *testing.T) {
+ var f fakeClient
+ f.ticket = &tvm.CheckedServiceTicket{SrcID: 43}
+
+ m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42}))
+
+ r := httptest.NewRequest("GET", "/", nil)
+ r.Header.Set(XYaServiceTicket, "123")
+ w := httptest.NewRecorder()
+
+ m(nil).ServeHTTP(w, r)
+ require.Equal(t, 403, w.Code)
+}
+
+func TestMiddlewareMissingTicket(t *testing.T) {
+ m := CheckServiceTicket(nil, WithAllowedClients([]tvm.ClientID{42}))
+
+ r := httptest.NewRequest("GET", "/", nil)
+ w := httptest.NewRecorder()
+
+ m(nil).ServeHTTP(w, r)
+ require.Equal(t, 403, w.Code)
+}
+
+func TestMiddlewareInvalidTicket(t *testing.T) {
+ var f fakeClient
+ f.err = &tvm.Error{}
+
+ m := CheckServiceTicket(&f, WithAllowedClients([]tvm.ClientID{42}))
+
+ r := httptest.NewRequest("GET", "/", nil)
+ r.Header.Set(XYaServiceTicket, "123")
+ w := httptest.NewRecorder()
+
+ m(nil).ServeHTTP(w, r)
+ require.Equal(t, 403, w.Code)
+}
+
+func TestMiddlewareMultipleDsts(t *testing.T) {
+ var f1, f2, f3 fakeClient
+ f1.err = &tvm.Error{}
+ f2.err = &tvm.Error{}
+ f3.ticket = &tvm.CheckedServiceTicket{SrcID: 42, DstID: 43}
+
+ m := CheckServiceTicketMultiClient([]tvm.Client{
+ &f1,
+ &f3,
+ &f2,
+ }, WithClientAuth(func(ctx context.Context, src tvm.ClientID, dst tvm.ClientID) error {
+ require.Equal(t, tvm.ClientID(43), dst)
+ require.Equal(t, tvm.ClientID(42), src)
+ return nil
+ }))
+
+ r := httptest.NewRequest("GET", "/", nil)
+ r.Header.Set(XYaServiceTicket, "123")
+
+ var handlerCalled bool
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ handlerCalled = true
+ require.Equal(t, f3.ticket, tvm.ContextServiceTicket(r.Context()))
+ })
+
+ m(handler).ServeHTTP(nil, r)
+ require.True(t, handlerCalled)
+}
diff --git a/library/go/httputil/middleware/tvm/ya.make b/library/go/httputil/middleware/tvm/ya.make
new file mode 100644
index 0000000000..7aab530b70
--- /dev/null
+++ b/library/go/httputil/middleware/tvm/ya.make
@@ -0,0 +1,12 @@
+GO_LIBRARY()
+
+SRCS(
+ middleware.go
+ middleware_opts.go
+)
+
+GO_TEST_SRCS(middleware_test.go)
+
+END()
+
+RECURSE(gotest)