diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index bff09d8c1..a5c341711 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -1,4 +1,4 @@ -#include "cuda.h" +#include #include #include #include diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 2b39fea29..3346eb954 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -12,7 +12,8 @@ from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver dirname = os.path.dirname(os.path.realpath(__file__)) -include_dirs = [os.path.join(dirname, "include")] +import shlex +include_dirs = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")] libdevice_dir = os.path.join(dirname, "lib") libraries = ['cuda'] PyCUtensorMap = None @@ -265,7 +266,7 @@ def make_launcher(constants, signature, tensordesc_meta): params.append("&global_scratch") params.append("&profile_scratch") src = f""" -#include \"cuda.h\" +#include #include #include #include