Skip to content

Commit

Permalink
Revert "[reland][inductor] switch AotCodeCompiler to new cpp_builder (#…
Browse files Browse the repository at this point in the history
…130127)"

This reverts commit 9606d61.

Reverted #130127 on behalf of https://github.com/ZainRizvi due to broke internal tests ([comment](#130127 (comment)))
  • Loading branch information
pytorchmergebot committed Jul 30, 2024
1 parent 9027db1 commit 239d4d2
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 76 deletions.
5 changes: 5 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7173,6 +7173,11 @@ def fn(x):

self.common(fn, [torch.randn(64, 64)])

def test_new_cpp_build_logical(self):
from torch._inductor.codecache import validate_new_cpp_commands

validate_new_cpp_commands()

def test_as_strided(self):
def fn(x):
return (
Expand Down
194 changes: 118 additions & 76 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def get_include_and_linking_paths(
return ipaths, lpaths_str, libs_str, macros, build_arch_flags


def deprecated_cpp_compile_command(
def cpp_compile_command(
input: Union[str, List[str]],
output: str,
warning_all: bool = True,
Expand All @@ -1523,13 +1523,6 @@ def deprecated_cpp_compile_command(
use_mmap_weights: bool = False,
extra_flags: Sequence[str] = (),
) -> str:
"""
Please don't use this function in new development code.
It was planed to remove after we switched to new cpp_builder, but I can't access to Meta
internal environment to fix AotCodeCompiler fb_code.
TODO: need some Meta employee help on fix AotCodeCompiler fb_code, and then delete this
deprecated function.
"""
ipaths, lpaths, libs, macros, build_arch_flags = get_include_and_linking_paths(
include_pytorch, vec_isa, cuda, aot_mode
)
Expand Down Expand Up @@ -1584,20 +1577,6 @@ def deprecated_cpp_compile_command(
).strip()


def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]) -> None:
"""
TODO: Will remove the temp code after switch to new cpp_builder
"""
new_diff: List[str] = [x for x in new_cmd if x not in old_cmd]
old_diff: List[str] = [y for y in old_cmd if y not in new_cmd]
if new_diff or old_diff:
print("!!! new_cmd: ", new_cmd)
print("!!! old_cmd: ", old_cmd)
print("!!! new_diff: ", new_diff)
print("!!! old_diff: ", old_diff)
raise RuntimeError("Error in new and old command different.")


def run_command_and_check(cmd: str) -> None:
cmd = shlex.split(cmd)
try:
Expand Down Expand Up @@ -1844,8 +1823,8 @@ def _compile_consts_darwin(consts: bytes) -> str:
if specified_so_name
else os.path.splitext(input_path)[0] + ".so"
)
output_o = os.path.splitext(input_path)[0] + ".o"

output_o = os.path.splitext(input_path)[0] + ".o"
consts_size = sum(
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
Expand Down Expand Up @@ -1884,29 +1863,8 @@ def _compile_consts_darwin(consts: bytes) -> str:
object_build_options.save_flags_to_file(compile_flags)

else:
(
object_output_name,
object_output_dir,
) = get_name_and_dir_from_output_file_path(input_path)
object_build_options = CppTorchCudaOptions(
vec_isa=picked_vec_isa,
cuda=cuda,
aot_mode=graph.aot_mode,
compile_only=True,
use_absolute_path=use_absolute_path,
use_mmap_weights=use_mmap_weights,
)
object_builder = CppBuilder(
name=object_output_name,
sources=input_path,
output_dir=object_output_dir,
BuildOption=object_build_options,
)
compile_cmd = object_builder.get_command_line()
output_o = object_builder.get_target_file_path()

# TODO: replace this with using the CppBuilder above
compile_cmd_old = deprecated_cpp_compile_command(
compile_cmd = cpp_compile_command(
input=input_path,
output=output_o,
vec_isa=picked_vec_isa,
Expand All @@ -1916,13 +1874,6 @@ def _compile_consts_darwin(consts: bytes) -> str:
use_absolute_path=use_absolute_path,
use_mmap_weights=use_mmap_weights,
)
# TODO: Enable below code to debug in fb_code.
"""
_temp_validate_new_and_old_command(
compile_cmd.split(" "), compile_cmd_old.split(" ")
)
"""
compile_cmd = compile_cmd_old

log.debug("aot compilation command: %s", compile_cmd)
if fbcode_aot_cpu_re:
Expand Down Expand Up @@ -2020,38 +1971,15 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
archive_path = package_aoti(os.path.split(input_path)[0])
return archive_path

output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)
so_build_options = CppTorchCudaOptions(
vec_isa=picked_vec_isa,
cuda=cuda,
aot_mode=graph.aot_mode,
use_absolute_path=use_absolute_path,
)
so_builder = CppBuilder(
name=output_name,
sources=[output_o, consts_o],
output_dir=output_dir,
BuildOption=so_build_options,
)
link_cmd = so_builder.get_command_line()
output_so = so_builder.get_target_file_path()

# TODO: replace this with using the CppBuilder above
link_cmd_old = deprecated_cpp_compile_command(
link_cmd = cpp_compile_command(
input=[output_o, consts_o],
output=output_so,
vec_isa=picked_vec_isa,
cuda=cuda,
aot_mode=graph.aot_mode,
use_absolute_path=use_absolute_path,
)
# TODO: Enable below code to debug in fb_code.
"""
_temp_validate_new_and_old_command(
link_cmd.split(" "), link_cmd_old.split(" ")
)
"""
link_cmd = link_cmd_old

log.debug("aot linkage command: %s", link_cmd)
if fbcode_aot_cpu_re:
Expand Down Expand Up @@ -2572,6 +2500,120 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
)


# TODO: Will remove the temp code after switch to new cpp_builder
def _temp_validate_new_and_old_command(new_cmd: List[str], old_cmd: List[str]) -> None:
new_diff: List[str] = [x for x in new_cmd if x not in old_cmd]
old_diff: List[str] = [y for y in old_cmd if y not in new_cmd]

if new_diff or old_diff:
print("!!! new_cmd: ", new_cmd)
print("!!! old_cmd: ", old_cmd)
print("!!! new_diff: ", new_diff)
print("!!! old_diff: ", old_diff)
raise RuntimeError("Error in new and old command different.")


def _do_validate_cpp_commands(
include_pytorch: bool,
cuda: bool,
compile_only: bool,
mmap_weights: bool,
use_absolute_path: bool,
aot_mode: bool,
) -> None:
# PreCI will failed if test machine can't run cuda.
temp_dir = tempfile.TemporaryDirectory()
test_dir_path = temp_dir.name
test_cuda = torch.cuda.is_available() and cuda
input_path = os.path.join(test_dir_path, "dummy_file.cpp")
output_path = os.path.join(test_dir_path, "dummy_file.so")
extra_flags = ["-D TEST_EXTRA_FLAGS"]
if compile_only:
output_path = os.path.join(test_dir_path, "dummy_file.o")
picked_isa = pick_vec_isa()

# Simulate fb_code env:
if not (aot_mode and not use_absolute_path):
input_path = os.path.basename(input_path)
output_path = os.path.basename(output_path)

# Fix test_new_cpp_build_logical failed on MacOS
if sys.platform != "linux":
aot_mode = False

old_cmd = cpp_compile_command(
input=input_path,
output=output_path,
include_pytorch=include_pytorch,
vec_isa=picked_isa,
cuda=test_cuda,
aot_mode=aot_mode,
compile_only=compile_only,
use_absolute_path=use_absolute_path,
use_mmap_weights=mmap_weights,
extra_flags=extra_flags,
).split(" ")

name, dir = get_name_and_dir_from_output_file_path(input_path)

dummy_build_option = CppTorchCudaOptions(
vec_isa=picked_isa,
include_pytorch=include_pytorch,
cuda=test_cuda,
aot_mode=aot_mode,
compile_only=compile_only,
use_absolute_path=use_absolute_path,
use_mmap_weights=mmap_weights,
extra_flags=extra_flags,
)

dummy_builder = CppBuilder(
name=name,
sources=input_path,
output_dir=dir,
BuildOption=dummy_build_option,
)
new_cmd = dummy_builder.get_command_line().split(" ")

_temp_validate_new_and_old_command(new_cmd, old_cmd)

temp_dir.cleanup()


# TODO: Will remove the temp code after switch to new cpp_builder
# It could help on sync new cpp_builder generate same command line as the old one.
def validate_new_cpp_commands() -> None:
cuda = [True, False]
use_mmap_weights = [True, False]
compile_only = [True, False]
include_pytorch = [True, False]
use_absolute_path = [True, False]
aot_mode = [False, True]

# Try to pass it in fb_code.
if config.is_fbcode():
return

for x in cuda:
for y in use_mmap_weights:
for z in compile_only:
for m in include_pytorch:
for n in use_absolute_path:
for o in aot_mode:
print(
f"!!! cuda:{x}, use_mmap_weights:{y}, compile_only:{z}, include_pytorch:{m},"
f" use_absolute_path:{n}, aot_mode:{o}"
)
_do_validate_cpp_commands(
include_pytorch=m,
cuda=x,
mmap_weights=y,
compile_only=z,
use_absolute_path=n,
aot_mode=o,
)


@clear_on_fresh_inductor_cache
class HalideCodeCache(CppPythonBindingsCodeCache):
cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
Expand Down

0 comments on commit 239d4d2

Please sign in to comment.