1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
|
#pragma once
#include <Access/IAccessEntity.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <Parsers/IParser.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <functional>
#include <optional>
#include <vector>
#include <atomic>
namespace Poco { class Logger; }
namespace Poco::Net { class IPAddress; }
namespace DB
{
struct User;
class Credentials;
class ExternalAuthenticators;
enum class AuthenticationType;
class BackupEntriesCollector;
class RestorerFromBackup;
/// Contains entities, i.e. instances of classes derived from IAccessEntity.
/// The implementations of this class MUST be thread-safe.
class IAccessStorage : public boost::noncopyable
{
public:
explicit IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
virtual ~IAccessStorage() = default;
/// Returns the name of this storage.
const String & getStorageName() const { return storage_name; }
virtual const char * getStorageType() const = 0;
/// Returns a JSON with the parameters of the storage. It's up to the storage type to fill the JSON.
virtual String getStorageParamsJSON() const { return "{}"; }
/// Returns true if this storage is readonly.
virtual bool isReadOnly() const { return false; }
/// Returns true if this entity is readonly.
virtual bool isReadOnly(const UUID &) const { return isReadOnly(); }
/// Starts periodic reloading and updating of entities in this storage.
virtual void startPeriodicReloading() {}
/// Stops periodic reloading and updating of entities in this storage.
virtual void stopPeriodicReloading() {}
enum class ReloadMode
{
/// Try to reload all access storages (including users.xml, local(disk) access storage, replicated(in zk) access storage.
/// This mode is invoked by the SYSTEM RELOAD USERS command.
ALL,
/// Only reloads users.xml
/// This mode is invoked by the SYSTEM RELOAD CONFIG command.
USERS_CONFIG_ONLY,
};
/// Makes this storage to reload and update access entities right now.
virtual void reload(ReloadMode /* reload_mode */) {}
/// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll(AccessEntityType type) const;
template <typename EntityClassT>
std::vector<UUID> findAll() const { return findAll(EntityClassT::TYPE); }
/// Searches for an entity with specified type and name. Returns std::nullopt if not found.
std::optional<UUID> find(AccessEntityType type, const String & name) const;
template <typename EntityClassT>
std::optional<UUID> find(const String & name) const { return find(EntityClassT::TYPE, name); }
std::vector<UUID> find(AccessEntityType type, const Strings & names) const;
template <typename EntityClassT>
std::vector<UUID> find(const Strings & names) const { return find(EntityClassT::TYPE, names); }
/// Searches for an entity with specified name and type. Throws an exception if not found.
UUID getID(AccessEntityType type, const String & name) const;
template <typename EntityClassT>
UUID getID(const String & name) const { return getID(EntityClassT::TYPE, name); }
std::vector<UUID> getIDs(AccessEntityType type, const Strings & names) const;
template <typename EntityClassT>
std::vector<UUID> getIDs(const Strings & names) const { return getIDs(EntityClassT::TYPE, names); }
/// Returns whether there is an entity with such identifier in the storage.
virtual bool exists(const UUID & id) const = 0;
bool exists(const std::vector<UUID> & ids) const;
/// Reads an entity. Throws an exception if not found.
template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityClassT> read(const UUID & id, bool throw_if_not_exists = true) const;
template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityClassT> read(const String & name, bool throw_if_not_exists = true) const;
template <typename EntityClassT = IAccessEntity>
std::vector<AccessEntityPtr> read(const std::vector<UUID> & ids, bool throw_if_not_exists = true) const;
/// Reads an entity. Returns nullptr if not found.
template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityClassT> tryRead(const UUID & id) const;
template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityClassT> tryRead(const String & name) const;
/// Reads only name of an entity.
String readName(const UUID & id) const;
std::optional<String> readName(const UUID & id, bool throw_if_not_exists) const;
Strings readNames(const std::vector<UUID> & ids, bool throw_if_not_exists = true) const;
std::optional<String> tryReadName(const UUID & id) const;
Strings tryReadNames(const std::vector<UUID> & ids) const;
std::pair<String, AccessEntityType> readNameWithType(const UUID & id) const;
std::optional<std::pair<String, AccessEntityType>> readNameWithType(const UUID & id, bool throw_if_not_exists) const;
std::optional<std::pair<String, AccessEntityType>> tryReadNameWithType(const UUID & id) const;
/// Reads all entities and returns them with their IDs.
template <typename EntityClassT>
std::vector<std::pair<UUID, std::shared_ptr<const EntityClassT>>> readAllWithIDs() const;
std::vector<std::pair<UUID, AccessEntityPtr>> readAllWithIDs(AccessEntityType type) const;
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
/// Throws an exception if the specified name already exists.
UUID insert(const AccessEntityPtr & entity);
std::optional<UUID> insert(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
bool insert(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
std::vector<UUID> insert(const std::vector<AccessEntityPtr> & multiple_entities, bool replace_if_exists = false, bool throw_if_exists = true);
std::vector<UUID> insert(const std::vector<AccessEntityPtr> & multiple_entities, const std::vector<UUID> & ids, bool replace_if_exists = false, bool throw_if_exists = true);
/// Inserts an entity to the storage. Returns ID of a new entry in the storage.
std::optional<UUID> tryInsert(const AccessEntityPtr & entity);
std::vector<UUID> tryInsert(const std::vector<AccessEntityPtr> & multiple_entities);
/// Inserts an entity to the storage. Return ID of a new entry in the storage.
/// Replaces an existing entry in the storage if the specified name already exists.
UUID insertOrReplace(const AccessEntityPtr & entity);
std::vector<UUID> insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities);
/// Removes an entity from the storage. Throws an exception if couldn't remove.
bool remove(const UUID & id, bool throw_if_not_exists = true);
std::vector<UUID> remove(const std::vector<UUID> & ids, bool throw_if_not_exists = true);
/// Removes an entity from the storage. Returns false if couldn't remove.
bool tryRemove(const UUID & id);
/// Removes multiple entities from the storage. Returns the list of successfully dropped.
std::vector<UUID> tryRemove(const std::vector<UUID> & ids);
using UpdateFunc = std::function<AccessEntityPtr(const AccessEntityPtr &)>;
/// Updates an entity stored in the storage. Throws an exception if couldn't update.
bool update(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists = true);
std::vector<UUID> update(const std::vector<UUID> & ids, const UpdateFunc & update_func, bool throw_if_not_exists = true);
/// Updates an entity stored in the storage. Returns false if couldn't update.
bool tryUpdate(const UUID & id, const UpdateFunc & update_func);
/// Updates multiple entities in the storage. Returns the list of successfully updated.
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
/// Finds a user, check the provided credentials and returns the ID of the user if they are valid.
/// Throws an exception if no such user or credentials are invalid.
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool allow_no_password, bool allow_plaintext_password) const;
std::optional<UUID> authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const;
/// Returns true if this storage can be stored to or restored from a backup.
virtual bool isBackupAllowed() const { return false; }
virtual bool isRestoreAllowed() const { return isBackupAllowed() && !isReadOnly(); }
/// Makes a backup of this access storage.
virtual void backup(BackupEntriesCollector & backup_entries_collector, const String & data_path_in_backup, AccessEntityType type) const;
virtual void restoreFromBackup(RestorerFromBackup & restorer);
protected:
virtual std::optional<UUID> findImpl(AccessEntityType type, const String & name) const = 0;
virtual std::vector<UUID> findAllImpl(AccessEntityType type) const = 0;
virtual AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const = 0;
virtual std::optional<std::pair<String, AccessEntityType>> readNameWithTypeImpl(const UUID & id, bool throw_if_not_exists) const;
virtual bool insertImpl(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
virtual bool removeImpl(const UUID & id, bool throw_if_not_exists);
virtual bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const;
virtual bool areCredentialsValid(const User & user, const Credentials & credentials, const ExternalAuthenticators & external_authenticators) const;
virtual bool isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const;
static UUID generateRandomID();
Poco::Logger * getLogger() const;
static String formatEntityTypeWithName(AccessEntityType type, const String & name) { return AccessEntityTypeInfo::get(type).formatEntityNameWithType(name); }
static void clearConflictsInEntitiesList(std::vector<std::pair<UUID, AccessEntityPtr>> & entities, const Poco::Logger * log_);
[[noreturn]] void throwNotFound(const UUID & id) const;
[[noreturn]] void throwNotFound(AccessEntityType type, const String & name) const;
[[noreturn]] static void throwBadCast(const UUID & id, AccessEntityType type, const String & name, AccessEntityType required_type);
[[noreturn]] void throwIDCollisionCannotInsert(
const UUID & id, AccessEntityType type, const String & name, AccessEntityType existing_type, const String & existing_name) const;
[[noreturn]] void throwNameCollisionCannotInsert(AccessEntityType type, const String & name) const;
[[noreturn]] void throwNameCollisionCannotRename(AccessEntityType type, const String & old_name, const String & new_name) const;
[[noreturn]] void throwReadonlyCannotInsert(AccessEntityType type, const String & name) const;
[[noreturn]] void throwReadonlyCannotUpdate(AccessEntityType type, const String & name) const;
[[noreturn]] void throwReadonlyCannotRemove(AccessEntityType type, const String & name) const;
[[noreturn]] static void throwAddressNotAllowed(const Poco::Net::IPAddress & address);
[[noreturn]] static void throwInvalidCredentials();
[[noreturn]] static void throwAuthenticationTypeNotAllowed(AuthenticationType auth_type);
[[noreturn]] void throwBackupNotAllowed() const;
[[noreturn]] void throwRestoreNotAllowed() const;
private:
const String storage_name;
mutable std::atomic<Poco::Logger *> log = nullptr;
};
template <typename EntityClassT>
std::shared_ptr<const EntityClassT> IAccessStorage::read(const UUID & id, bool throw_if_not_exists) const
{
auto entity = readImpl(id, throw_if_not_exists);
if constexpr (std::is_same_v<EntityClassT, IAccessEntity>)
return entity;
else
{
if (!entity)
return nullptr;
if (auto ptr = typeid_cast<std::shared_ptr<const EntityClassT>>(entity))
return ptr;
throwBadCast(id, entity->getType(), entity->getName(), EntityClassT::TYPE);
}
}
template <typename EntityClassT>
std::shared_ptr<const EntityClassT> IAccessStorage::read(const String & name, bool throw_if_not_exists) const
{
if (auto id = find<EntityClassT>(name))
return read<EntityClassT>(*id, throw_if_not_exists);
if (throw_if_not_exists)
throwNotFound(EntityClassT::TYPE, name);
else
return nullptr;
}
template <typename EntityClassT>
std::vector<AccessEntityPtr> IAccessStorage::read(const std::vector<UUID> & ids, bool throw_if_not_exists) const
{
std::vector<AccessEntityPtr> result;
result.reserve(ids.size());
for (const auto & id : ids)
result.push_back(read<EntityClassT>(id, throw_if_not_exists));
return result;
}
template <typename EntityClassT>
std::shared_ptr<const EntityClassT> IAccessStorage::tryRead(const UUID & id) const
{
return read<EntityClassT>(id, false);
}
template <typename EntityClassT>
std::shared_ptr<const EntityClassT> IAccessStorage::tryRead(const String & name) const
{
return read<EntityClassT>(name, false);
}
template <typename EntityClassT>
std::vector<std::pair<UUID, std::shared_ptr<const EntityClassT>>> IAccessStorage::readAllWithIDs() const
{
std::vector<std::pair<UUID, std::shared_ptr<const EntityClassT>>> entities;
for (const auto & id : findAll<EntityClassT>())
{
if (auto entity = tryRead<EntityClassT>(id))
entities.emplace_back(id, entity);
}
return entities;
}
inline bool parseAccessStorageName(IParser::Pos & pos, Expected & expected, String & storage_name)
{
return parseIdentifierOrStringLiteral(pos, expected, storage_name);
}
}
|