aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorl4m3r <l4m3r@yandex-team.com>2023-11-14 21:04:00 +0300
committerl4m3r <l4m3r@yandex-team.com>2023-11-14 21:39:29 +0300
commita8c9782fb7c6454c0afef92c5e5cb16cce719515 (patch)
treeafc96b93c83950b3104744281f533bd68074b4e4
parent3156e444f4cd107df59f9cc1bb85f213fc4c32ee (diff)
downloadydb-a8c9782fb7c6454c0afef92c5e5cb16cce719515.tar.gz
Fix: memoize multithreding optimization
Fix: memoize multithreding optimization
-rw-r--r--library/python/func/__init__.py28
-rw-r--r--library/python/func/ut/test_func.py137
-rw-r--r--library/python/func/ya.make2
3 files changed, 160 insertions, 7 deletions
diff --git a/library/python/func/__init__.py b/library/python/func/__init__.py
index 12a280bddc..15f2137f1d 100644
--- a/library/python/func/__init__.py
+++ b/library/python/func/__init__.py
@@ -1,6 +1,8 @@
import functools
import threading
import collections
+import contextlib
+import six
def map0(func, value):
@@ -76,20 +78,32 @@ class lazy_classproperty(object):
return getattr(owner, attr_name)
-def memoize(limit=0, thread_local=False):
+class nullcontext(object):
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+
+def memoize(limit=0, thread_local=False, thread_safe=True):
assert limit >= 0
+ assert limit <= 0 or thread_safe, 'memoize() it not thread safe enough to work in limiting and non-thread safe mode'
def decorator(func):
memory = {}
- lock = threading.Lock()
+
+ if six.PY3:
+ lock = contextlib.nullcontext()
+ else:
+ lock = nullcontext()
+ lock = threading.Lock() if thread_safe else lock
if limit:
keys = collections.deque()
def get(args):
- try:
- return memory[args]
- except KeyError:
+ if args not in memory:
with lock:
if args not in memory:
fargs = args[-1]
@@ -97,7 +111,7 @@ def memoize(limit=0, thread_local=False):
keys.append(args)
if len(keys) > limit:
del memory[keys.popleft()]
- return memory[args]
+ return memory[args]
else:
@@ -106,7 +120,7 @@ def memoize(limit=0, thread_local=False):
with lock:
if args not in memory:
fargs = args[-1]
- memory[args] = func(*fargs)
+ memory.setdefault(args, func(*fargs))
return memory[args]
if thread_local:
diff --git a/library/python/func/ut/test_func.py b/library/python/func/ut/test_func.py
index 3c4fad1a07..70a10d62cb 100644
--- a/library/python/func/ut/test_func.py
+++ b/library/python/func/ut/test_func.py
@@ -1,5 +1,8 @@
import pytest
+import multiprocessing
+import random
import threading
+import time
import library.python.func as func
@@ -158,5 +161,139 @@ def test_memoize_thread_local():
th.join()
+def test_memoize_not_thread_safe():
+ class Counter(object):
+ def __init__(self, s):
+ self.val = s
+
+ def inc(self):
+ self.val += 1
+ return self.val
+
+ @func.memoize(thread_safe=False)
+ def io_job(n):
+ time.sleep(0.1)
+ return Counter(n)
+
+ def worker(n):
+ assert io_job(n).inc() == n + 1
+ assert io_job(n).inc() == n + 2
+ assert io_job(n*10).inc() == n*10 + 1
+ assert io_job(n*10).inc() == n*10 + 2
+ assert io_job(n).inc() == n + 3
+
+ threads = []
+ for i in range(5):
+ threads.append(threading.Thread(target=worker, args=(i+1,)))
+
+ st = time.time()
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ elapsed_time = time.time() - st
+ assert elapsed_time < 0.5
+
+
+def test_memoize_not_thread_safe_concurrent():
+ class Counter(object):
+ def __init__(self, s):
+ self.val = s
+
+ def inc(self):
+ self.val += 1
+ return self.val
+
+ @func.memoize(thread_safe=False)
+ def io_job(n):
+ time.sleep(0.1)
+ return Counter(n)
+
+ def worker():
+ io_job(100).inc()
+
+ th1 = threading.Thread(target=worker)
+ th2 = threading.Thread(target=worker)
+ th3 = threading.Thread(target=worker)
+
+ th1.start()
+ time.sleep(0.05)
+ th2.start()
+
+ th1.join()
+ assert io_job(100).inc() == 100 + 2
+
+ th3.start()
+ # th3 instantly got counter from memory
+ assert io_job(100).inc() == 100 + 4
+
+ th2.join()
+ # th2 shoud increase th1 counter
+ assert io_job(100).inc() == 100 + 6
+
+
+def test_memoize_not_thread_safe_stress():
+ @func.memoize(thread_safe=False)
+ def job():
+ for _ in range(1000):
+ hash = random.getrandbits(128)
+ return hash
+
+ def worker(n):
+ hash = job()
+ results[n] = hash
+
+ num_threads = min(multiprocessing.cpu_count()*4, 64)
+ threads = []
+ results = [None for _ in range(num_threads)]
+
+ for i in range(num_threads):
+ thread = threading.Thread(target=worker, args=(i,))
+ threads.append(thread)
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ assert len(set(results)) == 1
+
+
+def test_memoize_thread_safe():
+ class Counter(object):
+ def __init__(self, s):
+ self.val = s
+
+ def inc(self):
+ self.val += 1
+ return self.val
+
+ @func.memoize(thread_safe=True)
+ def io_job(n):
+ time.sleep(0.05)
+ return Counter(n)
+
+ def worker(n):
+ assert io_job(n).inc() == n + 1
+ assert io_job(n).inc() == n + 2
+ assert io_job(n*10).inc() == n*10 + 1
+ assert io_job(n*10).inc() == n*10 + 2
+
+ threads = []
+ for i in range(5):
+ threads.append(threading.Thread(target=worker, args=(i+1,)))
+
+ st = time.time()
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ elapsed_time = time.time() - st
+ assert elapsed_time >= 0.5
+
+
if __name__ == '__main__':
pytest.main([__file__])
diff --git a/library/python/func/ya.make b/library/python/func/ya.make
index 47ab634913..72ebac784a 100644
--- a/library/python/func/ya.make
+++ b/library/python/func/ya.make
@@ -2,6 +2,8 @@ PY23_LIBRARY()
PY_SRCS(__init__.py)
+PEERDIR(contrib/python/six)
+
END()
RECURSE_FOR_TESTS(