#!/usr/bin/env python
# Ymake MatrixNet support

import sys
import os
import shutil
import re
import subprocess


def get_value(val):
    dct = val.split('=', 1)
    if len(dct) > 1:
        return dct[1]
    return ''


class BuildMnBase(object):
    def Run(self, mninfo, mnname, mnrankingSuffix, mncppPath, check=False, ptr=False, multi=False):
        self.mninfo = mninfo
        self.mnname = mnname
        self.mnrankingSuffix = mnrankingSuffix
        self.mncppPath = mncppPath
        self.check = check
        self.ptr = ptr
        self.multi = multi
        dataprefix = "MN_External_"
        mninfoName = os.path.basename(self.mninfo)
        data = dataprefix + mnname
        datasize = data + "Size"

        if self.multi:
            if self.ptr:
                mntype = "const NMatrixnet::TMnMultiCategPtr"
                mnload = "(new NMatrixnet::TMnMultiCateg( {1}, {2}, \"{0}\"))".format(mninfoName, data, datasize)
            else:
                mntype = "const NMatrixnet::TMnMultiCateg"
                mnload = "({1}, {2}, \"{0}\")".format(mninfoName, data, datasize)
        else:
            if self.ptr:
                mntype = "const NMatrixnet::TMnSsePtr"
                mnload = "(new NMatrixnet::TMnSseInfo({1}, {2}, \"{0}\"))".format(mninfoName, data, datasize)
            else:
                mntype = "const NMatrixnet::TMnSseInfo"
                mnload = "({1}, {2}, \"{0}\")".format(mninfoName, data, datasize)

        if self.check:
            self.CheckMn()

        mncpptmpPath = self.mncppPath + ".tmp"
        mncpptmp = open(mncpptmpPath, 'w')

        if self.multi:
            mncpptmp.write("#include <kernel/matrixnet/mn_multi_categ.h>\n")
        else:
            mncpptmp.write("#include <kernel/matrixnet/mn_sse.h>\n")

        rodatapath = os.path.dirname(self.mncppPath) + "/" + dataprefix + self.mnname + ".rodata"
        mncpptmp.write("namespace{\n")
        mncpptmp.write("    extern \"C\" {\n")
        mncpptmp.write("        extern const unsigned char {1}{0}[];\n".format(self.mnname, dataprefix))
        mncpptmp.write("        extern const ui32 {1}{0}Size;\n".format(self.mnname, dataprefix))
        mncpptmp.write("    }\n")
        mncpptmp.write("}\n")
        archiverCall = subprocess.Popen([self.archiver, "-q", "-p", "-o", rodatapath, self.mninfo], stdout=None, stderr=subprocess.PIPE)
        archiverCall.wait()
        mncpptmp.write("extern {0} {1};\n".format(mntype, self.mnname))
        mncpptmp.write("{0} {1}{2};".format(mntype, self.mnname, mnload))
        mncpptmp.close()
        shutil.move(mncpptmpPath, self.mncppPath)

    def CheckMn(self):
        if not self.fml_unused_tool:
            print >>sys.stderr, "fml_unused_tool undefined!"
        failed_msg = "fml_unused_tool failed: {0} -A {1} -e -r {2}".format(self.fml_unused_tool, self.SrcRoot, self.mninfo)
        assert not subprocess.call([self.fml_unused_tool, "-A", self.SrcRoot, "-e", "-r", self.mninfo]), failed_msg


class BuildMn(BuildMnBase):
    def Run(self, argv):
        if len(argv) < 6:
            print >>sys.stderr, "BuildMn.Run(<ARCADIA_ROOT> <archiver> <mninfo> <mnname> <mnrankingSuffix> <cppOutput> [params...])"
            sys.exit(1)

        self.SrcRoot = argv[0]
        self.archiver = argv[1]

        mninfo = argv[2]
        mnname = argv[3]
        mnrankingSuffix = argv[4]
        mncppPath = argv[5]
        check = False
        ptr = False
        multi = False
        self.fml_unused_tool = ''
        for param in argv[6:]:
            if param == "CHECK":
                check = True
            elif param == "PTR":
                ptr = True
            elif param == "MULTI":
                multi = True
            elif param.startswith('fml_tool='):
                self.fml_unused_tool = get_value(param)
            else:
                print >>sys.stdout, "Unknown param: {0}".format(param)
        super(BuildMn, self).Run(mninfo, mnname, mnrankingSuffix, mncppPath, check=check, ptr=ptr, multi=multi)


class BuildMns(BuildMnBase):
    def InitBase(self, listname, mnrankingSuffix):
        self.autogen = '// DO NOT EDIT THIS FILE DIRECTLY, AUTOGENERATED!\n'
        self.mnrankingSuffix = mnrankingSuffix
        self.mnlistname = listname + mnrankingSuffix
        self.mnlistelem = "const NMatrixnet::TMnSsePtr*"
        mnlisttype = "TMap< TString, {0} >".format(self.mnlistelem)
        self.mnlist = "const {0} {1}".format(mnlisttype, self.mnlistname)

        self.mnmultilistname = "{0}{1}Multi".format(listname, self.mnrankingSuffix)
        self.mnmultilistelem = "const NMatrixnet::TMnMultiCategPtr*"
        mnmultilisttype = "TMap< TString, {0} >".format(self.mnmultilistelem)
        self.mnmultilist = "const {0} {1}".format(mnmultilisttype, self.mnmultilistname)

    def InitForAll(self, argv):
        if len(argv) < 8:
            print >>sys.stderr, "BuildMns.InitForAll(<ARCADIA_ROOT> <BINDIR> <archiver>  <listname> <mnranking_suffix> <hdrfile> <srcfile> <mninfos> [fml_tool=<fml_unused_tool> CHECK])"
            sys.exit(1)

        bmns_args = []
        self.check = False
        self.fml_unused_tool = ''
        for arg in argv:
            if arg == "CHECK":
                self.check = True
            elif arg.startswith('fml_tool='):
                self.fml_unused_tool = get_value(arg)
            else:
                bmns_args.append(arg)

        self.SrcRoot = bmns_args[0]
        self.BINDIR = bmns_args[1]
        self.archiver = bmns_args[2]
        self.listname = bmns_args[3]
        self.mnrankingSuffix = get_value(bmns_args[4])
        self.hdrfile = bmns_args[5]
        self.srcfile = bmns_args[6]
        self.mninfos = bmns_args[7:]

        self.InitBase(self.listname, self.mnrankingSuffix)

    def InitForHeader(self, argv):
        if len(argv) < 4:
            print >>sys.stderr, "BuildMns.InitForHeader(<listname> <rankingSuffix> <hdrfile> <mninfos...>)"
            sys.exit(1)

        self.listname = argv[0]
        self.mnrankingSuffix = get_value(argv[1])
        self.hdrfile = argv[2]
        self.mninfos = argv[3:]

        self.InitBase(self.listname, self.mnrankingSuffix)

    def InitForCpp(self, argv):
        if len(argv) < 5:
            print >>sys.stderr, "BuildMns.InitForCpp(<listname> <rankingSuffix> <hdrfile> <srcfile> <mninfos...>)"
            sys.exit(1)

        self.listname = argv[0]
        self.mnrankingSuffix = get_value(argv[1])
        self.hdrfile = argv[2]
        self.srcfile = argv[3]
        self.mninfos = argv[4:]

        self.InitBase(self.listname, self.mnrankingSuffix)

    def InitForFiles(self, argv):
        if len(argv) < 7:
            print >>sys.stderr, "BuildMns.InitForFiles(<ARCADIA_ROOT> <BINDIR> <archiver> <fml_unused_tool> <listname> <rankingSuffix> <mninfos...> [CHECK])"
            sys.exit(1)

        bmns_args = []
        self.check = False
        self.fml_unused_tool = ''
        for arg in argv:
            if arg == "CHECK":
                self.check = True
            elif arg.startswith('fml_tool='):
                self.fml_unused_tool = get_value(arg)
            else:
                bmns_args.append(arg)

        self.SrcRoot = bmns_args[0]
        self.BINDIR = bmns_args[1]
        self.archiver = bmns_args[2]
        self.listname = bmns_args[3]
        self.mnrankingSuffix = get_value(bmns_args[4])
        self.mninfos = bmns_args[5:]

    def BuildMnsHeader(self):
        if self.mninfos:
            self.mninfos = sorted(set(self.mninfos))

        tmpHdrPath = self.hdrfile + ".tmp"
        tmpHdrFile = open(tmpHdrPath, 'w')

        tmpHdrFile.write(self.autogen)
        tmpHdrFile.write("#include <kernel/matrixnet/mn_sse.h>\n")
        tmpHdrFile.write("#include <kernel/matrixnet/mn_multi_categ.h>\n\n")
        tmpHdrFile.write("extern {0};\n".format(self.mnlist))
        tmpHdrFile.write("extern {0};\n".format(self.mnmultilist))

        for item in self.mninfos:
            mnfilename = os.path.basename(item)
            mnfilename, ext = os.path.splitext(mnfilename)

            mnname = re.sub("[^-a-zA-Z0-9_]", "_", mnfilename)

            if ext == ".info":
                mnname = "staticMn{0}{1}Ptr".format(self.mnrankingSuffix, mnname)
                tmpHdrFile.write("extern const NMatrixnet::TMnSsePtr {0};\n".format(mnname))
            elif ext == ".mnmc":
                mnname = "staticMnMulti{0}{1}Ptr".format(self.mnrankingSuffix, mnname)
                tmpHdrFile.write("extern const NMatrixnet::TMnMultiCategPtr {0};\n".format(mnname))

        tmpHdrFile.close()
        shutil.move(tmpHdrPath, self.hdrfile)

    def BuildMnFiles(self):
        for item in self.mninfos:
            mnfilename = os.path.basename(item)
            mnfilename, ext = os.path.splitext(mnfilename)

            mnname = re.sub("[^-a-zA-Z0-9_]", "_", mnfilename)

            if ext == ".info":
                mnname = "staticMn{0}{1}Ptr".format(self.mnrankingSuffix, mnname)
                super(BuildMns, self).Run(item, mnname, self.mnrankingSuffix, self.BINDIR + "/mn.{0}.cpp".format(mnname), check=self.check, ptr=True, multi=False)
            elif ext == ".mnmc":
                mnname = "staticMnMulti{0}{1}Ptr".format(self.mnrankingSuffix, mnname)
                # BUILD_MN_PTR_MULTI
                super(BuildMns, self).Run(item, mnname, self.mnrankingSuffix, self.BINDIR + "/mnmulti.{0}.cpp".format(mnname), check=False, ptr=True, multi=True)

    def BuildMnsCpp(self):
        if self.mninfos:
            self.mninfos = sorted(set(self.mninfos))

        tmpSrcPath = self.srcfile + ".tmp"
        tmpSrcFile = open(tmpSrcPath, 'w')
        hdrrel = os.path.basename(self.hdrfile)

        mnnames = []
        mnmultinames = []
        for item in self.mninfos:
            mnfilename = os.path.basename(item)
            mnfilename, ext = os.path.splitext(mnfilename)

            if ext == ".info":
                mnnames.append(mnfilename)
            elif ext == ".mnmc":
                mnmultinames.append(mnfilename)

        tmpSrcFile.write(self.autogen)
        tmpSrcFile.write("#include \"{0}\"\n\n".format(hdrrel))

        if mnnames:
            mndata = self.mnlistname + "_data"
            tmpSrcFile.write("static const std::pair< TString, {0} > {1}[] = {{\n".format(self.mnlistelem, mndata))
            for item in mnnames:
                mnname = re.sub("[^-a-zA-Z0-9_]", "_", item)
                tmpSrcFile.write("    std::make_pair(TString(\"{0}\"), &staticMn{1}{2}Ptr),\n".format(item, self.mnrankingSuffix, mnname))
            tmpSrcFile.write("};\n")
            tmpSrcFile.write("{0}({1},{1} + sizeof({1}) / sizeof({1}[0]));\n\n".format(self.mnlist, mndata))
        else:
            tmpSrcFile.write("{0};\n\n".format(self.mnlist))

        if mnmultinames:
            mnmultidata = self.mnmultilistname + "_data"
            tmpSrcFile.write("static const std::pair< TString, {0} > {1}[] = {{\n".format(self.mnmultilistelem, mnmultidata))
            for item in mnmultinames:
                mnname = re.sub("[^-a-zA-Z0-9_]", "_", item)
                tmpSrcFile.write("    std::make_pair(TString(\"{0}\"), &staticMnMulti{1}{2}Ptr),\n".format(item, self.mnrankingSuffix, mnname))
            tmpSrcFile.write("};\n")
            tmpSrcFile.write("{0}({1},{1} + sizeof({1}) / sizeof({1}[0]));\n".format(self.mnmultilist, mnmultidata))
        else:
            tmpSrcFile.write("{0};\n".format(self.mnmultilist))

        tmpSrcFile.close()
        shutil.move(tmpSrcPath, self.srcfile)


def BuildMnsAllF(argv):
    bldMns = BuildMns()
    bldMns.InitForAll(argv)
    bldMns.BuildMnsCpp()
    bldMns.BuildMnsHeader()
    bldMns.BuildMnFiles()


def BuildMnsCppF(argv):
    bldMns = BuildMns()
    bldMns.InitForCpp(argv)
    bldMns.BuildMnsCpp()


def BuildMnsHeaderF(argv):
    bldMns = BuildMns()
    bldMns.InitForHeader(argv)
    bldMns.BuildMnsHeader()


def BuildMnsFilesF(argv):
    bldMns = BuildMns()
    bldMns.InitForFiles(argv)
    bldMns.BuildMnFiles()


def BuildMnF(argv):
    bldMn = BuildMn()
    bldMn.Run(argv)


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print >>sys.stderr, "Usage: build_mn.py <funcName> <args...>"
        sys.exit(1)

    if (sys.argv[2:]):
        globals()[sys.argv[1]](sys.argv[2:])
    else:
        globals()[sys.argv[1]]()