-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathsetup.py
More file actions
95 lines (72 loc) · 2.94 KB
/
setup.py
File metadata and controls
95 lines (72 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
import os
import sys
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
from Cython.Build import cythonize
from setuptools import setup, Extension
def calculate_ext(module_: str, prefix: str = "", pre_module: str = "", source_suffix: str = "") -> Extension:
"""Create a C++ Extension object with .pyx sources for a given module.
Args:
module_: The name of the module in dot notation e.g. "package.subpackage.module".
prefix: A prefix to prepend to the final module name. e.g.
"prefixpackage.subpackage.module"
pre_module: A submodule to insert before the module name. e.g.
"pre_module.package.subpackage.module".
source_suffix: A suffix to append to the source filename such as "_linux",
"_windows". e.g. the source file would be
package.subpackage.modulesource_suffix.pyx instead of
package.subpackage.module.pyx
Returns:
A Cython Extension object configured with the provided parameters.
"""
module = module_.split(".")
if pre_module != "":
module.insert(-1, pre_module)
module[-1] = f"{prefix}{module[-1]}"
pyx = os.path.join(*module[:-1], f"{module[-1]}{source_suffix}.pyx")
module_ = ".".join(module)
return Extension(
module_,
sources=[pyx],
language="c++",
)
def get_ext_modules() -> list[Extension]:
"""Return a list of instantiated C++ Extensions with .pyx sources.
Modules names are gathered from [tool.nvmath-bindings.modules] and
[tool.nvmath-bindings.linux_modules] in pyproject.toml from lists of full module names.
e.g. "nvmath.bindings.cublas"
"""
with open("pyproject.toml", "rb") as f:
data = tomllib.load(f)
# Access specific sections, e.g., project metadata
pyproject_data = data.get("tool", {}).get("nvmath-bindings", {})
# Extension modules in nvmath.bindings for the math libraries.
modules = pyproject_data["modules"]
if sys.platform == "linux":
modules += pyproject_data["linux_modules"]
ext_modules: list[Extension] = []
for m in modules:
ext_modules += [
calculate_ext(m),
calculate_ext(m, prefix="cy"),
calculate_ext(m, pre_module="_internal", source_suffix="_linux" if sys.platform == "linux" else "_windows"),
]
# Extension modules in nvmath.internal for ndbuffer (temporary home).
nvmath_internal_modules = pyproject_data["internal_modules"]
ext_nvmath_internal_modules = [calculate_ext(m) for m in nvmath_internal_modules]
return ext_modules + ext_nvmath_internal_modules
nthreads = os.cpu_count()
setup(
ext_modules=cythonize(
get_ext_modules(),
verbose=True,
language_level=3,
compiler_directives={"embedsignature": True},
nthreads=nthreads,
),
)