summaryrefslogtreecommitdiffstats
path: root/build/scripts/fatbinary_wrapper.py
blob: 80c9c7a6db97039ea0b6f3874af267b8c32d4d30 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import sys
import subprocess


def fix(args):
    prev_module_id = None

    for arg in args:
        kind = None

        if arg.endswith(".ptx"):
            kind = "ptx"
        elif arg.endswith(".cubin"):
            kind = "elf"
        elif arg.endswith(".module_id"):
            module_id = open(arg).read()

            if prev_module_id is not None and module_id != prev_module_id:
                print(f".module_id mismatch: {module_id} vs {prev_module_id}", file=sys.stderr)
                sys.exit(1)

            prev_module_id = module_id
            continue

        if not kind:
            yield arg
            continue

        _, arch, _ = arg.rsplit(".", 2)

        yield f"--image3=kind={kind},sm={arch},file={arg}"


def main():
    cmd = list(fix(sys.argv[1:]))
    rc = subprocess.call(cmd)
    sys.exit(rc)


if __name__ == "__main__":
    main()