aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/tvmauth/client/misc/api/dynamic_dst/tvm_client.cpp
blob: cd6ec45406b489a416f09e48fc9734ea6f89680f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#include "tvm_client.h"

#include <util/string/builder.h>

namespace NTvmAuth::NDynamicClient {
    TIntrusivePtr<TTvmClient> TTvmClient::Create(const NTvmApi::TClientSettings& settings, TLoggerPtr logger) {
        Y_ENSURE_EX(logger, TNonRetriableException() << "Logger is required");
        THolder<TTvmClient> p(new TTvmClient(settings, std::move(logger)));
        p->Init();
        p->StartWorker();
        return p.Release();
    }

    NThreading::TFuture<TAddResponse> TTvmClient::Add(TDsts&& dsts) {
        if (dsts.empty()) {
            LogDebug("Adding dst: got empty task");
            return NThreading::MakeFuture<TAddResponse>(TAddResponse{});
        }

        NThreading::TPromise<TAddResponse> promise = NThreading::NewPromise<TAddResponse>();

        TServiceTickets::TMapIdStr requestedTicketsFromStartUpCache = GetRequestedTicketsFromStartUpCache(dsts);

        if (requestedTicketsFromStartUpCache.size() == dsts.size() &&
            !IsInvalid(TServiceTickets::GetInvalidationTime(requestedTicketsFromStartUpCache), TInstant::Now())) {
            std::unique_lock lock(*ServiceTicketBatchUpdateMutex_);

            TPairTicketsErrors newCache;
            TServiceTicketsPtr cache = GetCachedServiceTickets();

            NTvmApi::TDstSetPtr oldDsts = GetDsts();
            std::shared_ptr<TDsts> newDsts = std::make_shared<TDsts>(oldDsts->begin(), oldDsts->end());

            for (const auto& ticket : cache->TicketsById) {
                newCache.Tickets.insert(ticket);
            }
            for (const auto& error : cache->ErrorsById) {
                newCache.Errors.insert(error);
            }
            for (const auto& ticket : requestedTicketsFromStartUpCache) {
                newCache.Tickets.insert(ticket);
                newDsts->insert(ticket.first);
            }

            UpdateServiceTicketsCache(std::move(newCache), GetStartUpCacheBornDate());
            SetDsts(std::move(newDsts));

            lock.unlock();

            TAddResponse response;

            for (const auto& dst : dsts) {
                response.emplace(dst, TDstResponse{EDstStatus::Success, TString()});
                LogDebug(TStringBuilder() << "Got ticket from disk cache"
                                          << ": dst=" << dst.Id << " got ticket");
            }

            promise.SetValue(std::move(response));
            return promise.GetFuture();
        }

        const size_t size = dsts.size();
        const ui64 id = ++TaskIds_;

        TaskQueue_.Enqueue(TTask{id, promise, std::move(dsts)});

        LogDebug(TStringBuilder() << "Adding dst: got task #" << id << " with " << size << " dsts");
        return promise.GetFuture();
    }

    std::optional<TString> TTvmClient::GetOptionalServiceTicketFor(const TTvmId dst) {
        TServiceTicketsPtr tickets = GetCachedServiceTickets();

        Y_ENSURE_EX(tickets,
                    TBrokenTvmClientSettings()
                        << "Need to enable fetching of service tickets in settings");

        auto it = tickets->TicketsById.find(dst);
        if (it != tickets->TicketsById.end()) {
            return it->second;
        }

        it = tickets->ErrorsById.find(dst);
        if (it != tickets->ErrorsById.end()) {
            ythrow TMissingServiceTicket()
                << "Failed to get ticket for '" << dst << "': "
                << it->second;
        }

        return {};
    }

    TTvmClient::TTvmClient(const NTvmApi::TClientSettings& settings, TLoggerPtr logger)
        : TBase(settings, logger)
    {
    }

    TTvmClient::~TTvmClient() {
        TBase::StopWorker();
    }

    void TTvmClient::Worker() {
        TBase::Worker();
        ProcessTasks();
    }

    void TTvmClient::ProcessTasks() {
        TaskQueue_.DequeueAll(&Tasks_);
        if (Tasks_.empty()) {
            return;
        }

        TDsts required;
        for (const TTask& task : Tasks_) {
            for (const auto& dst : task.Dsts) {
                required.insert(dst);
            }
        }

        TServiceTicketsPtr cache = UpdateMissingServiceTickets(required);
        for (TTask& task : Tasks_) {
            try {
                SetResponseForTask(task, *cache);
            } catch (const std::exception& e) {
                LogError(TStringBuilder()
                         << "Adding dst: task #" << task.Id << ": exception: " << e.what());
            } catch (...) {
                LogError(TStringBuilder()
                         << "Adding dst: task #" << task.Id << ": exception: " << CurrentExceptionMessage());
            }
        }

        Tasks_.clear();
    }

    static const TString UNKNOWN = "Unknown reason";
    void TTvmClient::SetResponseForTask(TTvmClient::TTask& task, const TServiceTickets& cache) {
        if (task.Promise.HasValue()) {
            LogWarning(TStringBuilder() << "Adding dst: task #" << task.Id << " already has value");
            return;
        }

        TAddResponse response;

        for (const auto& dst : task.Dsts) {
            if (cache.TicketsById.contains(dst.Id)) {
                response.emplace(dst, TDstResponse{EDstStatus::Success, TString()});

                LogDebug(TStringBuilder() << "Adding dst: task #" << task.Id
                                          << ": dst=" << dst.Id << " got ticket");
                continue;
            }

            auto it = cache.ErrorsById.find(dst.Id);
            const TString& error = it == cache.ErrorsById.end() ? UNKNOWN : it->second;
            response.emplace(dst, TDstResponse{EDstStatus::Fail, error});

            LogWarning(TStringBuilder() << "Adding dst: task #" << task.Id
                                        << ": dst=" << dst.Id
                                        << " failed to get ticket: " << error);
        }

        LogDebug(TStringBuilder() << "Adding dst: task #" << task.Id << ": set value");
        task.Promise.SetValue(std::move(response));
    }
}