diff options
| -rw-r--r-- | ydb/mvp/oidc_proxy/context.cpp | 47 | ||||
| -rw-r--r-- | ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp | 80 | ||||
| -rw-r--r-- | ydb/mvp/oidc_proxy/oidc_session_create.cpp | 4 | ||||
| -rw-r--r-- | ydb/mvp/oidc_proxy/oidc_settings.h | 1 | ||||
| -rw-r--r-- | ydb/mvp/oidc_proxy/openid_connect.cpp | 198 | ||||
| -rw-r--r-- | ydb/mvp/oidc_proxy/openid_connect.h | 42 | ||||
| -rw-r--r-- | ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py | 60 |
7 files changed, 311 insertions, 121 deletions
diff --git a/ydb/mvp/oidc_proxy/context.cpp b/ydb/mvp/oidc_proxy/context.cpp index d638be70d93..678b3f7c318 100644 --- a/ydb/mvp/oidc_proxy/context.cpp +++ b/ydb/mvp/oidc_proxy/context.cpp @@ -1,10 +1,14 @@ +#include "context.h" +#include "oidc_settings.h" +#include "openid_connect.h" + +#include <ydb/library/actors/http/http.h> + +#include <library/cpp/json/json_writer.h> +#include <library/cpp/string_utils/base64/base64.h> + #include <util/generic/string.h> #include <util/string/builder.h> -#include <library/cpp/string_utils/base64/base64.h> -#include <ydb/library/actors/http/http.h> -#include "openid_connect.h" -#include "oidc_settings.h" -#include "context.h" namespace NMVP::NOIDC { @@ -22,15 +26,13 @@ TContext::TContext(const NHttp::THttpIncomingRequestPtr& request) TString TContext::GetState(const TString& key) const { static const TDuration STATE_LIFE_TIME = TDuration::Minutes(10); - TInstant expirationTime = TInstant::Now() + STATE_LIFE_TIME; - TStringBuilder json; - json << "{\"state\":\"" << State - << "\",\"expiration_time\":\"" << ToString(expirationTime.TimeT()) << "\"}"; - TString digest = HmacSHA1(key, json); - TStringBuilder signedState; - signedState << "{\"container\":\"" << Base64Encode(json) << "\"," - "\"digest\":\"" << Base64Encode(digest) << "\"}"; - return Base64EncodeNoPadding(signedState); + TState payload; + payload.AntiForgeryToken = State; + payload.ExpirationTime = TInstant::Now() + STATE_LIFE_TIME; + if (!NavigationRequest) { + payload.CookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX); + } + return EncodeState(payload, key); } bool TContext::IsNavigationRequest() const { @@ -43,7 +45,7 @@ TString TContext::GetRequestedAddress() const { TString TContext::CreateYdbOidcCookie(const TString& secret) const { static constexpr size_t COOKIE_MAX_AGE_SEC = 3600; - return TStringBuilder() << TOpenIdConnectSettings::YDB_OIDC_COOKIE << "=" + return TStringBuilder() << CreateNameYdbOidcCookie(NavigationRequest ? TStringBuf() : TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX) << "=" << GenerateCookie(secret) << ";" " Path=" << GetAuthCallbackUrl() << ";" " Max-Age=" << COOKIE_MAX_AGE_SEC << ";" @@ -51,13 +53,16 @@ TString TContext::CreateYdbOidcCookie(const TString& secret) const { } TString TContext::GenerateCookie(const TString& key) const { - TStringBuilder requestedAddressContext; - requestedAddressContext << "{\"requested_address\":\"" << RequestedAddress << "\"}"; + NJson::TJsonValue json(NJson::JSON_MAP); + json["requested_address"] = RequestedAddress; + const TString requestedAddressContext = NJson::WriteJson(json, false); + TString digest = HmacSHA256(key, requestedAddressContext); - TStringBuilder signedRequestedAddress; - signedRequestedAddress << "{\"requested_address_context\":\"" << Base64Encode(requestedAddressContext) - << "\",\"digest\":\"" << Base64Encode(digest) << "\"}"; - return Base64Encode(signedRequestedAddress); + + NJson::TJsonValue root(NJson::JSON_MAP); + root["requested_address_context"] = Base64Encode(requestedAddressContext); + root["digest"] = Base64Encode(digest); + return Base64Encode(NJson::WriteJson(root, false)); } bool TContext::IsPageNavigationRequest(const NHttp::THttpIncomingRequestPtr& request) { diff --git a/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp b/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp index 82db06429a7..23fbd69ee09 100644 --- a/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp +++ b/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp @@ -13,6 +13,7 @@ #include <ydb/mvp/core/mvp_test_runtime.h> #include <ydb/library/security/util.h> #include <library/cpp/json/json_reader.h> +#include <library/cpp/json/json_writer.h> #include <library/cpp/string_utils/base64/base64.h> #include <library/cpp/testing/unittest/registar.h> #include <util/generic/map.h> @@ -572,19 +573,21 @@ Y_UNIT_TEST_SUITE(Mvp) { redirectStrategy.CheckRedirectStatus(outgoingResponseEv); TString location = redirectStrategy.GetRedirectUrl(outgoingResponseEv); UNIT_ASSERT_STRING_CONTAINS(location, "https://auth.test.net/oauth/authorize"); - UNIT_ASSERT_STRING_CONTAINS(location, "response_type=code"); - UNIT_ASSERT_STRING_CONTAINS(location, "scope=openid"); - UNIT_ASSERT_STRING_CONTAINS(location, "client_id=" + settings.ClientId); - UNIT_ASSERT_STRING_CONTAINS(location, "redirect_uri=https://" + hostProxy + "/auth/callback"); NHttp::TUrlParameters urlParameters(location); - const TString state = urlParameters["state"]; + UNIT_ASSERT_STRINGS_EQUAL(urlParameters["response_type"], "code"); + UNIT_ASSERT_STRINGS_EQUAL(urlParameters["scope"], "openid"); + UNIT_ASSERT_STRINGS_EQUAL(urlParameters["client_id"], settings.ClientId); + UNIT_ASSERT_STRINGS_EQUAL(urlParameters["redirect_uri"], "https://" + hostProxy + "/auth/callback"); + const TString state = TString(urlParameters.Get("state")); const NHttp::THeaders headers(outgoingResponseEv->Response->Headers); UNIT_ASSERT(headers.Has("X-Request-Id")); UNIT_ASSERT(headers.Has("Set-Cookie")); TStringBuf setCookie = headers.Get("Set-Cookie"); - UNIT_ASSERT_STRING_CONTAINS(setCookie, TOpenIdConnectSettings::YDB_OIDC_COOKIE); + UNIT_ASSERT_STRING_CONTAINS( + setCookie, + CreateNameYdbOidcCookie(redirectStrategy.IsNavigationRequest() ? TStringBuf() : TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX)); redirectStrategy.CheckSpecificHeaders(headers); const NActors::TActorId sessionCreator = runtime.Register(new TSessionCreateHandler(edge, settings)); @@ -692,10 +695,15 @@ Y_UNIT_TEST_SUITE(Mvp) { TAutoPtr<IEventHandle> handle; NHttp::TEvHttpProxy::TEvHttpOutgoingResponse* outgoingResponseEv = runtime.GrabEdgeEvent<NHttp::TEvHttpProxy::TEvHttpOutgoingResponse>(handle); - UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "302"); const NHttp::THeaders protectedPageHeaders(outgoingResponseEv->Response->Headers); - UNIT_ASSERT(protectedPageHeaders.Has("Location")); - UNIT_ASSERT_STRINGS_EQUAL(protectedPageHeaders.Get("Location"), "/requested/page"); + if (redirectStrategy.IsNavigationRequest()) { + UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "302"); + UNIT_ASSERT(protectedPageHeaders.Has("Location")); + UNIT_ASSERT_STRINGS_EQUAL(protectedPageHeaders.Get("Location"), "/requested/page"); + } else { + UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "400"); + UNIT_ASSERT(!protectedPageHeaders.Has("Location")); + } } Y_UNIT_TEST(OpenIdConnectWrongStateAuthorizationFlow) { @@ -708,6 +716,22 @@ Y_UNIT_TEST_SUITE(Mvp) { OidcWrongStateAuthorizationFlow(redirectStrategy); } + Y_UNIT_TEST(OpenIdConnectExpiredBackgroundStateKeepsCookieSuffix) { + TPortManager tp; + auto settings = BuildBaseSettings(tp); + TState sourcePayload; + sourcePayload.AntiForgeryToken = "state"; + sourcePayload.ExpirationTime = TInstant::Seconds(0); + sourcePayload.CookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX); + + TCheckStateResult result = CheckState(EncodeState(sourcePayload, settings.ClientSecret), settings.ClientSecret); + const TString expectedCookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX); + + UNIT_ASSERT(!result.Ok); + UNIT_ASSERT_STRINGS_EQUAL(result.CookieSuffix, expectedCookieSuffix); + UNIT_ASSERT_STRING_CONTAINS(result.ErrorMessage, "State life time expired"); + } + Y_UNIT_TEST(OpenIdConnectSessionServiceCreateAuthorizationFail) { TPortManager tp; TMvpTestRuntime runtime; @@ -1411,17 +1435,17 @@ Y_UNIT_TEST_SUITE(Mvp) { } static TString GetViewerResponse200() { - TStringBuilder body; - body << "{\"UserSID\":\"" << VIEWER_USER_ACCOUNT_ID - << "\",\"OriginalUserToken\":\"" << TProfileServiceMock::VALID_USER_TOKEN << "\"}"; - return MakeHttpResponse("200 OK", body, "application/json"); + NJson::TJsonValue json(NJson::JSON_MAP); + json["UserSID"] = VIEWER_USER_ACCOUNT_ID; + json["OriginalUserToken"] = TProfileServiceMock::VALID_USER_TOKEN; + return MakeHttpResponse("200 OK", NJson::WriteJson(json, false), "application/json"); } static TString GetViewerResponseService200() { - TStringBuilder body; - body << "{\"UserSID\":\"" << VIEWER_SERVICE_ACCOUNT_ID - << "\",\"OriginalUserToken\":\"" << TProfileServiceMock::VALID_SERVICE_TOKEN << "\"}"; - return MakeHttpResponse("200 OK", body, "application/json"); + NJson::TJsonValue json(NJson::JSON_MAP); + json["UserSID"] = VIEWER_SERVICE_ACCOUNT_ID; + json["OriginalUserToken"] = TProfileServiceMock::VALID_SERVICE_TOKEN; + return MakeHttpResponse("200 OK", NJson::WriteJson(json, false), "application/json"); } static TString GetViewerResponse403() { @@ -1607,6 +1631,28 @@ static void NavigationRequestTest(const TString& rawRequest, bool expectedNaviga } Y_UNIT_TEST_SUITE(Utils) { + Y_UNIT_TEST(OpenIdConnectStateRoundTrip) { + TPortManager tp; + auto settings = BuildBaseSettings(tp); + + TState sourcePayload; + sourcePayload.AntiForgeryToken = "state"; + sourcePayload.ExpirationTime = TInstant::Seconds(TInstant::Now().Seconds() + TDuration::Minutes(10).Seconds()); + sourcePayload.CookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX); + + const TString state = EncodeState(sourcePayload, settings.ClientSecret); + const TCheckStateResult result = CheckState(state, settings.ClientSecret); + const TDecodeStateResult decodedResult = DecodeState(state); + + UNIT_ASSERT(result.Ok); + UNIT_ASSERT_STRINGS_EQUAL(result.CookieSuffix, sourcePayload.CookieSuffix); + UNIT_ASSERT(result.ErrorMessage.empty()); + UNIT_ASSERT(decodedResult.HasSignedStateJson); + UNIT_ASSERT(decodedResult.HasStateContainerJson); + UNIT_ASSERT(decodedResult.Payload == sourcePayload); + UNIT_ASSERT_STRINGS_EQUAL(EncodeState(decodedResult.Payload, settings.ClientSecret), state); + } + Y_UNIT_TEST(GenerateRandomBase64RandomUniqueness) { THashSet<TString> seen; for (size_t i = 0; i < 100; ++i) { diff --git a/ydb/mvp/oidc_proxy/oidc_session_create.cpp b/ydb/mvp/oidc_proxy/oidc_session_create.cpp index 8cd57a42959..bcac8054175 100644 --- a/ydb/mvp/oidc_proxy/oidc_session_create.cpp +++ b/ydb/mvp/oidc_proxy/oidc_session_create.cpp @@ -31,10 +31,10 @@ void THandlerSessionCreate::Bootstrap() { NHttp::THeaders headers(Request->Headers); NHttp::TCookies cookies(headers.Get("cookie")); - TRestoreOidcContextResult restoreContextResult = RestoreOidcContext(cookies, Settings.ClientSecret); + TRestoreOidcContextResult restoreContextResult = RestoreOidcContext(cookies, Settings.ClientSecret, checkStateResult.CookieSuffix); Context = restoreContextResult.Context; - if (checkStateResult.IsSuccess()) { + if (checkStateResult.Ok) { if (restoreContextResult.IsSuccess()) { if (code.empty()) { BLOG_D("Restore oidc session failed: receive empty 'code' parameter"); diff --git a/ydb/mvp/oidc_proxy/oidc_settings.h b/ydb/mvp/oidc_proxy/oidc_settings.h index 2306eb1e21d..a26f72366ab 100644 --- a/ydb/mvp/oidc_proxy/oidc_settings.h +++ b/ydb/mvp/oidc_proxy/oidc_settings.h @@ -10,6 +10,7 @@ namespace NMVP::NOIDC { struct TOpenIdConnectSettings { static const inline TString YDB_OIDC_COOKIE = "ydb_oidc_cookie"; + static const inline TStringBuf YDB_OIDC_COOKIE_BACKGROUND_SUFFIX = "_background"; static const inline TString SESSION_COOKIE = "session_cookie"; static const inline TString IMPERSONATED_COOKIE = "impersonated_cookie"; diff --git a/ydb/mvp/oidc_proxy/openid_connect.cpp b/ydb/mvp/oidc_proxy/openid_connect.cpp index 3b8cdaf6c46..c8470cbf570 100644 --- a/ydb/mvp/oidc_proxy/openid_connect.cpp +++ b/ydb/mvp/oidc_proxy/openid_connect.cpp @@ -5,7 +5,9 @@ #include <ydb/core/util/wildcard.h> #include <ydb/library/security/util.h> +#include <library/cpp/cgiparam/cgiparam.h> #include <library/cpp/json/json_reader.h> +#include <library/cpp/json/json_writer.h> #include <library/cpp/string_utils/base64/base64.h> #include <openssl/evp.h> @@ -14,7 +16,6 @@ #include <util/string/builder.h> #include <util/string/hex.h> - namespace NMVP::NOIDC { TRestoreOidcContextResult::TRestoreOidcContextResult(const TStatus& status, const TContext& context) @@ -27,13 +28,18 @@ bool TRestoreOidcContextResult::IsSuccess() const { return Status.IsSuccess; } -TCheckStateResult::TCheckStateResult(bool success, const TString& errorMessage) - : Success(success) +TCheckStateResult::TCheckStateResult(bool ok, const TString& cookieSuffix, const TString& errorMessage) + : Ok(ok) , ErrorMessage(errorMessage) + , CookieSuffix(cookieSuffix) {} -bool TCheckStateResult::IsSuccess() const { - return Success; +TCheckStateResult TCheckStateResult::Error(const TString& errorMessage, const TString& cookieSuffix) { + return TCheckStateResult(false, cookieSuffix, errorMessage); +} + +TCheckStateResult TCheckStateResult::Success(const TString& cookieSuffix) { + return TCheckStateResult(true, cookieSuffix, ""); } void SetCORS(const NHttp::THttpIncomingRequestPtr& request, NHttp::THeadersBuilder* const headers) { @@ -79,14 +85,19 @@ void SetHeader(NYdbGrpc::TCallMeta& meta, const TString& name, const TString& va NHttp::THttpOutgoingResponsePtr GetHttpOutgoingResponsePtr(const NHttp::THttpIncomingRequestPtr& request, const TOpenIdConnectSettings& settings, TStringBuf requestId) { TContext context(request); - const TString redirectUrl = TStringBuilder() << settings.GetAuthEndpointURL() - << "?response_type=code" - << "&scope=openid" - << "&state=" << context.GetState(settings.ClientSecret) - << "&client_id=" << settings.ClientId - << "&redirect_uri=" << (request->Endpoint->Secure ? "https://" : "http://") - << request->Host - << GetAuthCallbackUrl(); + const TString redirectUri = TStringBuilder() + << (request->Endpoint->Secure ? "https://" : "http://") + << request->Host + << GetAuthCallbackUrl(); + + TCgiParameters authParams; + authParams.InsertUnescaped("response_type", "code"); + authParams.InsertUnescaped("scope", "openid"); + authParams.InsertUnescaped("state", context.GetState(settings.ClientSecret)); + authParams.InsertUnescaped("client_id", settings.ClientId); + authParams.InsertUnescaped("redirect_uri", redirectUri); + + const TString redirectUrl = settings.GetAuthEndpointURL() + "?" + authParams.Print(); NHttp::THeadersBuilder responseHeaders; SetCORS(request, &responseHeaders); SetRequestIdHeader(responseHeaders, requestId); @@ -96,12 +107,15 @@ NHttp::THttpOutgoingResponsePtr GetHttpOutgoingResponsePtr(const NHttp::THttpInc return request->CreateResponse("302", "Authorization required", responseHeaders); } responseHeaders.Set("Content-Type", "application/json; charset=utf-8"); - TString body {"{\"error\":\"Authorization Required\",\"authUrl\":\"" + redirectUrl + "\"}"}; + NJson::TJsonValue json(NJson::JSON_MAP); + json["error"] = "Authorization Required"; + json["authUrl"] = redirectUrl; + const TString body = NJson::WriteJson(json, false); return request->CreateResponse("401", "Unauthorized", responseHeaders, body); } -TString CreateNameYdbOidcCookie(TStringBuf key, TStringBuf state) { - return TOpenIdConnectSettings::YDB_OIDC_COOKIE + "_" + HexEncode(HmacSHA256(key, state)); +TString CreateNameYdbOidcCookie(TStringBuf suffix) { + return TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE) + TString(suffix); } TString CreateNameSessionCookie(TStringBuf key) { @@ -131,23 +145,24 @@ TString ClearSecureCookie(const TString& name) { return cookieBuilder; } -TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key) { +TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key, TStringBuf cookieSuffix) { TStringBuilder errorMessage; errorMessage << "Restore oidc context failed: "; - if (!cookies.Has(TOpenIdConnectSettings::YDB_OIDC_COOKIE)) { + TString cookieName = CreateNameYdbOidcCookie(cookieSuffix); + if (!cookies.Has(cookieName)) { return TRestoreOidcContextResult({.IsSuccess = false, .IsErrorRetryable = false, - .ErrorMessage = errorMessage << "Cannot find cookie " << TOpenIdConnectSettings::YDB_OIDC_COOKIE}); + .ErrorMessage = errorMessage << "Cannot find cookie " << cookieName}); } - TString signedRequestedAddress = Base64Decode(cookies.Get(TOpenIdConnectSettings::YDB_OIDC_COOKIE)); + TString signedRequestedAddress = Base64Decode(cookies.Get(cookieName)); TString requestedAddressContext; TString expectedDigest; NJson::TJsonValue jsonValue; NJson::TJsonReaderConfig jsonConfig; if (NJson::ReadJsonTree(signedRequestedAddress, &jsonConfig, &jsonValue)) { const NJson::TJsonValue* jsonRequestedAddressContext = nullptr; - if (jsonValue.GetValuePointer("requested_address_context", &jsonRequestedAddressContext)) { - requestedAddressContext = jsonRequestedAddressContext->GetStringRobust(); + if (jsonValue.GetValuePointer("requested_address_context", &jsonRequestedAddressContext) && jsonRequestedAddressContext->IsString()) { + requestedAddressContext = jsonRequestedAddressContext->GetString(); requestedAddressContext = Base64Decode(requestedAddressContext); } if (requestedAddressContext.empty()) { @@ -156,8 +171,8 @@ TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, con .ErrorMessage = errorMessage << "Struct with state is empty"}); } const NJson::TJsonValue* jsonDigest = nullptr; - if (jsonValue.GetValuePointer("digest", &jsonDigest)) { - expectedDigest = jsonDigest->GetStringRobust(); + if (jsonValue.GetValuePointer("digest", &jsonDigest) && jsonDigest->IsString()) { + expectedDigest = jsonDigest->GetString(); expectedDigest = Base64Decode(expectedDigest); } if (expectedDigest.empty()) { @@ -173,66 +188,113 @@ TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, con .ErrorMessage = errorMessage << "Calculated digest is not equal expected digest"}); } TString requestedAddress; - if (NJson::ReadJsonTree(requestedAddressContext, &jsonConfig, &jsonValue)) { - const NJson::TJsonValue* jsonRequestedAddress = nullptr; - if (jsonValue.GetValuePointer("requested_address", &jsonRequestedAddress)) { - requestedAddress = jsonRequestedAddress->GetStringRobust(); - } else { - return TRestoreOidcContextResult({.IsSuccess = false, - .IsErrorRetryable = false, - .ErrorMessage = errorMessage << "Requested address was not found in the cookie"}); - } + if (!NJson::ReadJsonTree(requestedAddressContext, &jsonConfig, &jsonValue)) { + return TRestoreOidcContextResult({.IsSuccess = false, + .IsErrorRetryable = false, + .ErrorMessage = errorMessage << "Requested address context is not valid json"}); + } + const NJson::TJsonValue* jsonRequestedAddress = nullptr; + if (jsonValue.GetValuePointer("requested_address", &jsonRequestedAddress) && jsonRequestedAddress->IsString()) { + requestedAddress = jsonRequestedAddress->GetString(); + } else { + return TRestoreOidcContextResult({.IsSuccess = false, + .IsErrorRetryable = false, + .ErrorMessage = errorMessage << "Requested address was not found in the cookie"}); } return TRestoreOidcContextResult({.IsSuccess = true, .IsErrorRetryable = true, .ErrorMessage = ""}, TContext({.RequestedAddress = requestedAddress})); } -TCheckStateResult CheckState(const TString& state, const TString& key) { - TStringBuilder errorMessage; - errorMessage << "Check state failed: "; - TString signedState = Base64DecodeUneven(state); - TString stateContainer; - TString expectedDigest; +TCheckStateResult TDecodeStateResult::Check(const TString& key) const { + static constexpr TStringBuf ErrorPrefix = "Check state failed: "; + if (!HasSignedStateJson) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "Signed state is not valid json"); + } + if (StateContainer.empty()) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "Container with state is empty"); + } + if (ExpectedDigest.empty()) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "Expected digest is empty"); + } + + TString digest = HmacSHA1(key, StateContainer); + if (ExpectedDigest != digest) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "Calculated digest is not equal expected digest"); + } + if (!HasStateContainerJson) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "State container is not valid json"); + } + if (!Payload.ExpirationTime) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "Expiration time not found in json", Payload.CookieSuffix); + } + if (TInstant::Now() > *Payload.ExpirationTime) { + return TCheckStateResult::Error(TString(ErrorPrefix) + "State life time expired", Payload.CookieSuffix); + } + return TCheckStateResult::Success(Payload.CookieSuffix); +} + +TString EncodeState(const TState& payload, TStringBuf signingKey) { + NJson::TJsonValue json(NJson::JSON_MAP); + json["state"] = payload.AntiForgeryToken; + if (payload.ExpirationTime) { + json["expiration_time"] = ToString(payload.ExpirationTime->TimeT()); + } + if (!payload.CookieSuffix.empty()) { + json["cookie_suffix"] = payload.CookieSuffix; + } + const TString stateContainer = NJson::WriteJson(json, false); + + TString digest = HmacSHA1(signingKey, stateContainer); + + NJson::TJsonValue root(NJson::JSON_MAP); + root["container"] = Base64Encode(stateContainer); + root["digest"] = Base64Encode(digest); + return Base64EncodeNoPadding(NJson::WriteJson(root, false)); +} + +TDecodeStateResult DecodeState(TStringBuf encodedState) { + TDecodeStateResult result; + TString signedState = Base64DecodeUneven(encodedState); NJson::TJsonValue jsonValue; NJson::TJsonReaderConfig jsonConfig; - if (NJson::ReadJsonTree(signedState, &jsonConfig, &jsonValue)) { + result.HasSignedStateJson = NJson::ReadJsonTree(signedState, &jsonConfig, &jsonValue); + if (result.HasSignedStateJson) { const NJson::TJsonValue* jsonStateContainer = nullptr; - if (jsonValue.GetValuePointer("container", &jsonStateContainer)) { - stateContainer = jsonStateContainer->GetStringRobust(); - stateContainer = Base64Decode(stateContainer); - } - if (stateContainer.empty()) { - return TCheckStateResult(false, errorMessage << "Container with state is empty"); + if (jsonValue.GetValuePointer("container", &jsonStateContainer) && jsonStateContainer->IsString()) { + result.StateContainer = jsonStateContainer->GetString(); + result.StateContainer = Base64Decode(result.StateContainer); } const NJson::TJsonValue* jsonDigest = nullptr; - if (jsonValue.GetValuePointer("digest", &jsonDigest)) { - expectedDigest = jsonDigest->GetStringRobust(); - expectedDigest = Base64Decode(expectedDigest); + if (jsonValue.GetValuePointer("digest", &jsonDigest) && jsonDigest->IsString()) { + result.ExpectedDigest = jsonDigest->GetString(); + result.ExpectedDigest = Base64Decode(result.ExpectedDigest); } - if (expectedDigest.empty()) { - return TCheckStateResult(false, errorMessage << "Expected digest is empty"); - } - } - TString digest = HmacSHA1(key, stateContainer); - if (expectedDigest != digest) { - return TCheckStateResult(false, errorMessage << "Calculated digest is not equal expected digest"); } - if (NJson::ReadJsonTree(stateContainer, &jsonConfig, &jsonValue)) { - const NJson::TJsonValue* jsonExpirationTime = nullptr; - if (jsonValue.GetValuePointer("expiration_time", &jsonExpirationTime)) { - timeval timeVal { - .tv_sec = jsonExpirationTime->GetIntegerRobust(), - .tv_usec = 0 - }; - if (TInstant::Now() > TInstant(timeVal)) { - return TCheckStateResult(false, errorMessage << "State life time expired"); + + if (!result.StateContainer.empty()) { + result.HasStateContainerJson = NJson::ReadJsonTree(result.StateContainer, &jsonConfig, &jsonValue); + if (result.HasStateContainerJson) { + const NJson::TJsonValue* jsonState = nullptr; + if (jsonValue.GetValuePointer("state", &jsonState) && jsonState->IsString()) { + result.Payload.AntiForgeryToken = jsonState->GetString(); + } + const NJson::TJsonValue* jsonCookieSuffix = nullptr; + if (jsonValue.GetValuePointer("cookie_suffix", &jsonCookieSuffix) && jsonCookieSuffix->IsString()) { + result.Payload.CookieSuffix = jsonCookieSuffix->GetString(); + } + const NJson::TJsonValue* jsonExpirationTime = nullptr; + if (jsonValue.GetValuePointer("expiration_time", &jsonExpirationTime)) { + result.Payload.ExpirationTime = TInstant::Seconds(jsonExpirationTime->GetIntegerRobust()); } - } else { - return TCheckStateResult(false, errorMessage << "Expiration time not found in json"); } } - return TCheckStateResult(); + + return result; +} + +TCheckStateResult CheckState(const TString& state, const TString& key) { + return DecodeState(state).Check(key); } TString DecodeToken(const TStringBuf& cookie) { diff --git a/ydb/mvp/oidc_proxy/openid_connect.h b/ydb/mvp/oidc_proxy/openid_connect.h index 670990c2f86..24c0efcc096 100644 --- a/ydb/mvp/oidc_proxy/openid_connect.h +++ b/ydb/mvp/oidc_proxy/openid_connect.h @@ -11,7 +11,9 @@ #include <ydb/mvp/core/core_ydb.h> #include <ydb/public/api/client/yc_private/oauth/session_service.grpc.pb.h> #include <ydb/public/api/client/nc_private/iam/v1/profile_service.grpc.pb.h> +#include <util/datetime/base.h> #include <util/generic/ptr.h> +#include <util/generic/maybe.h> #include <util/generic/string.h> namespace NMVP::NOIDC { @@ -45,27 +47,57 @@ struct TRestoreOidcContextResult { bool IsSuccess() const; }; +struct TCheckStateResult; + +struct TState { + TString AntiForgeryToken; + TString CookieSuffix; + TMaybe<TInstant> ExpirationTime; + + bool operator==(const TState& other) const { + return AntiForgeryToken == other.AntiForgeryToken + && CookieSuffix == other.CookieSuffix + && ExpirationTime == other.ExpirationTime; + } +}; + +struct TDecodeStateResult { + bool HasSignedStateJson = false; + bool HasStateContainerJson = false; + + TString StateContainer; + TString ExpectedDigest; + TState Payload; + + TCheckStateResult Check(const TString& key) const; +}; + struct TCheckStateResult { - bool Success = true; + bool Ok = true; TString ErrorMessage; + TString CookieSuffix; - TCheckStateResult(bool success = true, const TString& errorMessage = ""); + static TCheckStateResult Error(const TString& errorMessage, const TString& cookieSuffix = ""); + static TCheckStateResult Success(const TString& cookieSuffix = ""); - bool IsSuccess() const; +private: + TCheckStateResult(bool ok, const TString& cookieSuffix, const TString& errorMessage); }; TString HmacSHA256(TStringBuf key, TStringBuf data); TString HmacSHA1(TStringBuf key, TStringBuf data); void SetHeader(NYdbGrpc::TCallMeta& meta, const TString& name, const TString& value); NHttp::THttpOutgoingResponsePtr GetHttpOutgoingResponsePtr(const NHttp::THttpIncomingRequestPtr& request, const TOpenIdConnectSettings& settings, TStringBuf requestId); -TString CreateNameYdbOidcCookie(TStringBuf key, TStringBuf state); +TString CreateNameYdbOidcCookie(TStringBuf suffix = ""); TString CreateNameSessionCookie(TStringBuf key); TString CreateNameImpersonatedCookie(TStringBuf key); const TString& GetAuthCallbackUrl(); TString CreateSecureCookie(const TString& name, const TString& value, const ui32 expiredSeconds); TString ClearSecureCookie(const TString& name); void SetCORS(const NHttp::THttpIncomingRequestPtr& request, NHttp::THeadersBuilder* const headers); -TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key); +TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key, TStringBuf cookieSuffix = ""); +TString EncodeState(const TState& payload, TStringBuf signingKey); +TDecodeStateResult DecodeState(TStringBuf encodedState); TCheckStateResult CheckState(const TString& state, const TString& key); TString DecodeToken(const TStringBuf& cookie); TStringBuf GetCookie(const NHttp::TCookies& cookies, const TString& cookieName); diff --git a/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py b/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py index a2afc3c1a5b..06655c86715 100644 --- a/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py +++ b/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py @@ -1,5 +1,5 @@ import requests -from urllib.parse import urlparse +from urllib.parse import parse_qs, urlparse from oidc_proxy_testlib import ( CALLBACK_PATH, @@ -57,7 +57,7 @@ def build_base_expired_session_request_headers(host): return headers -def build_ajax_expired_session_request_headers(host): +def build_background_expired_session_request_headers(host): headers = build_base_expired_session_request_headers(host) headers.update({ "Accept": "*/*", @@ -93,11 +93,11 @@ def get_protected_landing_data_path(env): return f"/{protected_host(env)}{LANDING_DATA_PATH}" -def start_ajax_auth_challenge_request(env): +def start_background_auth_challenge_request(env): host = protected_host(env) return start_auth_challenge_request( env, - build_ajax_expired_session_request_headers(host), + build_background_expired_session_request_headers(host), ) @@ -109,7 +109,7 @@ def start_navigation_auth_challenge_request(env): ) -def assert_ajax_auth_challenge(response): +def assert_background_auth_challenge(response): assert response.status_code == 401, response.text response_json = response.json() assert response_json["error"] == "Authorization Required", response_json @@ -122,13 +122,57 @@ def assert_navigation_auth_redirect(env, response): assert response.status_code == 302, response.text redirect_url = response.headers["Location"] assert redirect_url.startswith(env.auth_service.endpoint + AUTH_AUTHORIZE_PATH), redirect_url + assert "Set-Cookie" in response.headers, response.headers + + +def start_navigation_auth_flow(env, start_path): + response = env.get( + start_path, + allow_redirects=False, + headers={"Host": "oidcproxy.net"}, + ) + assert_navigation_auth_redirect(env, response) + state = parse_qs(urlparse(response.headers["Location"]).query)["state"][0] + oidc_cookie = response.headers["Set-Cookie"].split(";", 1)[0] + return state, oidc_cookie + + +def finish_auth_callback(env, state, oidc_cookies): + return env.get( + "/auth/callback", + params={ + "code": "code_template#", + "state": state, + }, + allow_redirects=False, + headers={ + "Host": "oidcproxy.net", + "Cookie": "; ".join(oidc_cookies), + }, + ) -def test_ajax_request_returns_json_401_with_auth_url(oidc_proxy_full_flow_env): - response = start_ajax_auth_challenge_request(oidc_proxy_full_flow_env) - assert_ajax_auth_challenge(response) +def test_background_request_returns_json_401_with_auth_url(oidc_proxy_full_flow_env): + response = start_background_auth_challenge_request(oidc_proxy_full_flow_env) + assert_background_auth_challenge(response) def test_navigation_request_returns_oidc_redirect(oidc_proxy_full_flow_env): response = start_navigation_auth_challenge_request(oidc_proxy_full_flow_env) assert_navigation_auth_redirect(oidc_proxy_full_flow_env, response) + + +def test_navigation_callback_uses_cookie_from_state_when_background_cookie_exists(oidc_proxy_full_flow_env): + env = oidc_proxy_full_flow_env + host = protected_host(env) + navigation_target = f"/{host}{LANDING_DATA_PATH}&from=navigation" + + navigation_state, navigation_oidc_cookie = start_navigation_auth_flow(env, navigation_target) + background_response = start_background_auth_challenge_request(env) + assert_background_auth_challenge(background_response) + background_oidc_cookie = background_response.headers["Set-Cookie"].split(";", 1)[0] + + callback_response = finish_auth_callback(env, navigation_state, [navigation_oidc_cookie, background_oidc_cookie]) + + assert callback_response.status_code == 302, callback_response.text + assert callback_response.headers["Location"] == navigation_target, callback_response.headers |
