diff options
author | l4m3r <l4m3r@yandex-team.com> | 2023-11-14 21:04:00 +0300 |
---|---|---|
committer | l4m3r <l4m3r@yandex-team.com> | 2023-11-14 21:39:29 +0300 |
commit | a8c9782fb7c6454c0afef92c5e5cb16cce719515 (patch) | |
tree | afc96b93c83950b3104744281f533bd68074b4e4 /library/python/func/ut/test_func.py | |
parent | 3156e444f4cd107df59f9cc1bb85f213fc4c32ee (diff) | |
download | ydb-a8c9782fb7c6454c0afef92c5e5cb16cce719515.tar.gz |
Fix: memoize multithreding optimization
Fix: memoize multithreding optimization
Diffstat (limited to 'library/python/func/ut/test_func.py')
-rw-r--r-- | library/python/func/ut/test_func.py | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/library/python/func/ut/test_func.py b/library/python/func/ut/test_func.py index 3c4fad1a07..70a10d62cb 100644 --- a/library/python/func/ut/test_func.py +++ b/library/python/func/ut/test_func.py @@ -1,5 +1,8 @@ import pytest +import multiprocessing +import random import threading +import time import library.python.func as func @@ -158,5 +161,139 @@ def test_memoize_thread_local(): th.join() +def test_memoize_not_thread_safe(): + class Counter(object): + def __init__(self, s): + self.val = s + + def inc(self): + self.val += 1 + return self.val + + @func.memoize(thread_safe=False) + def io_job(n): + time.sleep(0.1) + return Counter(n) + + def worker(n): + assert io_job(n).inc() == n + 1 + assert io_job(n).inc() == n + 2 + assert io_job(n*10).inc() == n*10 + 1 + assert io_job(n*10).inc() == n*10 + 2 + assert io_job(n).inc() == n + 3 + + threads = [] + for i in range(5): + threads.append(threading.Thread(target=worker, args=(i+1,))) + + st = time.time() + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + elapsed_time = time.time() - st + assert elapsed_time < 0.5 + + +def test_memoize_not_thread_safe_concurrent(): + class Counter(object): + def __init__(self, s): + self.val = s + + def inc(self): + self.val += 1 + return self.val + + @func.memoize(thread_safe=False) + def io_job(n): + time.sleep(0.1) + return Counter(n) + + def worker(): + io_job(100).inc() + + th1 = threading.Thread(target=worker) + th2 = threading.Thread(target=worker) + th3 = threading.Thread(target=worker) + + th1.start() + time.sleep(0.05) + th2.start() + + th1.join() + assert io_job(100).inc() == 100 + 2 + + th3.start() + # th3 instantly got counter from memory + assert io_job(100).inc() == 100 + 4 + + th2.join() + # th2 shoud increase th1 counter + assert io_job(100).inc() == 100 + 6 + + +def test_memoize_not_thread_safe_stress(): + @func.memoize(thread_safe=False) + def job(): + for _ in range(1000): + hash = random.getrandbits(128) + return hash + + def worker(n): + hash = job() + results[n] = hash + + num_threads = min(multiprocessing.cpu_count()*4, 64) + threads = [] + results = [None for _ in range(num_threads)] + + for i in range(num_threads): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + assert len(set(results)) == 1 + + +def test_memoize_thread_safe(): + class Counter(object): + def __init__(self, s): + self.val = s + + def inc(self): + self.val += 1 + return self.val + + @func.memoize(thread_safe=True) + def io_job(n): + time.sleep(0.05) + return Counter(n) + + def worker(n): + assert io_job(n).inc() == n + 1 + assert io_job(n).inc() == n + 2 + assert io_job(n*10).inc() == n*10 + 1 + assert io_job(n*10).inc() == n*10 + 2 + + threads = [] + for i in range(5): + threads.append(threading.Thread(target=worker, args=(i+1,))) + + st = time.time() + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + elapsed_time = time.time() - st + assert elapsed_time >= 0.5 + + if __name__ == '__main__': pytest.main([__file__]) |