aboutsummaryrefslogtreecommitdiffstats
path: root/build/scripts/gen_mx_table.py
blob: cce69e5cfbce9be5f81c84b573a0df7922af19b5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys 
 
tmpl = """ 
#include "yabs_mx_calc_table.h" 
 
#include <kernel/matrixnet/mn_sse.h> 
 
#include <library/cpp/archive/yarchive.h>
 
#include <util/memory/blob.h> 
#include <util/generic/hash.h> 
#include <util/generic/ptr.h> 
#include <util/generic/singleton.h> 
 
using namespace NMatrixnet; 
 
extern "C" { 
    extern const unsigned char MxFormulas[]; 
    extern const ui32 MxFormulasSize; 
} 
 
namespace { 
    struct TFml: public TBlob, public TMnSseInfo { 
        inline TFml(const TBlob& b) 
            : TBlob(b) 
            , TMnSseInfo(Data(), Size()) 
        { 
        } 
    }; 
 
    struct TFormulas: public THashMap<size_t, TAutoPtr<TFml>> {
        inline TFormulas() { 
            TBlob b = TBlob::NoCopy(MxFormulas, MxFormulasSize); 
            TArchiveReader ar(b); 
            %s 
        } 
 
        inline const TMnSseInfo& at(size_t n) const noexcept {
            return *find(n)->second; 
        } 
    }; 
 
    %s 
 
    static func_descr_t yabs_funcs[] = { 
        %s 
    }; 
} 
 
yabs_mx_calc_table_t yabs_mx_calc_table = {YABS_MX_CALC_VERSION, 10000, 0, yabs_funcs}; 
""" 
 
if __name__ == '__main__':
    init = []
    body = []
    defs = {}
 
    for i in sys.argv[1:]:
        name = i.replace('.', '_')
        num = long(name.split('_')[1])
 
        init.append('(*this)[%s] = new TFml(ar.ObjectBlobByKey("%s"));' % (num, '/' + i))
 
        f1 = 'static void yabs_%s(size_t count, const float** args, double* res) {Singleton<TFormulas>()->at(%s).DoCalcRelevs(args, res, count);}' % (name, num)
        f2 = 'static size_t yabs_%s_factor_count() {return Singleton<TFormulas>()->at(%s).MaxFactorIndex() + 1;}' % (name, num)
 
        body.append(f1)
        body.append(f2)
 
        d1 = 'yabs_%s' % name
        d2 = 'yabs_%s_factor_count' % name
 
        defs[num] = '{%s, %s}' % (d1, d2)
 
    print tmpl % ('\n'.join(init), '\n\n'.join(body), ',\n'.join((defs.get(i, '{nullptr, nullptr}') for i in range(0, 10000))))