From 1f28f2d7281b932f88c79e006ae6c2a463de9faa Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Sat, 30 Jul 2022 02:17:35 -0300 Subject: [PATCH] Numba target extension --- rbc/externals/__init__.py | 14 +- rbc/heavydb/__init__.py | 1 + rbc/heavydb/buffer.py | 1 + rbc/heavydb/heavydb_compiler.py | 352 +++++++++++++++++++++++++ rbc/heavydb/mathimpl.py | 70 ++++- rbc/heavydb/remoteheavydb.py | 4 +- rbc/irtools.py | 319 ++++++++-------------- rbc/remotejit.py | 1 + rbc/tests/heavydb/test_column_basic.py | 22 ++ rbc/tests/heavydb/test_heavydb.py | 8 +- rbc/tests/heavydb/test_math.py | 2 +- rbc/tests/test_externals_libdevice.py | 9 +- utils/client_ssh_tunnel.conf | 2 +- 13 files changed, 573 insertions(+), 232 deletions(-) create mode 100644 rbc/heavydb/heavydb_compiler.py diff --git a/rbc/externals/__init__.py b/rbc/externals/__init__.py index 13a8edad..98f126b8 100644 --- a/rbc/externals/__init__.py +++ b/rbc/externals/__init__.py @@ -4,11 +4,15 @@ def gen_codegen(fn_name): - def codegen(context, builder, sig, args): - # Need to retrieve the function name again - fndesc = funcdesc.ExternalFunctionDescriptor(fn_name, sig.return_type, sig.args) - func = context.declare_external_function(builder.module, fndesc) - return builder.call(func, args) + if fn_name.startswith('llvm.'): + def codegen(context, builder, sig, args): + func = builder.module.declare_intrinsic(fn_name, [a.type for a in args]) + return builder.call(func, args) + else: + def codegen(context, builder, sig, args): + fndesc = funcdesc.ExternalFunctionDescriptor(fn_name, sig.return_type, sig.args) + func = context.declare_external_function(builder.module, fndesc) + return builder.call(func, args) return codegen diff --git a/rbc/heavydb/__init__.py b/rbc/heavydb/__init__.py index 01cdb3b8..ebfcb300 100644 --- a/rbc/heavydb/__init__.py +++ b/rbc/heavydb/__init__.py @@ -10,6 +10,7 @@ from .text_encoding_none import * # noqa: F401, F403 from .timestamp import * # noqa: F401, F403 from .remoteheavydb import * # noqa: F401, F403 +from .heavydb_compiler import * # noqa: F401, F403 from . import mathimpl as math # noqa: F401 from . import npyimpl as np # noqa: F401 diff --git a/rbc/heavydb/buffer.py b/rbc/heavydb/buffer.py index b563c198..a9b4041c 100644 --- a/rbc/heavydb/buffer.py +++ b/rbc/heavydb/buffer.py @@ -553,6 +553,7 @@ def codegen(context, builder, signature, args): nv = ir.Constant(ir.IntType(T.bitwidth), null_value) if isinstance(T, types.Float): nv = builder.bitcast(nv, ty) + intrinsic(builder, (data, index, nv,)) return sig, codegen diff --git a/rbc/heavydb/heavydb_compiler.py b/rbc/heavydb/heavydb_compiler.py new file mode 100644 index 00000000..3611d3ef --- /dev/null +++ b/rbc/heavydb/heavydb_compiler.py @@ -0,0 +1,352 @@ +from contextlib import contextmanager +import llvmlite.binding as llvm +from rbc.targetinfo import TargetInfo +from numba.np import ufunc_db +from numba import _dynfunc +from numba.core import ( + codegen, compiler_lock, typing, + base, cpu, utils, descriptors, + dispatcher, callconv, imputils, + options,) +from numba.core.target_extension import ( + Generic, + target_registry, + dispatcher_registry, +) + + +class HeavyDB_CPU(Generic): + """Mark the target as HeavyDB CPU + """ + + +class HeavyDB_GPU(Generic): + """Mark the target as HeavyDB GPU + """ + + +target_registry['heavydb_cpu'] = HeavyDB_CPU +target_registry['heavydb_gpu'] = HeavyDB_GPU + +heavydb_cpu_registry = imputils.Registry(name='heavydb_cpu_registry') +heavydb_gpu_registry = imputils.Registry(name='heavydb_gpu_registry') + + +class _NestedContext(object): + _typing_context = None + _target_context = None + + @contextmanager + def nested(self, typing_context, target_context): + old_nested = self._typing_context, self._target_context + try: + self._typing_context = typing_context + self._target_context = target_context + yield + finally: + self._typing_context, self._target_context = old_nested + + +_options_mixin = options.include_default_options( + "no_rewrites", + "no_cpython_wrapper", + "no_cfunc_wrapper", + "fastmath", + "inline", + "boundscheck", + "nopython", + # Add "target_backend" as a accepted option for the CPU in @jit(...) + "target_backend", +) + + +class HeavyDBTargetOptions(_options_mixin, options.TargetOptions): + def finalize(self, flags, options): + flags.enable_pyobject = False + flags.enable_looplift = False + flags.nrt = False + flags.debuginfo = False + flags.boundscheck = False + flags.enable_pyobject_looplift = False + flags.no_rewrites = True + flags.auto_parallel = cpu.ParallelOptions(False) + flags.inherit_if_not_set("fastmath") + flags.inherit_if_not_set("error_model", default="python") + # Add "target_backend" as a option that inherits from the caller + flags.inherit_if_not_set("target_backend") + + +class HeavyDBTarget(descriptors.TargetDescriptor): + options = HeavyDBTargetOptions + _nested = _NestedContext() + + @utils.cached_property + def _toplevel_target_context(self): + # Lazily-initialized top-level target context, for all threads + return JITRemoteTargetContext(self.typing_context, self._target_name) + + @utils.cached_property + def _toplevel_typing_context(self): + # Lazily-initialized top-level typing context, for all threads + return JITRemoteTypingContext() + + @property + def target_context(self): + """ + The target context for CPU/GPU targets. + """ + nested = self._nested._target_context + if nested is not None: + return nested + else: + return self._toplevel_target_context + + @property + def typing_context(self): + """ + The typing context for CPU targets. + """ + nested = self._nested._typing_context + if nested is not None: + return nested + else: + return self._toplevel_typing_context + + def nested_context(self, typing_context, target_context): + """ + A context manager temporarily replacing the contexts with the + given ones, for the current thread of execution. + """ + return self._nested.nested(typing_context, target_context) + + +# Create a target instance +heavydb_cpu_target = HeavyDBTarget("heavydb_cpu") +heavydb_gpu_target = HeavyDBTarget("heavydb_gpu") + + +# Declare a dispatcher for the CPU/GPU targets +class HeavyDBCPUDispatcher(dispatcher.Dispatcher): + targetdescr = heavydb_cpu_target + + +class HeavyDBGPUDispatcher(dispatcher.Dispatcher): + targetdescr = heavydb_gpu_target + + +# Register a dispatcher for the target, a lot of the code uses this +# internally to work out what to do RE compilation +dispatcher_registry[target_registry["heavydb_cpu"]] = HeavyDBCPUDispatcher +dispatcher_registry[target_registry["heavydb_gpu"]] = HeavyDBGPUDispatcher + + +class JITRemoteCodeLibrary(codegen.JITCodeLibrary): + """JITRemoteCodeLibrary was introduce to prevent numba from calling functions + that checks if the module is final. See xnd-project/rbc issue #87. + """ + + def get_pointer_to_function(self, name): + """We can return any random number here! This is just to prevent numba from + trying to check if the symbol given by "name" is defined in the module. + In cases were RBC is calling an external function (i.e. allocate_varlen_buffer) + the symbol will not be defined in the module, resulting in an error. + """ + return 0 + + def _finalize_specific(self): + """Same as codegen.JITCodeLibrary._finalize_specific but without + calling _ensure_finalize at the end + """ + self._codegen._scan_and_fix_unresolved_refs(self._final_module) + + +class JITRemoteCodegen(codegen.JITCPUCodegen): + _library_class = JITRemoteCodeLibrary + + def _get_host_cpu_name(self): + target_info = TargetInfo() + return target_info.device_name + + def _get_host_cpu_features(self): + target_info = TargetInfo() + features = target_info.device_features + server_llvm_version = target_info.llvm_version + if server_llvm_version is None or target_info.is_gpu: + return '' + client_llvm_version = llvm.llvm_version_info + + # See https://github.com/xnd-project/rbc/issues/45 + remove_features = { + (11, 8): ['tsxldtrk', 'amx-tile', 'amx-bf16', 'serialize', 'amx-int8', + 'avx512vp2intersect', 'tsxldtrk', 'amx-tile', 'amx-bf16', + 'serialize', 'amx-int8', 'avx512vp2intersect', 'tsxldtrk', + 'amx-tile', 'amx-bf16', 'serialize', 'amx-int8', + 'avx512vp2intersect', 'cx8', 'enqcmd', 'avx512bf16'], + (11, 10): ['tsxldtrk', 'amx-tile', 'amx-bf16', 'serialize', 'amx-int8'], + (9, 8): ['cx8', 'enqcmd', 'avx512bf16'], + }.get((server_llvm_version[0], client_llvm_version[0]), []) + for f in remove_features: + features = features.replace('+' + f, '').replace('-' + f, '') + return features + + def _customize_tm_options(self, options): + super()._customize_tm_options(options) + # fix reloc_model as the base method sets it using local target + target_info = TargetInfo() + if target_info.arch.startswith('x86'): + reloc_model = 'static' + else: + reloc_model = 'default' + options['reloc'] = reloc_model + + def set_env(self, env_name, env): + return None + + +class JITRemoteTypingContext(typing.Context): + """JITRemote Typing Context + """ + + def load_additional_registries(self): + from . import mathimpl + self.install_registry(mathimpl.registry) + return super().load_additional_registries() + + +class JITRemoteTargetContext(base.BaseContext): + # Whether dynamic globals (CPU runtime addresses) is allowed + allow_dynamic_globals = True + + def __init__(self, typing_context, target): + if target not in ('heavydb_cpu', 'heavydb_gpu'): + raise ValueError(f'Target "{target}" not supported') + super().__init__(typing_context, target) + + @compiler_lock.global_compiler_lock + def init(self): + target_info = TargetInfo() + self.address_size = target_info.bits + self.is32bit = (self.address_size == 32) + self._internal_codegen = JITRemoteCodegen("numba.exec") + self._target_data = llvm.create_target_data(target_info.datalayout) + + def refresh(self): + if self.target_name == 'heavydb_cpu': + registry = heavydb_cpu_registry + else: + registry = heavydb_gpu_registry + + try: + loader = self._registries[registry] + except KeyError: + loader = imputils.RegistryLoader(registry) + self._registries[registry] = loader + + self.install_registry(registry) + # Also refresh typing context, since @overload declarations can + # affect it. + self.typing_context.refresh() + super().refresh() + + def load_additional_registries(self): + # Add implementations that work via import + from numba.cpython import (builtins, charseq, enumimpl, hashing, heapq, # noqa: F401 + iterators, listobj, numbers, rangeobj, + setobj, slicing, tupleobj, unicode,) + + self.install_registry(imputils.builtin_registry) + + # uncomment as needed! + # from numba.core import optional + from numba.np import linalg, polynomial, arraymath, arrayobj # noqa: F401 + # from numba.typed import typeddict, dictimpl + # from numba.typed import typedlist, listobject + # from numba.experimental import jitclass, function_type + # from numba.np import npdatetime + + # Add target specific implementations + from numba.np import npyimpl + from numba.cpython import mathimpl + # from numba.cpython import cmathimpl, mathimpl, printimpl, randomimpl + # from numba.misc import cffiimpl + # from numba.experimental.jitclass.base import ClassBuilder as \ + # jitclassimpl + # self.install_registry(cmathimpl.registry) + # self.install_registry(cffiimpl.registry) + self.install_registry(mathimpl.registry) + self.install_registry(npyimpl.registry) + # self.install_registry(printimpl.registry) + # self.install_registry(randomimpl.registry) + # self.install_registry(jitclassimpl.class_impl_registry) + + def codegen(self): + return self._internal_codegen + + @utils.cached_property + def call_conv(self): + return callconv.CPUCallConv(self) + + @property + def target_data(self): + return self._target_data + + def create_cpython_wrapper(self, + library, + fndesc, + env, + call_helper, + release_gil=False): + # There's no cpython wrapper on HeavyDB + pass + + def create_cfunc_wrapper(self, + library, + fndesc, + env, + call_helper, + release_gil=False): + # There's no cfunc wrapper on HeavyDB + pass + + def get_executable(self, library, fndesc, env): + """ + Returns + ------- + (cfunc, fnptr) + + - cfunc + callable function (Can be None) + - fnptr + callable function address + - env + an execution environment (from _dynfunc) + """ + # although we don't use this function, it seems to be required + # by some parts of codegen in Numba. + + # Code generation + fnptr = library.get_pointer_to_function( + fndesc.llvm_cpython_wrapper_name + ) + + # Note: we avoid reusing the original docstring to avoid encoding + # issues on Python 2, see issue #1908 + doc = "compiled wrapper for %r" % (fndesc.qualname,) + cfunc = _dynfunc.make_function( + fndesc.lookup_module(), + fndesc.qualname.split(".")[-1], + doc, + fnptr, + env, + # objects to keepalive with the function + (library,), + ) + library.codegen.set_env(self.get_env_name(fndesc), env) + return cfunc + + def post_lowering(self, mod, library): + pass + + # Overrides + def get_ufunc_info(self, ufunc_key): + return ufunc_db.get_ufunc_info(ufunc_key) diff --git a/rbc/heavydb/mathimpl.py b/rbc/heavydb/mathimpl.py index b508e8b5..3155179f 100644 --- a/rbc/heavydb/mathimpl.py +++ b/rbc/heavydb/mathimpl.py @@ -1,9 +1,17 @@ import math -from rbc.externals import gen_codegen, dispatch_codegen -from numba.core.typing.templates import infer_global -from numba.core.imputils import lower_builtin -from numba.core.typing.templates import ConcreteTemplate, signature +from rbc.externals import gen_codegen +from numba.core.typing.templates import ConcreteTemplate, signature, Registry from numba.types import float32, float64, int32, int64, uint64, intp +from numba.core.intrinsics import INTR_TO_CMATH +from .heavydb_compiler import heavydb_cpu_registry, heavydb_gpu_registry + + +lower_cpu = heavydb_cpu_registry.lower +lower_gpu = heavydb_gpu_registry.lower + + +registry = Registry() +infer_global = registry.register_global # Adding missing cases in Numba @@ -75,22 +83,31 @@ class Math_converter(ConcreteTemplate): binarys = [] binarys += [("copysign", "copysignf", math.copysign)] binarys += [("atan2", "atan2f", math.atan2)] -binarys += [("pow", "powf", math.pow)] binarys += [("fmod", "fmodf", math.fmod)] binarys += [("hypot", "hypotf", math.hypot)] binarys += [("remainder", "remainderf", math.remainder)] def impl_unary(fname, key, typ): - cpu = gen_codegen(fname) + if fname in INTR_TO_CMATH.values(): + # use llvm intrinsics when possible + cpu = gen_codegen(f'llvm.{fname}') + else: + cpu = gen_codegen(fname) gpu = gen_codegen(f"__nv_{fname}") - lower_builtin(key, typ)(dispatch_codegen(cpu, gpu)) + lower_cpu(key, typ)(cpu) + lower_gpu(key, typ)(gpu) def impl_binary(fname, key, typ): - cpu = gen_codegen(fname) + if fname in INTR_TO_CMATH.values(): + # use llvm intrinsics when possible + cpu = gen_codegen(f'llvm.{fname}') + else: + cpu = gen_codegen(fname) gpu = gen_codegen(f"__nv_{fname}") - lower_builtin(key, typ, typ)(dispatch_codegen(cpu, gpu)) + lower_cpu(key, typ, typ)(cpu) + lower_gpu(key, typ, typ)(gpu) for fname64, fname32, key in unarys: @@ -105,17 +122,42 @@ def impl_binary(fname, key, typ): # manual mapping def impl_ldexp(): + # cpu ldexp_cpu = gen_codegen('ldexp') - ldexp_gpu = gen_codegen('__nv_ldexp') - ldexpf_cpu = gen_codegen('ldexpf') - ldexpf_gpu = gen_codegen('__nv_ldexpf') + lower_cpu(math.ldexp, float64, int32)(ldexp_cpu) + lower_cpu(math.ldexp, float32, int32)(ldexpf_cpu) - lower_builtin(math.ldexp, float64, int32)(dispatch_codegen(ldexp_cpu, ldexp_gpu)) - lower_builtin(math.ldexp, float32, int32)(dispatch_codegen(ldexpf_cpu, ldexpf_gpu)) + # gpu + ldexp_gpu = gen_codegen('__nv_ldexp') + ldexpf_gpu = gen_codegen('__nv_ldexpf') + lower_gpu(math.ldexp, float64, int32)(ldexp_gpu) + lower_gpu(math.ldexp, float32, int32)(ldexpf_gpu) + + +def impl_pow(): + # cpu + pow_cpu = gen_codegen('pow') + powf_cpu = gen_codegen('powf') + lower_cpu(math.pow, float64, float64)(pow_cpu) + lower_cpu(math.pow, float32, float32)(powf_cpu) + lower_cpu(math.pow, float64, int32)(pow_cpu) + lower_cpu(math.pow, float32, int32)(powf_cpu) + + # gpu + pow_gpu = gen_codegen('__nv_pow') + powf_gpu = gen_codegen('__nv_powf') + powi_gpu = gen_codegen('__nv_powi') + powif_gpu = gen_codegen('__nv_powif') + lower_gpu(math.pow, float64, float64)(pow_gpu) + lower_gpu(math.pow, float32, float32)(powf_gpu) + lower_gpu(math.pow, float64, int32)(powi_gpu) + lower_gpu(math.pow, float32, int32)(powif_gpu) impl_ldexp() +impl_pow() + # CPU only: # math.gcd diff --git a/rbc/heavydb/remoteheavydb.py b/rbc/heavydb/remoteheavydb.py index c75c004f..f62eb2af 100644 --- a/rbc/heavydb/remoteheavydb.py +++ b/rbc/heavydb/remoteheavydb.py @@ -260,7 +260,7 @@ def is_sizer(t): def get_sizer_enum(t): - """Return sizer enum value as defined by the omniscidb server. + """Return sizer enum value as defined by the HeavyDB server. """ sizer = t.annotation()['sizer'] sizer = output_buffer_sizer_map.get(sizer or None, sizer) @@ -1508,5 +1508,5 @@ def remote_call(self, func, ftype: typesystem.Type, arguments: tuple, hold=False class RemoteOmnisci(RemoteHeavyDB): - """Omnisci - the previous brand of HeavyAI + """HeavyDB - the previous brand of HeavyAI """ diff --git a/rbc/irtools.py b/rbc/irtools.py index 371843e6..90f20152 100644 --- a/rbc/irtools.py +++ b/rbc/irtools.py @@ -3,7 +3,6 @@ import re import warnings -from contextlib import contextmanager from collections import defaultdict from llvmlite import ir import llvmlite.binding as llvm @@ -11,12 +10,33 @@ from .errors import UnsupportedError from . import libfuncs from rbc.externals import stdio -from numba.core import codegen, cpu, compiler_lock, \ +from numba.core import cpu, \ registry, typing, compiler, sigutils, cgutils, \ - extending, imputils + extending, target_extension, retarget, dispatcher +from numba import njit from numba.core import errors as nb_errors +class Retarget(retarget.BasicRetarget): + + def __init__(self, target_name): + self.target_name = target_name + super().__init__() + + @property + def output_target(self): + return self.target_name + + def compile_retarget(self, cpu_disp): + kernel = njit(_target=self.target_name)(cpu_disp.py_func) + return kernel + + +def switch_target(target_name): + tc = dispatcher.TargetConfigurationStack + return tc.switch_target(Retarget(target_name)) + + int32_t = ir.IntType(32) int1_t = ir.IntType(1) @@ -64,132 +84,6 @@ def get_called_functions(library, funcname=None): # --------------------------------------------------------------------------- -class JITRemoteCodeLibrary(codegen.JITCodeLibrary): - """JITRemoteCodeLibrary was introduce to prevent numba from calling functions - that checks if the module is final. See xnd-project/rbc issue #87. - """ - - def get_pointer_to_function(self, name): - """We can return any random number here! This is just to prevent numba from - trying to check if the symbol given by "name" is defined in the module. - In cases were RBC is calling an external function (i.e. allocate_varlen_buffer) - the symbol will not be defined in the module, resulting in an error. - """ - return 0 - - def _finalize_specific(self): - """Same as codegen.JITCodeLibrary._finalize_specific but without - calling _ensure_finalize at the end - """ - self._codegen._scan_and_fix_unresolved_refs(self._final_module) - - -class JITRemoteCodegen(codegen.JITCPUCodegen): - _library_class = JITRemoteCodeLibrary - - def _get_host_cpu_name(self): - target_info = TargetInfo() - return target_info.device_name - - def _get_host_cpu_features(self): - target_info = TargetInfo() - features = target_info.device_features - server_llvm_version = target_info.llvm_version - if server_llvm_version is None or target_info.is_gpu: - return '' - client_llvm_version = llvm.llvm_version_info - - # See https://github.com/xnd-project/rbc/issues/45 - remove_features = { - (12, 12): [], (11, 11): [], (10, 10): [], (9, 9): [], (8, 8): [], - (11, 8): ['tsxldtrk', 'amx-tile', 'amx-bf16', 'serialize', 'amx-int8', - 'avx512vp2intersect', 'tsxldtrk', 'amx-tile', 'amx-bf16', - 'serialize', 'amx-int8', 'avx512vp2intersect', 'tsxldtrk', - 'amx-tile', 'amx-bf16', 'serialize', 'amx-int8', - 'avx512vp2intersect', 'cx8', 'enqcmd', 'avx512bf16'], - (11, 10): ['tsxldtrk', 'amx-tile', 'amx-bf16', 'serialize', 'amx-int8'], - (9, 11): ['sse2', 'cx16', 'sahf', 'tbm', 'avx512ifma', 'sha', - 'gfni', 'fma4', 'vpclmulqdq', 'prfchw', 'bmi2', 'cldemote', - 'fsgsbase', 'ptwrite', 'xsavec', 'popcnt', 'mpx', - 'avx512bitalg', 'movdiri', 'xsaves', 'avx512er', - 'avx512vnni', 'avx512vpopcntdq', 'pconfig', 'clwb', - 'avx512f', 'clzero', 'pku', 'mmx', 'lwp', 'rdpid', 'xop', - 'rdseed', 'waitpkg', 'movdir64b', 'sse4a', 'avx512bw', - 'clflushopt', 'xsave', 'avx512vbmi2', '64bit', 'avx512vl', - 'invpcid', 'avx512cd', 'avx', 'vaes', 'cx8', 'fma', 'rtm', - 'bmi', 'enqcmd', 'rdrnd', 'mwaitx', 'sse4.1', 'sse4.2', 'avx2', - 'fxsr', 'wbnoinvd', 'sse', 'lzcnt', 'pclmul', 'prefetchwt1', - 'f16c', 'ssse3', 'sgx', 'shstk', 'cmov', 'avx512vbmi', - 'avx512bf16', 'movbe', 'xsaveopt', 'avx512dq', 'adx', - 'avx512pf', 'sse3'], - (9, 8): ['cx8', 'enqcmd', 'avx512bf16'], - }.get((server_llvm_version[0], client_llvm_version[0]), None) - if remove_features is None: - warnings.warn( - f'{type(self).__name__}._get_host_cpu_features: `remove_features` dictionary' - ' requires an update: detected different LLVM versions in server ' - f'{server_llvm_version} and client {client_llvm_version}.' - f' CPU features: {features}.') - else: - features += ',' - for f in remove_features: - features = features.replace('+' + f + ',', '').replace('-' + f + ',', '') - features.rstrip(',') - return features - - def _customize_tm_options(self, options): - super()._customize_tm_options(options) - # fix reloc_model as the base method sets it using local target - target_info = TargetInfo() - if target_info.arch.startswith('x86'): - reloc_model = 'static' - else: - reloc_model = 'default' - options['reloc'] = reloc_model - - def set_env(self, env_name, env): - return None - - -class JITRemoteTypingContext(typing.Context): - def load_additional_registries(self): - self.install_registry(typing.templates.builtin_registry) - super().load_additional_registries() - - -class JITRemoteTargetContext(cpu.CPUContext): - - @compiler_lock.global_compiler_lock - def init(self): - target_info = TargetInfo() - self.address_size = target_info.bits - self.is32bit = (self.address_size == 32) - self._internal_codegen = JITRemoteCodegen("numba.exec") - - def load_additional_registries(self): - self.install_registry(imputils.builtin_registry) - super().load_additional_registries() - - def get_executable(self, library, fndesc, env): - return None - - def post_lowering(self, mod, library): - pass - - -# --------------------------------------------------------------------------- -# Code generation methods - - -@contextmanager -def replace_numba_internals_hack(): - # Hackish solution to prevent numba from calling _ensure_finalize. See issue #87 - _internal_codegen_bkp = registry.cpu_target.target_context._internal_codegen - registry.cpu_target.target_context._internal_codegen = JITRemoteCodegen("numba.exec") - yield - registry.cpu_target.target_context._internal_codegen = _internal_codegen_bkp - - def make_wrapper(fname, atypes, rtype, cres, target: TargetInfo, verbose=False): """Make wrapper function to numba compile result. @@ -268,7 +162,7 @@ def make_wrapper(fname, atypes, rtype, cres, target: TargetInfo, verbose=False): def compile_instance(func, sig, - target: TargetInfo, + target_info: TargetInfo, typing_context, target_context, pipeline_class, @@ -309,7 +203,7 @@ def compile_instance(func, sig, result = get_called_functions(cres.library, cres.fndesc.llvm_func_name) for f in result['declarations']: - if target.supports(f): + if target_info.supports(f): continue warnings.warn(f'Skipping {fname} that uses undefined function `{f}`') return @@ -317,18 +211,18 @@ def compile_instance(func, sig, nvvmlib = libfuncs.Library.get('nvvm') llvmlib = libfuncs.Library.get('llvm') for f in result['intrinsics']: - if target.is_gpu: + if target_info.is_gpu: if f in nvvmlib: continue - if target.is_cpu: + if target_info.is_cpu: if f in llvmlib: continue warnings.warn(f'Skipping {fname} that uses unsupported intrinsic `{f}`') return - make_wrapper(fname, args, return_type, cres, target, verbose=debug) + make_wrapper(fname, args, return_type, cres, target_info, verbose=debug) main_module = main_library._final_module for lib in result['libraries']: @@ -373,83 +267,100 @@ def compile_to_LLVM(functions_and_signatures, LLVM module instance. To get the IR string, use `str(module)`. """ - target_desc = registry.cpu_target - - typing_context = JITRemoteTypingContext() - target_context = JITRemoteTargetContext(typing_context) + # avoid circula import error + # * remotejit imports irtools + # * irtools import heavydb + # * heavydb import remotejit + from rbc.heavydb import JITRemoteTypingContext, JITRemoteTargetContext, \ + heavydb_cpu_target, heavydb_gpu_target + + device = target_info.name + software = target_info.software[0] + + if software == 'HeavyDB': + target_name = f'heavydb_{device}' + target_desc = heavydb_cpu_target if device == 'cpu' else heavydb_gpu_target + typing_context = JITRemoteTypingContext() + target_context = JITRemoteTargetContext(typing_context, target_name) + else: + target_name = 'cpu' + target_desc = registry.cpu_target + typing_context = typing.Context() + target_context = cpu.CPUContext(typing_context, target_name) # Bring over Array overloads (a hack): target_context._defns = target_desc.target_context._defns - with replace_numba_internals_hack(): - codegen = target_context.codegen() - main_library = codegen.create_library('rbc.irtools.compile_to_IR') - main_module = main_library._final_module - - if user_defined_llvm_ir is not None: - if isinstance(user_defined_llvm_ir, str): - user_defined_llvm_ir = llvm.parse_assembly(user_defined_llvm_ir) - assert isinstance(user_defined_llvm_ir, llvm.ModuleRef) - main_module.link_in(user_defined_llvm_ir, preserve=True) - - succesful_fids = [] - function_names = [] - for func, signatures in functions_and_signatures: - for fid, sig in signatures.items(): - fname = compile_instance(func, sig, target_info, typing_context, - target_context, pipeline_class, - main_library, - debug=debug) - if fname is not None: - succesful_fids.append(fid) - function_names.append(fname) - - add_metadata_flag(main_library, - pass_column_arguments_by_value=0, - manage_memory_buffer=1) - main_library._optimize_final_module() + codegen = target_context.codegen() + main_library = codegen.create_library(f'rbc.irtools.compile_to_IR_{software}_{device}') + main_module = main_library._final_module - # Remove unused defined functions and declarations - used_symbols = defaultdict(set) - for fname in function_names: - for k, v in get_called_functions(main_library, fname).items(): - used_symbols[k].update(v) + if user_defined_llvm_ir is not None: + if isinstance(user_defined_llvm_ir, str): + user_defined_llvm_ir = llvm.parse_assembly(user_defined_llvm_ir) + assert isinstance(user_defined_llvm_ir, llvm.ModuleRef) + main_module.link_in(user_defined_llvm_ir, preserve=True) + + succesful_fids = [] + function_names = [] + for func, signatures in functions_and_signatures: + for fid, sig in signatures.items(): + with switch_target(target_name): + with target_extension.target_override(target_name): + fname = compile_instance(func, sig, target_info, typing_context, + target_context, pipeline_class, + main_library, + debug=debug) + if fname is not None: + succesful_fids.append(fid) + function_names.append(fname) + + add_metadata_flag(main_library, + pass_column_arguments_by_value=0, + manage_memory_buffer=1) + main_library._optimize_final_module() + + # Remove unused defined functions and declarations + used_symbols = defaultdict(set) + for fname in function_names: + for k, v in get_called_functions(main_library, fname).items(): + used_symbols[k].update(v) + + all_symbols = get_called_functions(main_library) + + unused_symbols = defaultdict(set) + for k, lst in all_symbols.items(): + if k == 'libraries': + continue + for fn in lst: + if fn not in used_symbols[k]: + unused_symbols[k].add(fn) + + changed = False + for f in main_module.functions: + fn = f.name + if fn.startswith('llvm.'): + if f.name in unused_symbols['intrinsics']: + f.linkage = llvm.Linkage.external + changed = True + elif f.is_declaration: + if f.name in unused_symbols['declarations']: + f.linkage = llvm.Linkage.external + changed = True + else: + if f.name in unused_symbols['defined']: + f.linkage = llvm.Linkage.private + changed = True - all_symbols = get_called_functions(main_library) + # TODO: determine unused global_variables and struct_types - unused_symbols = defaultdict(set) - for k, lst in all_symbols.items(): - if k == 'libraries': - continue - for fn in lst: - if fn not in used_symbols[k]: - unused_symbols[k].add(fn) - - changed = False - for f in main_module.functions: - fn = f.name - if fn.startswith('llvm.'): - if f.name in unused_symbols['intrinsics']: - f.linkage = llvm.Linkage.external - changed = True - elif f.is_declaration: - if f.name in unused_symbols['declarations']: - f.linkage = llvm.Linkage.external - changed = True - else: - if f.name in unused_symbols['defined']: - f.linkage = llvm.Linkage.private - changed = True - - # TODO: determine unused global_variables and struct_types - - if changed: - main_library._optimize_final_module() - - main_module.verify() - main_library._finalized = True - main_module.triple = target_info.triple - main_module.data_layout = target_info.datalayout + if changed: + main_library._optimize_final_module() + + main_module.verify() + main_library._finalized = True + main_module.triple = target_info.triple + main_module.data_layout = target_info.datalayout return main_module, succesful_fids diff --git a/rbc/remotejit.py b/rbc/remotejit.py index ee76aec5..4bf08816 100644 --- a/rbc/remotejit.py +++ b/rbc/remotejit.py @@ -949,6 +949,7 @@ def targets(self) -> dict: target_info = TargetInfo.host() target_info.set('has_numba', True) target_info.set('has_cpython', True) + target_info.set('software', 'remotejit') return dict(cpu=target_info.tojson()) @dispatchermethod diff --git a/rbc/tests/heavydb/test_column_basic.py b/rbc/tests/heavydb/test_column_basic.py index ddb292ff..299a8a91 100644 --- a/rbc/tests/heavydb/test_column_basic.py +++ b/rbc/tests/heavydb/test_column_basic.py @@ -3,6 +3,8 @@ from collections import defaultdict import pytest import numpy as np +import math +from numba import njit rbc_heavydb = pytest.importorskip('rbc.heavydb') @@ -822,6 +824,26 @@ def convolve(x, kernel, m, y): assert list(result) == expected +def test_issue343(heavydb): + # Before generating llvm code, the irtools entry point needs + # to switch the target context from CPU to GPU, so that functions + # are bind to the correct target. In the case below, math.exp + # is bind to '@llvm.exp.f64' on CPU and '@__nv_exp' on GPU. + if not heavydb.has_cuda: + pytest.skip('test requires heavydb build with GPU support') + + @njit + def bar(x): + return math.exp(x) + + @heavydb('double(double)', devices=['cpu', 'gpu']) + def foo(x): + return math.exp(x) + bar(x) + + assert '__nv_exp' in str(foo) + assert 'llvm.exp.f64' in str(foo) + + def test_column_dtype(heavydb): from numba import types table = heavydb.table_name diff --git a/rbc/tests/heavydb/test_heavydb.py b/rbc/tests/heavydb/test_heavydb.py index a707f335..5eaefc2a 100644 --- a/rbc/tests/heavydb/test_heavydb.py +++ b/rbc/tests/heavydb/test_heavydb.py @@ -362,7 +362,7 @@ def test_casting(heavydb): The following table defines the behavior of applying these UDFs to values with different types: - OmnisciDB version 5.9+ + HeavyDB version 5.9+ ---------------------- | Functions applied to itype | i8 | i16 | i32 | i64 | f32 | f64 | @@ -374,7 +374,7 @@ def test_casting(heavydb): float | FAIL | FAIL | FAIL | FAIL | OK | OK | double | FAIL | FAIL | FAIL | FAIL | FAIL | OK | - OmnisciDB version 5.8 + HeavyDB version 5.8 ---------------------- | Functions applied to itype | i8 | i16 | i32 | i64 | f32 | f64 | @@ -386,7 +386,7 @@ def test_casting(heavydb): float | FAIL | FAIL | FAIL | FAIL | OK | OK | double | FAIL | FAIL | FAIL | FAIL | FAIL | OK | - OmnisciDB version 5.7 and older + HeavyDB version 5.7 and older ------------------------------- | Functions applied to itype | i8 | i16 | i32 | i64 | f32 | f64 | @@ -839,7 +839,7 @@ def test_reconnect(heavydb): def test_non_admin_user(heavydb): - heavydb.require_version((5, 9), 'Requires omniscidb 5.9 or newer') + heavydb.require_version((5, 9), 'Requires HeavyDB 5.9 or newer') user = 'rbc_test_non_admin_user' password = 'Xy2kq_3lM' diff --git a/rbc/tests/heavydb/test_math.py b/rbc/tests/heavydb/test_math.py index 019c1abb..78ea68a3 100644 --- a/rbc/tests/heavydb/test_math.py +++ b/rbc/tests/heavydb/test_math.py @@ -53,7 +53,7 @@ def heavydb(): math_functions = [ # Number-theoretic and representation functions - ('ceil', 'int64(double)'), + ('ceil', 'double(double)'), ('comb', 'int64(int64, int64)'), ('copysign', 'double(double, double)'), ('fabs', 'double(double)'), diff --git a/rbc/tests/test_externals_libdevice.py b/rbc/tests/test_externals_libdevice.py index 223c6047..71b78459 100644 --- a/rbc/tests/test_externals_libdevice.py +++ b/rbc/tests/test_externals_libdevice.py @@ -15,6 +15,9 @@ funcs.append((fname, str(retty), argtys, has_ptr_arg)) +fns = {} + + @pytest.fixture(scope="module") def heavydb(): @@ -50,6 +53,7 @@ def fn(a, b, c): fn.__name__ = f"{heavydb.table_name}_{fname[5:]}" fn = heavydb(f"{retty}({', '.join(argtypes)})", devices=["gpu"])(fn) + fns[fname] = fn for fname, retty, argtys, has_ptr_arg in funcs: if has_ptr_arg: @@ -84,4 +88,7 @@ def test_externals_libdevice(heavydb, fname, retty, argtys, has_ptr_arg): cols = ", ".join(tuple(map(lambda x: cols_dict[x], argtys))) query = f"SELECT {func_name}({cols}) FROM {table}" - _, _ = heavydb.sql_execute(query) + _, result = heavydb.sql_execute(query) + + assert fname in str(fns[fname]) + # to-do: check results diff --git a/utils/client_ssh_tunnel.conf b/utils/client_ssh_tunnel.conf index ab32578b..52224310 100644 --- a/utils/client_ssh_tunnel.conf +++ b/utils/client_ssh_tunnel.conf @@ -6,7 +6,7 @@ # 1. Run omnscidb server with ssh port forwarding:: # # ssh -L 6274:127.0.0.1:16274 -# bin/omnisci_server --enable-runtime-udf --enable-table-functions -p 16274 --http-port 16278 --calcite-port 16279 +# bin/omnisci_server --enable-dev-table-functions --enable-runtime-udf --enable-table-functions -p 16274 --http-port 16278 --calcite-port 16279 # # 2. Relate the omniscidb server to client: #