aboutsummaryrefslogblamecommitdiffstats
path: root/library/python/func/ut/test_func.py
blob: 90fe125a118b5b89579acc3ecc0d367f6543e0bc (plain) (tree)
1
2
3
4
5
6
7
8
             
                      
                
           
          
 
                                  

                
                                                   
                                             
                                       
















                                                          
                   

                               
                   

                               
                   

                               
                   

                               
                   

                                       
                   

                            


                               


























                                 








































                                                                       






                                                                                                   
























                                                                              
























                                          




































































































































                                                                    
                           
import pytest
import multiprocessing
import random
import threading
import time
import six

import library.python.func as func


def test_map0():
    assert None is func.map0(lambda x: x + 1, None)
    assert 3 == func.map0(lambda x: x + 1, 2)
    assert None is func.map0(len, None)
    assert 2 == func.map0(len, [1, 2])


def test_single():
    assert 1 == func.single([1])
    with pytest.raises(Exception):
        assert 1 == func.single([])
    with pytest.raises(Exception):
        assert 1 == func.single([1, 2])


def test_memoize():
    class Counter(object):
        @staticmethod
        def inc():
            Counter._qty = getattr(Counter, '_qty', 0) + 1
            return Counter._qty

    @func.memoize()
    def t1(a):
        return a, Counter.inc()

    @func.memoize()
    def t2(a):
        return a, Counter.inc()

    @func.memoize()
    def t3(a):
        return a, Counter.inc()

    @func.memoize()
    def t4(a):
        return a, Counter.inc()

    @func.memoize()
    def t5(a, b, c):
        return a + b + c, Counter.inc()

    @func.memoize()
    def t6():
        return Counter.inc()

    @func.memoize(limit=2)
    def t7(a, _b):
        return a, Counter.inc()

    assert (1, 1) == t1(1)
    assert (1, 1) == t1(1)
    assert (2, 2) == t1(2)
    assert (2, 2) == t1(2)

    assert (1, 3) == t2(1)
    assert (1, 3) == t2(1)
    assert (2, 4) == t2(2)
    assert (2, 4) == t2(2)

    assert (1, 5) == t3(1)
    assert (1, 5) == t3(1)
    assert (2, 6) == t3(2)
    assert (2, 6) == t3(2)

    assert (1, 7) == t4(1)
    assert (1, 7) == t4(1)
    assert (2, 8) == t4(2)
    assert (2, 8) == t4(2)

    assert (6, 9) == t5(1, 2, 3)
    assert (6, 9) == t5(1, 2, 3)
    assert (7, 10) == t5(1, 2, 4)
    assert (7, 10) == t5(1, 2, 4)

    assert 11 == t6()
    assert 11 == t6()

    assert (1, 12) == t7(1, None)
    assert (2, 13) == t7(2, None)
    assert (1, 12) == t7(1, None)
    assert (2, 13) == t7(2, None)
    # removed result for (1, None)
    assert (3, 14) == t7(3, None)
    assert (1, 15) == t7(1, None)

    class ClassWithMemoizedMethod(object):
        def __init__(self):
            self.a = 0

        @func.memoize(True)
        def t(self, i):
            self.a += i
            return i

    obj = ClassWithMemoizedMethod()
    assert 10 == obj.t(10)
    assert 10 == obj.a
    assert 10 == obj.t(10)
    assert 10 == obj.a

    assert 20 == obj.t(20)
    assert 30 == obj.a
    assert 20 == obj.t(20)
    assert 30 == obj.a


def test_first():
    assert func.first([0, [], (), None, False, {}, 0.0, '1', 0]) == '1'
    assert func.first([]) is None
    assert func.first([0]) is None


def test_split():
    assert func.split([1, 1], lambda x: x) == ([1, 1], [])
    assert func.split([0, 0], lambda x: x) == ([], [0, 0])
    assert func.split([], lambda x: x) == ([], [])
    assert func.split([1, 0, 1], lambda x: x) == ([1, 1], [0])


def test_flatten_dict():
    assert func.flatten_dict({"a": 1, "b": 2}) == {"a": 1, "b": 2}
    assert func.flatten_dict({"a": 1}) == {"a": 1}
    assert func.flatten_dict({}) == {}
    assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}) == {"a": 1, "b.c.d": 2}
    assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}, separator="/") == {"a": 1, "b/c/d": 2}


def test_threadsafe_singleton():
    class ShouldBeSingle(six.with_metaclass(func.Singleton, object)):
        def __new__(cls, *args, **kwargs):
            time.sleep(0.1)
            return super(ShouldBeSingle, cls).__new__(cls, *args, **kwargs)

    threads_count = 100
    threads = [None] * threads_count
    results = [None] * threads_count

    def class_factory(results, i):
        time.sleep(0.1)
        results[i] = ShouldBeSingle()

    for i in range(threads_count):
        threads[i] = threading.Thread(target=class_factory, args=(results, i))

    for i in range(threads_count):
        threads[i].start()

    for i in range(threads_count):
        threads[i].join()

    assert len(set(results)) == 1


def test_memoize_thread_local():
    class Counter(object):
        def __init__(self, s):
            self.val = s

        def inc(self):
            self.val += 1
            return self.val

    @func.memoize(thread_local=True)
    def get_counter(start):
        return Counter(start)

    def th_inc():
        assert get_counter(0).inc() == 1
        assert get_counter(0).inc() == 2
        assert get_counter(10).inc() == 11
        assert get_counter(10).inc() == 12

    th_inc()

    th = threading.Thread(target=th_inc)
    th.start()
    th.join()


def test_memoize_not_thread_safe():
    class Counter(object):
        def __init__(self, s):
            self.val = s

        def inc(self):
            self.val += 1
            return self.val

    @func.memoize(thread_safe=False)
    def io_job(n):
        time.sleep(0.1)
        return Counter(n)

    def worker(n):
        assert io_job(n).inc() == n + 1
        assert io_job(n).inc() == n + 2
        assert io_job(n*10).inc() == n*10 + 1
        assert io_job(n*10).inc() == n*10 + 2
        assert io_job(n).inc() == n + 3

    threads = []
    for i in range(5):
        threads.append(threading.Thread(target=worker, args=(i+1,)))

    st = time.time()

    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    elapsed_time = time.time() - st
    assert elapsed_time < 0.5


def test_memoize_not_thread_safe_concurrent():
    class Counter(object):
        def __init__(self, s):
            self.val = s

        def inc(self):
            self.val += 1
            return self.val

    @func.memoize(thread_safe=False)
    def io_job(n):
        time.sleep(0.1)
        return Counter(n)

    def worker():
        io_job(100).inc()

    th1 = threading.Thread(target=worker)
    th2 = threading.Thread(target=worker)
    th3 = threading.Thread(target=worker)

    th1.start()
    time.sleep(0.05)
    th2.start()

    th1.join()
    assert io_job(100).inc() == 100 + 2

    th3.start()
    # th3 instantly got counter from memory
    assert io_job(100).inc() == 100 + 4

    th2.join()
    # th2 shoud increase th1 counter
    assert io_job(100).inc() == 100 + 6


def test_memoize_not_thread_safe_stress():
    @func.memoize(thread_safe=False)
    def job():
        for _ in range(1000):
            hash = random.getrandbits(128)
        return hash

    def worker(n):
        hash = job()
        results[n] = hash

    num_threads = min(multiprocessing.cpu_count()*4, 64)
    threads = []
    results = [None for _ in range(num_threads)]

    for i in range(num_threads):
        thread = threading.Thread(target=worker, args=(i,))
        threads.append(thread)

    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()
    assert len(set(results)) == 1


def test_memoize_thread_safe():
    class Counter(object):
        def __init__(self, s):
            self.val = s

        def inc(self):
            self.val += 1
            return self.val

    @func.memoize(thread_safe=True)
    def io_job(n):
        time.sleep(0.05)
        return Counter(n)

    def worker(n):
        assert io_job(n).inc() == n + 1
        assert io_job(n).inc() == n + 2
        assert io_job(n*10).inc() == n*10 + 1
        assert io_job(n*10).inc() == n*10 + 2

    threads = []
    for i in range(5):
        threads.append(threading.Thread(target=worker, args=(i+1,)))

    st = time.time()

    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    elapsed_time = time.time() - st
    assert elapsed_time >= 0.5


if __name__ == '__main__':
    pytest.main([__file__])