summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAndrei Rykov <[email protected]>2026-05-03 14:40:32 +0200
committerGitHub <[email protected]>2026-05-03 14:40:32 +0200
commite1e2d046cb0f8fc6f72d496d7a73374585eef3da (patch)
tree11965016a53e503da7b09a523072f7a10925b0d9
parent621e8550a009af531760cd5b2be3dd7c3ce0ba8a (diff)
OIDC: separate OIDC context cookie on navigation and background (#39051)oidc-1.2.8.12-devmeta-1.0.5
-rw-r--r--ydb/mvp/oidc_proxy/context.cpp47
-rw-r--r--ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp80
-rw-r--r--ydb/mvp/oidc_proxy/oidc_session_create.cpp4
-rw-r--r--ydb/mvp/oidc_proxy/oidc_settings.h1
-rw-r--r--ydb/mvp/oidc_proxy/openid_connect.cpp198
-rw-r--r--ydb/mvp/oidc_proxy/openid_connect.h42
-rw-r--r--ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py60
7 files changed, 311 insertions, 121 deletions
diff --git a/ydb/mvp/oidc_proxy/context.cpp b/ydb/mvp/oidc_proxy/context.cpp
index d638be70d93..678b3f7c318 100644
--- a/ydb/mvp/oidc_proxy/context.cpp
+++ b/ydb/mvp/oidc_proxy/context.cpp
@@ -1,10 +1,14 @@
+#include "context.h"
+#include "oidc_settings.h"
+#include "openid_connect.h"
+
+#include <ydb/library/actors/http/http.h>
+
+#include <library/cpp/json/json_writer.h>
+#include <library/cpp/string_utils/base64/base64.h>
+
#include <util/generic/string.h>
#include <util/string/builder.h>
-#include <library/cpp/string_utils/base64/base64.h>
-#include <ydb/library/actors/http/http.h>
-#include "openid_connect.h"
-#include "oidc_settings.h"
-#include "context.h"
namespace NMVP::NOIDC {
@@ -22,15 +26,13 @@ TContext::TContext(const NHttp::THttpIncomingRequestPtr& request)
TString TContext::GetState(const TString& key) const {
static const TDuration STATE_LIFE_TIME = TDuration::Minutes(10);
- TInstant expirationTime = TInstant::Now() + STATE_LIFE_TIME;
- TStringBuilder json;
- json << "{\"state\":\"" << State
- << "\",\"expiration_time\":\"" << ToString(expirationTime.TimeT()) << "\"}";
- TString digest = HmacSHA1(key, json);
- TStringBuilder signedState;
- signedState << "{\"container\":\"" << Base64Encode(json) << "\","
- "\"digest\":\"" << Base64Encode(digest) << "\"}";
- return Base64EncodeNoPadding(signedState);
+ TState payload;
+ payload.AntiForgeryToken = State;
+ payload.ExpirationTime = TInstant::Now() + STATE_LIFE_TIME;
+ if (!NavigationRequest) {
+ payload.CookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX);
+ }
+ return EncodeState(payload, key);
}
bool TContext::IsNavigationRequest() const {
@@ -43,7 +45,7 @@ TString TContext::GetRequestedAddress() const {
TString TContext::CreateYdbOidcCookie(const TString& secret) const {
static constexpr size_t COOKIE_MAX_AGE_SEC = 3600;
- return TStringBuilder() << TOpenIdConnectSettings::YDB_OIDC_COOKIE << "="
+ return TStringBuilder() << CreateNameYdbOidcCookie(NavigationRequest ? TStringBuf() : TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX) << "="
<< GenerateCookie(secret) << ";"
" Path=" << GetAuthCallbackUrl() << ";"
" Max-Age=" << COOKIE_MAX_AGE_SEC << ";"
@@ -51,13 +53,16 @@ TString TContext::CreateYdbOidcCookie(const TString& secret) const {
}
TString TContext::GenerateCookie(const TString& key) const {
- TStringBuilder requestedAddressContext;
- requestedAddressContext << "{\"requested_address\":\"" << RequestedAddress << "\"}";
+ NJson::TJsonValue json(NJson::JSON_MAP);
+ json["requested_address"] = RequestedAddress;
+ const TString requestedAddressContext = NJson::WriteJson(json, false);
+
TString digest = HmacSHA256(key, requestedAddressContext);
- TStringBuilder signedRequestedAddress;
- signedRequestedAddress << "{\"requested_address_context\":\"" << Base64Encode(requestedAddressContext)
- << "\",\"digest\":\"" << Base64Encode(digest) << "\"}";
- return Base64Encode(signedRequestedAddress);
+
+ NJson::TJsonValue root(NJson::JSON_MAP);
+ root["requested_address_context"] = Base64Encode(requestedAddressContext);
+ root["digest"] = Base64Encode(digest);
+ return Base64Encode(NJson::WriteJson(root, false));
}
bool TContext::IsPageNavigationRequest(const NHttp::THttpIncomingRequestPtr& request) {
diff --git a/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp b/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp
index 82db06429a7..23fbd69ee09 100644
--- a/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp
+++ b/ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp
@@ -13,6 +13,7 @@
#include <ydb/mvp/core/mvp_test_runtime.h>
#include <ydb/library/security/util.h>
#include <library/cpp/json/json_reader.h>
+#include <library/cpp/json/json_writer.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <library/cpp/testing/unittest/registar.h>
#include <util/generic/map.h>
@@ -572,19 +573,21 @@ Y_UNIT_TEST_SUITE(Mvp) {
redirectStrategy.CheckRedirectStatus(outgoingResponseEv);
TString location = redirectStrategy.GetRedirectUrl(outgoingResponseEv);
UNIT_ASSERT_STRING_CONTAINS(location, "https://auth.test.net/oauth/authorize");
- UNIT_ASSERT_STRING_CONTAINS(location, "response_type=code");
- UNIT_ASSERT_STRING_CONTAINS(location, "scope=openid");
- UNIT_ASSERT_STRING_CONTAINS(location, "client_id=" + settings.ClientId);
- UNIT_ASSERT_STRING_CONTAINS(location, "redirect_uri=https://" + hostProxy + "/auth/callback");
NHttp::TUrlParameters urlParameters(location);
- const TString state = urlParameters["state"];
+ UNIT_ASSERT_STRINGS_EQUAL(urlParameters["response_type"], "code");
+ UNIT_ASSERT_STRINGS_EQUAL(urlParameters["scope"], "openid");
+ UNIT_ASSERT_STRINGS_EQUAL(urlParameters["client_id"], settings.ClientId);
+ UNIT_ASSERT_STRINGS_EQUAL(urlParameters["redirect_uri"], "https://" + hostProxy + "/auth/callback");
+ const TString state = TString(urlParameters.Get("state"));
const NHttp::THeaders headers(outgoingResponseEv->Response->Headers);
UNIT_ASSERT(headers.Has("X-Request-Id"));
UNIT_ASSERT(headers.Has("Set-Cookie"));
TStringBuf setCookie = headers.Get("Set-Cookie");
- UNIT_ASSERT_STRING_CONTAINS(setCookie, TOpenIdConnectSettings::YDB_OIDC_COOKIE);
+ UNIT_ASSERT_STRING_CONTAINS(
+ setCookie,
+ CreateNameYdbOidcCookie(redirectStrategy.IsNavigationRequest() ? TStringBuf() : TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX));
redirectStrategy.CheckSpecificHeaders(headers);
const NActors::TActorId sessionCreator = runtime.Register(new TSessionCreateHandler(edge, settings));
@@ -692,10 +695,15 @@ Y_UNIT_TEST_SUITE(Mvp) {
TAutoPtr<IEventHandle> handle;
NHttp::TEvHttpProxy::TEvHttpOutgoingResponse* outgoingResponseEv = runtime.GrabEdgeEvent<NHttp::TEvHttpProxy::TEvHttpOutgoingResponse>(handle);
- UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "302");
const NHttp::THeaders protectedPageHeaders(outgoingResponseEv->Response->Headers);
- UNIT_ASSERT(protectedPageHeaders.Has("Location"));
- UNIT_ASSERT_STRINGS_EQUAL(protectedPageHeaders.Get("Location"), "/requested/page");
+ if (redirectStrategy.IsNavigationRequest()) {
+ UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "302");
+ UNIT_ASSERT(protectedPageHeaders.Has("Location"));
+ UNIT_ASSERT_STRINGS_EQUAL(protectedPageHeaders.Get("Location"), "/requested/page");
+ } else {
+ UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "400");
+ UNIT_ASSERT(!protectedPageHeaders.Has("Location"));
+ }
}
Y_UNIT_TEST(OpenIdConnectWrongStateAuthorizationFlow) {
@@ -708,6 +716,22 @@ Y_UNIT_TEST_SUITE(Mvp) {
OidcWrongStateAuthorizationFlow(redirectStrategy);
}
+ Y_UNIT_TEST(OpenIdConnectExpiredBackgroundStateKeepsCookieSuffix) {
+ TPortManager tp;
+ auto settings = BuildBaseSettings(tp);
+ TState sourcePayload;
+ sourcePayload.AntiForgeryToken = "state";
+ sourcePayload.ExpirationTime = TInstant::Seconds(0);
+ sourcePayload.CookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX);
+
+ TCheckStateResult result = CheckState(EncodeState(sourcePayload, settings.ClientSecret), settings.ClientSecret);
+ const TString expectedCookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX);
+
+ UNIT_ASSERT(!result.Ok);
+ UNIT_ASSERT_STRINGS_EQUAL(result.CookieSuffix, expectedCookieSuffix);
+ UNIT_ASSERT_STRING_CONTAINS(result.ErrorMessage, "State life time expired");
+ }
+
Y_UNIT_TEST(OpenIdConnectSessionServiceCreateAuthorizationFail) {
TPortManager tp;
TMvpTestRuntime runtime;
@@ -1411,17 +1435,17 @@ Y_UNIT_TEST_SUITE(Mvp) {
}
static TString GetViewerResponse200() {
- TStringBuilder body;
- body << "{\"UserSID\":\"" << VIEWER_USER_ACCOUNT_ID
- << "\",\"OriginalUserToken\":\"" << TProfileServiceMock::VALID_USER_TOKEN << "\"}";
- return MakeHttpResponse("200 OK", body, "application/json");
+ NJson::TJsonValue json(NJson::JSON_MAP);
+ json["UserSID"] = VIEWER_USER_ACCOUNT_ID;
+ json["OriginalUserToken"] = TProfileServiceMock::VALID_USER_TOKEN;
+ return MakeHttpResponse("200 OK", NJson::WriteJson(json, false), "application/json");
}
static TString GetViewerResponseService200() {
- TStringBuilder body;
- body << "{\"UserSID\":\"" << VIEWER_SERVICE_ACCOUNT_ID
- << "\",\"OriginalUserToken\":\"" << TProfileServiceMock::VALID_SERVICE_TOKEN << "\"}";
- return MakeHttpResponse("200 OK", body, "application/json");
+ NJson::TJsonValue json(NJson::JSON_MAP);
+ json["UserSID"] = VIEWER_SERVICE_ACCOUNT_ID;
+ json["OriginalUserToken"] = TProfileServiceMock::VALID_SERVICE_TOKEN;
+ return MakeHttpResponse("200 OK", NJson::WriteJson(json, false), "application/json");
}
static TString GetViewerResponse403() {
@@ -1607,6 +1631,28 @@ static void NavigationRequestTest(const TString& rawRequest, bool expectedNaviga
}
Y_UNIT_TEST_SUITE(Utils) {
+ Y_UNIT_TEST(OpenIdConnectStateRoundTrip) {
+ TPortManager tp;
+ auto settings = BuildBaseSettings(tp);
+
+ TState sourcePayload;
+ sourcePayload.AntiForgeryToken = "state";
+ sourcePayload.ExpirationTime = TInstant::Seconds(TInstant::Now().Seconds() + TDuration::Minutes(10).Seconds());
+ sourcePayload.CookieSuffix = TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE_BACKGROUND_SUFFIX);
+
+ const TString state = EncodeState(sourcePayload, settings.ClientSecret);
+ const TCheckStateResult result = CheckState(state, settings.ClientSecret);
+ const TDecodeStateResult decodedResult = DecodeState(state);
+
+ UNIT_ASSERT(result.Ok);
+ UNIT_ASSERT_STRINGS_EQUAL(result.CookieSuffix, sourcePayload.CookieSuffix);
+ UNIT_ASSERT(result.ErrorMessage.empty());
+ UNIT_ASSERT(decodedResult.HasSignedStateJson);
+ UNIT_ASSERT(decodedResult.HasStateContainerJson);
+ UNIT_ASSERT(decodedResult.Payload == sourcePayload);
+ UNIT_ASSERT_STRINGS_EQUAL(EncodeState(decodedResult.Payload, settings.ClientSecret), state);
+ }
+
Y_UNIT_TEST(GenerateRandomBase64RandomUniqueness) {
THashSet<TString> seen;
for (size_t i = 0; i < 100; ++i) {
diff --git a/ydb/mvp/oidc_proxy/oidc_session_create.cpp b/ydb/mvp/oidc_proxy/oidc_session_create.cpp
index 8cd57a42959..bcac8054175 100644
--- a/ydb/mvp/oidc_proxy/oidc_session_create.cpp
+++ b/ydb/mvp/oidc_proxy/oidc_session_create.cpp
@@ -31,10 +31,10 @@ void THandlerSessionCreate::Bootstrap() {
NHttp::THeaders headers(Request->Headers);
NHttp::TCookies cookies(headers.Get("cookie"));
- TRestoreOidcContextResult restoreContextResult = RestoreOidcContext(cookies, Settings.ClientSecret);
+ TRestoreOidcContextResult restoreContextResult = RestoreOidcContext(cookies, Settings.ClientSecret, checkStateResult.CookieSuffix);
Context = restoreContextResult.Context;
- if (checkStateResult.IsSuccess()) {
+ if (checkStateResult.Ok) {
if (restoreContextResult.IsSuccess()) {
if (code.empty()) {
BLOG_D("Restore oidc session failed: receive empty 'code' parameter");
diff --git a/ydb/mvp/oidc_proxy/oidc_settings.h b/ydb/mvp/oidc_proxy/oidc_settings.h
index 2306eb1e21d..a26f72366ab 100644
--- a/ydb/mvp/oidc_proxy/oidc_settings.h
+++ b/ydb/mvp/oidc_proxy/oidc_settings.h
@@ -10,6 +10,7 @@ namespace NMVP::NOIDC {
struct TOpenIdConnectSettings {
static const inline TString YDB_OIDC_COOKIE = "ydb_oidc_cookie";
+ static const inline TStringBuf YDB_OIDC_COOKIE_BACKGROUND_SUFFIX = "_background";
static const inline TString SESSION_COOKIE = "session_cookie";
static const inline TString IMPERSONATED_COOKIE = "impersonated_cookie";
diff --git a/ydb/mvp/oidc_proxy/openid_connect.cpp b/ydb/mvp/oidc_proxy/openid_connect.cpp
index 3b8cdaf6c46..c8470cbf570 100644
--- a/ydb/mvp/oidc_proxy/openid_connect.cpp
+++ b/ydb/mvp/oidc_proxy/openid_connect.cpp
@@ -5,7 +5,9 @@
#include <ydb/core/util/wildcard.h>
#include <ydb/library/security/util.h>
+#include <library/cpp/cgiparam/cgiparam.h>
#include <library/cpp/json/json_reader.h>
+#include <library/cpp/json/json_writer.h>
#include <library/cpp/string_utils/base64/base64.h>
#include <openssl/evp.h>
@@ -14,7 +16,6 @@
#include <util/string/builder.h>
#include <util/string/hex.h>
-
namespace NMVP::NOIDC {
TRestoreOidcContextResult::TRestoreOidcContextResult(const TStatus& status, const TContext& context)
@@ -27,13 +28,18 @@ bool TRestoreOidcContextResult::IsSuccess() const {
return Status.IsSuccess;
}
-TCheckStateResult::TCheckStateResult(bool success, const TString& errorMessage)
- : Success(success)
+TCheckStateResult::TCheckStateResult(bool ok, const TString& cookieSuffix, const TString& errorMessage)
+ : Ok(ok)
, ErrorMessage(errorMessage)
+ , CookieSuffix(cookieSuffix)
{}
-bool TCheckStateResult::IsSuccess() const {
- return Success;
+TCheckStateResult TCheckStateResult::Error(const TString& errorMessage, const TString& cookieSuffix) {
+ return TCheckStateResult(false, cookieSuffix, errorMessage);
+}
+
+TCheckStateResult TCheckStateResult::Success(const TString& cookieSuffix) {
+ return TCheckStateResult(true, cookieSuffix, "");
}
void SetCORS(const NHttp::THttpIncomingRequestPtr& request, NHttp::THeadersBuilder* const headers) {
@@ -79,14 +85,19 @@ void SetHeader(NYdbGrpc::TCallMeta& meta, const TString& name, const TString& va
NHttp::THttpOutgoingResponsePtr GetHttpOutgoingResponsePtr(const NHttp::THttpIncomingRequestPtr& request, const TOpenIdConnectSettings& settings, TStringBuf requestId) {
TContext context(request);
- const TString redirectUrl = TStringBuilder() << settings.GetAuthEndpointURL()
- << "?response_type=code"
- << "&scope=openid"
- << "&state=" << context.GetState(settings.ClientSecret)
- << "&client_id=" << settings.ClientId
- << "&redirect_uri=" << (request->Endpoint->Secure ? "https://" : "http://")
- << request->Host
- << GetAuthCallbackUrl();
+ const TString redirectUri = TStringBuilder()
+ << (request->Endpoint->Secure ? "https://" : "http://")
+ << request->Host
+ << GetAuthCallbackUrl();
+
+ TCgiParameters authParams;
+ authParams.InsertUnescaped("response_type", "code");
+ authParams.InsertUnescaped("scope", "openid");
+ authParams.InsertUnescaped("state", context.GetState(settings.ClientSecret));
+ authParams.InsertUnescaped("client_id", settings.ClientId);
+ authParams.InsertUnescaped("redirect_uri", redirectUri);
+
+ const TString redirectUrl = settings.GetAuthEndpointURL() + "?" + authParams.Print();
NHttp::THeadersBuilder responseHeaders;
SetCORS(request, &responseHeaders);
SetRequestIdHeader(responseHeaders, requestId);
@@ -96,12 +107,15 @@ NHttp::THttpOutgoingResponsePtr GetHttpOutgoingResponsePtr(const NHttp::THttpInc
return request->CreateResponse("302", "Authorization required", responseHeaders);
}
responseHeaders.Set("Content-Type", "application/json; charset=utf-8");
- TString body {"{\"error\":\"Authorization Required\",\"authUrl\":\"" + redirectUrl + "\"}"};
+ NJson::TJsonValue json(NJson::JSON_MAP);
+ json["error"] = "Authorization Required";
+ json["authUrl"] = redirectUrl;
+ const TString body = NJson::WriteJson(json, false);
return request->CreateResponse("401", "Unauthorized", responseHeaders, body);
}
-TString CreateNameYdbOidcCookie(TStringBuf key, TStringBuf state) {
- return TOpenIdConnectSettings::YDB_OIDC_COOKIE + "_" + HexEncode(HmacSHA256(key, state));
+TString CreateNameYdbOidcCookie(TStringBuf suffix) {
+ return TString(TOpenIdConnectSettings::YDB_OIDC_COOKIE) + TString(suffix);
}
TString CreateNameSessionCookie(TStringBuf key) {
@@ -131,23 +145,24 @@ TString ClearSecureCookie(const TString& name) {
return cookieBuilder;
}
-TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key) {
+TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key, TStringBuf cookieSuffix) {
TStringBuilder errorMessage;
errorMessage << "Restore oidc context failed: ";
- if (!cookies.Has(TOpenIdConnectSettings::YDB_OIDC_COOKIE)) {
+ TString cookieName = CreateNameYdbOidcCookie(cookieSuffix);
+ if (!cookies.Has(cookieName)) {
return TRestoreOidcContextResult({.IsSuccess = false,
.IsErrorRetryable = false,
- .ErrorMessage = errorMessage << "Cannot find cookie " << TOpenIdConnectSettings::YDB_OIDC_COOKIE});
+ .ErrorMessage = errorMessage << "Cannot find cookie " << cookieName});
}
- TString signedRequestedAddress = Base64Decode(cookies.Get(TOpenIdConnectSettings::YDB_OIDC_COOKIE));
+ TString signedRequestedAddress = Base64Decode(cookies.Get(cookieName));
TString requestedAddressContext;
TString expectedDigest;
NJson::TJsonValue jsonValue;
NJson::TJsonReaderConfig jsonConfig;
if (NJson::ReadJsonTree(signedRequestedAddress, &jsonConfig, &jsonValue)) {
const NJson::TJsonValue* jsonRequestedAddressContext = nullptr;
- if (jsonValue.GetValuePointer("requested_address_context", &jsonRequestedAddressContext)) {
- requestedAddressContext = jsonRequestedAddressContext->GetStringRobust();
+ if (jsonValue.GetValuePointer("requested_address_context", &jsonRequestedAddressContext) && jsonRequestedAddressContext->IsString()) {
+ requestedAddressContext = jsonRequestedAddressContext->GetString();
requestedAddressContext = Base64Decode(requestedAddressContext);
}
if (requestedAddressContext.empty()) {
@@ -156,8 +171,8 @@ TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, con
.ErrorMessage = errorMessage << "Struct with state is empty"});
}
const NJson::TJsonValue* jsonDigest = nullptr;
- if (jsonValue.GetValuePointer("digest", &jsonDigest)) {
- expectedDigest = jsonDigest->GetStringRobust();
+ if (jsonValue.GetValuePointer("digest", &jsonDigest) && jsonDigest->IsString()) {
+ expectedDigest = jsonDigest->GetString();
expectedDigest = Base64Decode(expectedDigest);
}
if (expectedDigest.empty()) {
@@ -173,66 +188,113 @@ TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, con
.ErrorMessage = errorMessage << "Calculated digest is not equal expected digest"});
}
TString requestedAddress;
- if (NJson::ReadJsonTree(requestedAddressContext, &jsonConfig, &jsonValue)) {
- const NJson::TJsonValue* jsonRequestedAddress = nullptr;
- if (jsonValue.GetValuePointer("requested_address", &jsonRequestedAddress)) {
- requestedAddress = jsonRequestedAddress->GetStringRobust();
- } else {
- return TRestoreOidcContextResult({.IsSuccess = false,
- .IsErrorRetryable = false,
- .ErrorMessage = errorMessage << "Requested address was not found in the cookie"});
- }
+ if (!NJson::ReadJsonTree(requestedAddressContext, &jsonConfig, &jsonValue)) {
+ return TRestoreOidcContextResult({.IsSuccess = false,
+ .IsErrorRetryable = false,
+ .ErrorMessage = errorMessage << "Requested address context is not valid json"});
+ }
+ const NJson::TJsonValue* jsonRequestedAddress = nullptr;
+ if (jsonValue.GetValuePointer("requested_address", &jsonRequestedAddress) && jsonRequestedAddress->IsString()) {
+ requestedAddress = jsonRequestedAddress->GetString();
+ } else {
+ return TRestoreOidcContextResult({.IsSuccess = false,
+ .IsErrorRetryable = false,
+ .ErrorMessage = errorMessage << "Requested address was not found in the cookie"});
}
return TRestoreOidcContextResult({.IsSuccess = true,
.IsErrorRetryable = true,
.ErrorMessage = ""}, TContext({.RequestedAddress = requestedAddress}));
}
-TCheckStateResult CheckState(const TString& state, const TString& key) {
- TStringBuilder errorMessage;
- errorMessage << "Check state failed: ";
- TString signedState = Base64DecodeUneven(state);
- TString stateContainer;
- TString expectedDigest;
+TCheckStateResult TDecodeStateResult::Check(const TString& key) const {
+ static constexpr TStringBuf ErrorPrefix = "Check state failed: ";
+ if (!HasSignedStateJson) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "Signed state is not valid json");
+ }
+ if (StateContainer.empty()) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "Container with state is empty");
+ }
+ if (ExpectedDigest.empty()) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "Expected digest is empty");
+ }
+
+ TString digest = HmacSHA1(key, StateContainer);
+ if (ExpectedDigest != digest) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "Calculated digest is not equal expected digest");
+ }
+ if (!HasStateContainerJson) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "State container is not valid json");
+ }
+ if (!Payload.ExpirationTime) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "Expiration time not found in json", Payload.CookieSuffix);
+ }
+ if (TInstant::Now() > *Payload.ExpirationTime) {
+ return TCheckStateResult::Error(TString(ErrorPrefix) + "State life time expired", Payload.CookieSuffix);
+ }
+ return TCheckStateResult::Success(Payload.CookieSuffix);
+}
+
+TString EncodeState(const TState& payload, TStringBuf signingKey) {
+ NJson::TJsonValue json(NJson::JSON_MAP);
+ json["state"] = payload.AntiForgeryToken;
+ if (payload.ExpirationTime) {
+ json["expiration_time"] = ToString(payload.ExpirationTime->TimeT());
+ }
+ if (!payload.CookieSuffix.empty()) {
+ json["cookie_suffix"] = payload.CookieSuffix;
+ }
+ const TString stateContainer = NJson::WriteJson(json, false);
+
+ TString digest = HmacSHA1(signingKey, stateContainer);
+
+ NJson::TJsonValue root(NJson::JSON_MAP);
+ root["container"] = Base64Encode(stateContainer);
+ root["digest"] = Base64Encode(digest);
+ return Base64EncodeNoPadding(NJson::WriteJson(root, false));
+}
+
+TDecodeStateResult DecodeState(TStringBuf encodedState) {
+ TDecodeStateResult result;
+ TString signedState = Base64DecodeUneven(encodedState);
NJson::TJsonValue jsonValue;
NJson::TJsonReaderConfig jsonConfig;
- if (NJson::ReadJsonTree(signedState, &jsonConfig, &jsonValue)) {
+ result.HasSignedStateJson = NJson::ReadJsonTree(signedState, &jsonConfig, &jsonValue);
+ if (result.HasSignedStateJson) {
const NJson::TJsonValue* jsonStateContainer = nullptr;
- if (jsonValue.GetValuePointer("container", &jsonStateContainer)) {
- stateContainer = jsonStateContainer->GetStringRobust();
- stateContainer = Base64Decode(stateContainer);
- }
- if (stateContainer.empty()) {
- return TCheckStateResult(false, errorMessage << "Container with state is empty");
+ if (jsonValue.GetValuePointer("container", &jsonStateContainer) && jsonStateContainer->IsString()) {
+ result.StateContainer = jsonStateContainer->GetString();
+ result.StateContainer = Base64Decode(result.StateContainer);
}
const NJson::TJsonValue* jsonDigest = nullptr;
- if (jsonValue.GetValuePointer("digest", &jsonDigest)) {
- expectedDigest = jsonDigest->GetStringRobust();
- expectedDigest = Base64Decode(expectedDigest);
+ if (jsonValue.GetValuePointer("digest", &jsonDigest) && jsonDigest->IsString()) {
+ result.ExpectedDigest = jsonDigest->GetString();
+ result.ExpectedDigest = Base64Decode(result.ExpectedDigest);
}
- if (expectedDigest.empty()) {
- return TCheckStateResult(false, errorMessage << "Expected digest is empty");
- }
- }
- TString digest = HmacSHA1(key, stateContainer);
- if (expectedDigest != digest) {
- return TCheckStateResult(false, errorMessage << "Calculated digest is not equal expected digest");
}
- if (NJson::ReadJsonTree(stateContainer, &jsonConfig, &jsonValue)) {
- const NJson::TJsonValue* jsonExpirationTime = nullptr;
- if (jsonValue.GetValuePointer("expiration_time", &jsonExpirationTime)) {
- timeval timeVal {
- .tv_sec = jsonExpirationTime->GetIntegerRobust(),
- .tv_usec = 0
- };
- if (TInstant::Now() > TInstant(timeVal)) {
- return TCheckStateResult(false, errorMessage << "State life time expired");
+
+ if (!result.StateContainer.empty()) {
+ result.HasStateContainerJson = NJson::ReadJsonTree(result.StateContainer, &jsonConfig, &jsonValue);
+ if (result.HasStateContainerJson) {
+ const NJson::TJsonValue* jsonState = nullptr;
+ if (jsonValue.GetValuePointer("state", &jsonState) && jsonState->IsString()) {
+ result.Payload.AntiForgeryToken = jsonState->GetString();
+ }
+ const NJson::TJsonValue* jsonCookieSuffix = nullptr;
+ if (jsonValue.GetValuePointer("cookie_suffix", &jsonCookieSuffix) && jsonCookieSuffix->IsString()) {
+ result.Payload.CookieSuffix = jsonCookieSuffix->GetString();
+ }
+ const NJson::TJsonValue* jsonExpirationTime = nullptr;
+ if (jsonValue.GetValuePointer("expiration_time", &jsonExpirationTime)) {
+ result.Payload.ExpirationTime = TInstant::Seconds(jsonExpirationTime->GetIntegerRobust());
}
- } else {
- return TCheckStateResult(false, errorMessage << "Expiration time not found in json");
}
}
- return TCheckStateResult();
+
+ return result;
+}
+
+TCheckStateResult CheckState(const TString& state, const TString& key) {
+ return DecodeState(state).Check(key);
}
TString DecodeToken(const TStringBuf& cookie) {
diff --git a/ydb/mvp/oidc_proxy/openid_connect.h b/ydb/mvp/oidc_proxy/openid_connect.h
index 670990c2f86..24c0efcc096 100644
--- a/ydb/mvp/oidc_proxy/openid_connect.h
+++ b/ydb/mvp/oidc_proxy/openid_connect.h
@@ -11,7 +11,9 @@
#include <ydb/mvp/core/core_ydb.h>
#include <ydb/public/api/client/yc_private/oauth/session_service.grpc.pb.h>
#include <ydb/public/api/client/nc_private/iam/v1/profile_service.grpc.pb.h>
+#include <util/datetime/base.h>
#include <util/generic/ptr.h>
+#include <util/generic/maybe.h>
#include <util/generic/string.h>
namespace NMVP::NOIDC {
@@ -45,27 +47,57 @@ struct TRestoreOidcContextResult {
bool IsSuccess() const;
};
+struct TCheckStateResult;
+
+struct TState {
+ TString AntiForgeryToken;
+ TString CookieSuffix;
+ TMaybe<TInstant> ExpirationTime;
+
+ bool operator==(const TState& other) const {
+ return AntiForgeryToken == other.AntiForgeryToken
+ && CookieSuffix == other.CookieSuffix
+ && ExpirationTime == other.ExpirationTime;
+ }
+};
+
+struct TDecodeStateResult {
+ bool HasSignedStateJson = false;
+ bool HasStateContainerJson = false;
+
+ TString StateContainer;
+ TString ExpectedDigest;
+ TState Payload;
+
+ TCheckStateResult Check(const TString& key) const;
+};
+
struct TCheckStateResult {
- bool Success = true;
+ bool Ok = true;
TString ErrorMessage;
+ TString CookieSuffix;
- TCheckStateResult(bool success = true, const TString& errorMessage = "");
+ static TCheckStateResult Error(const TString& errorMessage, const TString& cookieSuffix = "");
+ static TCheckStateResult Success(const TString& cookieSuffix = "");
- bool IsSuccess() const;
+private:
+ TCheckStateResult(bool ok, const TString& cookieSuffix, const TString& errorMessage);
};
TString HmacSHA256(TStringBuf key, TStringBuf data);
TString HmacSHA1(TStringBuf key, TStringBuf data);
void SetHeader(NYdbGrpc::TCallMeta& meta, const TString& name, const TString& value);
NHttp::THttpOutgoingResponsePtr GetHttpOutgoingResponsePtr(const NHttp::THttpIncomingRequestPtr& request, const TOpenIdConnectSettings& settings, TStringBuf requestId);
-TString CreateNameYdbOidcCookie(TStringBuf key, TStringBuf state);
+TString CreateNameYdbOidcCookie(TStringBuf suffix = "");
TString CreateNameSessionCookie(TStringBuf key);
TString CreateNameImpersonatedCookie(TStringBuf key);
const TString& GetAuthCallbackUrl();
TString CreateSecureCookie(const TString& name, const TString& value, const ui32 expiredSeconds);
TString ClearSecureCookie(const TString& name);
void SetCORS(const NHttp::THttpIncomingRequestPtr& request, NHttp::THeadersBuilder* const headers);
-TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key);
+TRestoreOidcContextResult RestoreOidcContext(const NHttp::TCookies& cookies, const TString& key, TStringBuf cookieSuffix = "");
+TString EncodeState(const TState& payload, TStringBuf signingKey);
+TDecodeStateResult DecodeState(TStringBuf encodedState);
TCheckStateResult CheckState(const TString& state, const TString& key);
TString DecodeToken(const TStringBuf& cookie);
TStringBuf GetCookie(const NHttp::TCookies& cookies, const TString& cookieName);
diff --git a/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py b/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py
index a2afc3c1a5b..06655c86715 100644
--- a/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py
+++ b/ydb/tests/functional/mvp/oidc_proxy/test_auth_failures.py
@@ -1,5 +1,5 @@
import requests
-from urllib.parse import urlparse
+from urllib.parse import parse_qs, urlparse
from oidc_proxy_testlib import (
CALLBACK_PATH,
@@ -57,7 +57,7 @@ def build_base_expired_session_request_headers(host):
return headers
-def build_ajax_expired_session_request_headers(host):
+def build_background_expired_session_request_headers(host):
headers = build_base_expired_session_request_headers(host)
headers.update({
"Accept": "*/*",
@@ -93,11 +93,11 @@ def get_protected_landing_data_path(env):
return f"/{protected_host(env)}{LANDING_DATA_PATH}"
-def start_ajax_auth_challenge_request(env):
+def start_background_auth_challenge_request(env):
host = protected_host(env)
return start_auth_challenge_request(
env,
- build_ajax_expired_session_request_headers(host),
+ build_background_expired_session_request_headers(host),
)
@@ -109,7 +109,7 @@ def start_navigation_auth_challenge_request(env):
)
-def assert_ajax_auth_challenge(response):
+def assert_background_auth_challenge(response):
assert response.status_code == 401, response.text
response_json = response.json()
assert response_json["error"] == "Authorization Required", response_json
@@ -122,13 +122,57 @@ def assert_navigation_auth_redirect(env, response):
assert response.status_code == 302, response.text
redirect_url = response.headers["Location"]
assert redirect_url.startswith(env.auth_service.endpoint + AUTH_AUTHORIZE_PATH), redirect_url
+ assert "Set-Cookie" in response.headers, response.headers
+
+
+def start_navigation_auth_flow(env, start_path):
+ response = env.get(
+ start_path,
+ allow_redirects=False,
+ headers={"Host": "oidcproxy.net"},
+ )
+ assert_navigation_auth_redirect(env, response)
+ state = parse_qs(urlparse(response.headers["Location"]).query)["state"][0]
+ oidc_cookie = response.headers["Set-Cookie"].split(";", 1)[0]
+ return state, oidc_cookie
+
+
+def finish_auth_callback(env, state, oidc_cookies):
+ return env.get(
+ "/auth/callback",
+ params={
+ "code": "code_template#",
+ "state": state,
+ },
+ allow_redirects=False,
+ headers={
+ "Host": "oidcproxy.net",
+ "Cookie": "; ".join(oidc_cookies),
+ },
+ )
-def test_ajax_request_returns_json_401_with_auth_url(oidc_proxy_full_flow_env):
- response = start_ajax_auth_challenge_request(oidc_proxy_full_flow_env)
- assert_ajax_auth_challenge(response)
+def test_background_request_returns_json_401_with_auth_url(oidc_proxy_full_flow_env):
+ response = start_background_auth_challenge_request(oidc_proxy_full_flow_env)
+ assert_background_auth_challenge(response)
def test_navigation_request_returns_oidc_redirect(oidc_proxy_full_flow_env):
response = start_navigation_auth_challenge_request(oidc_proxy_full_flow_env)
assert_navigation_auth_redirect(oidc_proxy_full_flow_env, response)
+
+
+def test_navigation_callback_uses_cookie_from_state_when_background_cookie_exists(oidc_proxy_full_flow_env):
+ env = oidc_proxy_full_flow_env
+ host = protected_host(env)
+ navigation_target = f"/{host}{LANDING_DATA_PATH}&from=navigation"
+
+ navigation_state, navigation_oidc_cookie = start_navigation_auth_flow(env, navigation_target)
+ background_response = start_background_auth_challenge_request(env)
+ assert_background_auth_challenge(background_response)
+ background_oidc_cookie = background_response.headers["Set-Cookie"].split(";", 1)[0]
+
+ callback_response = finish_auth_callback(env, navigation_state, [navigation_oidc_cookie, background_oidc_cookie])
+
+ assert callback_response.status_code == 302, callback_response.text
+ assert callback_response.headers["Location"] == navigation_target, callback_response.headers