diff options
author | l4m3r <l4m3r@yandex-team.com> | 2023-11-14 21:04:00 +0300 |
---|---|---|
committer | l4m3r <l4m3r@yandex-team.com> | 2023-11-14 21:39:29 +0300 |
commit | a8c9782fb7c6454c0afef92c5e5cb16cce719515 (patch) | |
tree | afc96b93c83950b3104744281f533bd68074b4e4 | |
parent | 3156e444f4cd107df59f9cc1bb85f213fc4c32ee (diff) | |
download | ydb-a8c9782fb7c6454c0afef92c5e5cb16cce719515.tar.gz |
Fix: memoize multithreding optimization
Fix: memoize multithreding optimization
-rw-r--r-- | library/python/func/__init__.py | 28 | ||||
-rw-r--r-- | library/python/func/ut/test_func.py | 137 | ||||
-rw-r--r-- | library/python/func/ya.make | 2 |
3 files changed, 160 insertions, 7 deletions
diff --git a/library/python/func/__init__.py b/library/python/func/__init__.py index 12a280bddc..15f2137f1d 100644 --- a/library/python/func/__init__.py +++ b/library/python/func/__init__.py @@ -1,6 +1,8 @@ import functools import threading import collections +import contextlib +import six def map0(func, value): @@ -76,20 +78,32 @@ class lazy_classproperty(object): return getattr(owner, attr_name) -def memoize(limit=0, thread_local=False): +class nullcontext(object): + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + pass + + +def memoize(limit=0, thread_local=False, thread_safe=True): assert limit >= 0 + assert limit <= 0 or thread_safe, 'memoize() it not thread safe enough to work in limiting and non-thread safe mode' def decorator(func): memory = {} - lock = threading.Lock() + + if six.PY3: + lock = contextlib.nullcontext() + else: + lock = nullcontext() + lock = threading.Lock() if thread_safe else lock if limit: keys = collections.deque() def get(args): - try: - return memory[args] - except KeyError: + if args not in memory: with lock: if args not in memory: fargs = args[-1] @@ -97,7 +111,7 @@ def memoize(limit=0, thread_local=False): keys.append(args) if len(keys) > limit: del memory[keys.popleft()] - return memory[args] + return memory[args] else: @@ -106,7 +120,7 @@ def memoize(limit=0, thread_local=False): with lock: if args not in memory: fargs = args[-1] - memory[args] = func(*fargs) + memory.setdefault(args, func(*fargs)) return memory[args] if thread_local: diff --git a/library/python/func/ut/test_func.py b/library/python/func/ut/test_func.py index 3c4fad1a07..70a10d62cb 100644 --- a/library/python/func/ut/test_func.py +++ b/library/python/func/ut/test_func.py @@ -1,5 +1,8 @@ import pytest +import multiprocessing +import random import threading +import time import library.python.func as func @@ -158,5 +161,139 @@ def test_memoize_thread_local(): th.join() +def test_memoize_not_thread_safe(): + class Counter(object): + def __init__(self, s): + self.val = s + + def inc(self): + self.val += 1 + return self.val + + @func.memoize(thread_safe=False) + def io_job(n): + time.sleep(0.1) + return Counter(n) + + def worker(n): + assert io_job(n).inc() == n + 1 + assert io_job(n).inc() == n + 2 + assert io_job(n*10).inc() == n*10 + 1 + assert io_job(n*10).inc() == n*10 + 2 + assert io_job(n).inc() == n + 3 + + threads = [] + for i in range(5): + threads.append(threading.Thread(target=worker, args=(i+1,))) + + st = time.time() + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + elapsed_time = time.time() - st + assert elapsed_time < 0.5 + + +def test_memoize_not_thread_safe_concurrent(): + class Counter(object): + def __init__(self, s): + self.val = s + + def inc(self): + self.val += 1 + return self.val + + @func.memoize(thread_safe=False) + def io_job(n): + time.sleep(0.1) + return Counter(n) + + def worker(): + io_job(100).inc() + + th1 = threading.Thread(target=worker) + th2 = threading.Thread(target=worker) + th3 = threading.Thread(target=worker) + + th1.start() + time.sleep(0.05) + th2.start() + + th1.join() + assert io_job(100).inc() == 100 + 2 + + th3.start() + # th3 instantly got counter from memory + assert io_job(100).inc() == 100 + 4 + + th2.join() + # th2 shoud increase th1 counter + assert io_job(100).inc() == 100 + 6 + + +def test_memoize_not_thread_safe_stress(): + @func.memoize(thread_safe=False) + def job(): + for _ in range(1000): + hash = random.getrandbits(128) + return hash + + def worker(n): + hash = job() + results[n] = hash + + num_threads = min(multiprocessing.cpu_count()*4, 64) + threads = [] + results = [None for _ in range(num_threads)] + + for i in range(num_threads): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + assert len(set(results)) == 1 + + +def test_memoize_thread_safe(): + class Counter(object): + def __init__(self, s): + self.val = s + + def inc(self): + self.val += 1 + return self.val + + @func.memoize(thread_safe=True) + def io_job(n): + time.sleep(0.05) + return Counter(n) + + def worker(n): + assert io_job(n).inc() == n + 1 + assert io_job(n).inc() == n + 2 + assert io_job(n*10).inc() == n*10 + 1 + assert io_job(n*10).inc() == n*10 + 2 + + threads = [] + for i in range(5): + threads.append(threading.Thread(target=worker, args=(i+1,))) + + st = time.time() + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + elapsed_time = time.time() - st + assert elapsed_time >= 0.5 + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/library/python/func/ya.make b/library/python/func/ya.make index 47ab634913..72ebac784a 100644 --- a/library/python/func/ya.make +++ b/library/python/func/ya.make @@ -2,6 +2,8 @@ PY23_LIBRARY() PY_SRCS(__init__.py) +PEERDIR(contrib/python/six) + END() RECURSE_FOR_TESTS( |