import pytest
import multiprocessing
import random
import threading
import time
import six
import library.python.func as func
def test_map0():
assert None is func.map0(lambda x: x + 1, None)
assert 3 == func.map0(lambda x: x + 1, 2)
assert None is func.map0(len, None)
assert 2 == func.map0(len, [1, 2])
def test_single():
assert 1 == func.single([1])
with pytest.raises(Exception):
assert 1 == func.single([])
with pytest.raises(Exception):
assert 1 == func.single([1, 2])
def test_memoize():
class Counter(object):
@staticmethod
def inc():
Counter._qty = getattr(Counter, '_qty', 0) + 1
return Counter._qty
@func.memoize()
def t1(a):
return a, Counter.inc()
@func.memoize()
def t2(a):
return a, Counter.inc()
@func.memoize()
def t3(a):
return a, Counter.inc()
@func.memoize()
def t4(a):
return a, Counter.inc()
@func.memoize()
def t5(a, b, c):
return a + b + c, Counter.inc()
@func.memoize()
def t6():
return Counter.inc()
@func.memoize(limit=2)
def t7(a, _b):
return a, Counter.inc()
assert (1, 1) == t1(1)
assert (1, 1) == t1(1)
assert (2, 2) == t1(2)
assert (2, 2) == t1(2)
assert (1, 3) == t2(1)
assert (1, 3) == t2(1)
assert (2, 4) == t2(2)
assert (2, 4) == t2(2)
assert (1, 5) == t3(1)
assert (1, 5) == t3(1)
assert (2, 6) == t3(2)
assert (2, 6) == t3(2)
assert (1, 7) == t4(1)
assert (1, 7) == t4(1)
assert (2, 8) == t4(2)
assert (2, 8) == t4(2)
assert (6, 9) == t5(1, 2, 3)
assert (6, 9) == t5(1, 2, 3)
assert (7, 10) == t5(1, 2, 4)
assert (7, 10) == t5(1, 2, 4)
assert 11 == t6()
assert 11 == t6()
assert (1, 12) == t7(1, None)
assert (2, 13) == t7(2, None)
assert (1, 12) == t7(1, None)
assert (2, 13) == t7(2, None)
# removed result for (1, None)
assert (3, 14) == t7(3, None)
assert (1, 15) == t7(1, None)
class ClassWithMemoizedMethod(object):
def __init__(self):
self.a = 0
@func.memoize(True)
def t(self, i):
self.a += i
return i
obj = ClassWithMemoizedMethod()
assert 10 == obj.t(10)
assert 10 == obj.a
assert 10 == obj.t(10)
assert 10 == obj.a
assert 20 == obj.t(20)
assert 30 == obj.a
assert 20 == obj.t(20)
assert 30 == obj.a
def test_first():
assert func.first([0, [], (), None, False, {}, 0.0, '1', 0]) == '1'
assert func.first([]) is None
assert func.first([0]) is None
def test_split():
assert func.split([1, 1], lambda x: x) == ([1, 1], [])
assert func.split([0, 0], lambda x: x) == ([], [0, 0])
assert func.split([], lambda x: x) == ([], [])
assert func.split([1, 0, 1], lambda x: x) == ([1, 1], [0])
def test_flatten_dict():
assert func.flatten_dict({"a": 1, "b": 2}) == {"a": 1, "b": 2}
assert func.flatten_dict({"a": 1}) == {"a": 1}
assert func.flatten_dict({}) == {}
assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}) == {"a": 1, "b.c.d": 2}
assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}, separator="/") == {"a": 1, "b/c/d": 2}
def test_threadsafe_singleton():
class ShouldBeSingle(six.with_metaclass(func.Singleton, object)):
def __new__(cls, *args, **kwargs):
time.sleep(0.1)
return super(ShouldBeSingle, cls).__new__(cls, *args, **kwargs)
threads_count = 100
threads = [None] * threads_count
results = [None] * threads_count
def class_factory(results, i):
time.sleep(0.1)
results[i] = ShouldBeSingle()
for i in range(threads_count):
threads[i] = threading.Thread(target=class_factory, args=(results, i))
for i in range(threads_count):
threads[i].start()
for i in range(threads_count):
threads[i].join()
assert len(set(results)) == 1
def test_memoize_thread_local():
class Counter(object):
def __init__(self, s):
self.val = s
def inc(self):
self.val += 1
return self.val
@func.memoize(thread_local=True)
def get_counter(start):
return Counter(start)
def th_inc():
assert get_counter(0).inc() == 1
assert get_counter(0).inc() == 2
assert get_counter(10).inc() == 11
assert get_counter(10).inc() == 12
th_inc()
th = threading.Thread(target=th_inc)
th.start()
th.join()
def test_memoize_not_thread_safe():
class Counter(object):
def __init__(self, s):
self.val = s
def inc(self):
self.val += 1
return self.val
@func.memoize(thread_safe=False)
def io_job(n):
time.sleep(0.1)
return Counter(n)
def worker(n):
assert io_job(n).inc() == n + 1
assert io_job(n).inc() == n + 2
assert io_job(n*10).inc() == n*10 + 1
assert io_job(n*10).inc() == n*10 + 2
assert io_job(n).inc() == n + 3
threads = []
for i in range(5):
threads.append(threading.Thread(target=worker, args=(i+1,)))
st = time.time()
for thread in threads:
thread.start()
for thread in threads:
thread.join()
elapsed_time = time.time() - st
assert elapsed_time < 0.5
def test_memoize_not_thread_safe_concurrent():
class Counter(object):
def __init__(self, s):
self.val = s
def inc(self):
self.val += 1
return self.val
@func.memoize(thread_safe=False)
def io_job(n):
time.sleep(0.1)
return Counter(n)
def worker():
io_job(100).inc()
th1 = threading.Thread(target=worker)
th2 = threading.Thread(target=worker)
th3 = threading.Thread(target=worker)
th1.start()
time.sleep(0.05)
th2.start()
th1.join()
assert io_job(100).inc() == 100 + 2
th3.start()
# th3 instantly got counter from memory
assert io_job(100).inc() == 100 + 4
th2.join()
# th2 shoud increase th1 counter
assert io_job(100).inc() == 100 + 6
def test_memoize_not_thread_safe_stress():
@func.memoize(thread_safe=False)
def job():
for _ in range(1000):
hash = random.getrandbits(128)
return hash
def worker(n):
hash = job()
results[n] = hash
num_threads = min(multiprocessing.cpu_count()*4, 64)
threads = []
results = [None for _ in range(num_threads)]
for i in range(num_threads):
thread = threading.Thread(target=worker, args=(i,))
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert len(set(results)) == 1
def test_memoize_thread_safe():
class Counter(object):
def __init__(self, s):
self.val = s
def inc(self):
self.val += 1
return self.val
@func.memoize(thread_safe=True)
def io_job(n):
time.sleep(0.05)
return Counter(n)
def worker(n):
assert io_job(n).inc() == n + 1
assert io_job(n).inc() == n + 2
assert io_job(n*10).inc() == n*10 + 1
assert io_job(n*10).inc() == n*10 + 2
threads = []
for i in range(5):
threads.append(threading.Thread(target=worker, args=(i+1,)))
st = time.time()
for thread in threads:
thread.start()
for thread in threads:
thread.join()
elapsed_time = time.time() - st
assert elapsed_time >= 0.5
if __name__ == '__main__':
pytest.main([__file__])