Skip to content

Commit

Permalink
Add flop counter to elementwise for opencl/cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
adityapb committed Nov 4, 2020
1 parent bc3f1a6 commit 129d687
Show file tree
Hide file tree
Showing 9 changed files with 541 additions and 55 deletions.
14 changes: 14 additions & 0 deletions compyle/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self):
self._use_double = None
self._omp_schedule = None
self._profile = None
self._count_flops = None
self._use_local_memory = None
self._wgs = None
self._suppress_warnings = None
Expand Down Expand Up @@ -129,6 +130,19 @@ def profile(self, value):
def _profile_default(self):
return False

@property
def count_flops(self):
if self._count_flops is None:
self._count_flops = self._count_flops_default()
return self._count_flops

@count_flops.setter
def count_flops(self, value):
self._count_flops = value

def _count_flops_default(self):
return False

@property
def use_local_memory(self):
if self._use_local_memory is None:
Expand Down
26 changes: 15 additions & 11 deletions compyle/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
dtype_to_knowntype, annotate)
from .extern import Extern
from .utils import getsourcelines
from .profile import profile
from .profile import record_flops, profile

from . import array
from . import parallel
Expand Down Expand Up @@ -265,13 +265,9 @@ def visit_UnaryOp(self, node):
return self.visit(node.operand)

def visit_Return(self, node):
if isinstance(node.value, ast.Name) or \
isinstance(node.value, ast.Subscript) or \
isinstance(node.value, ast.Num) or \
isinstance(node.value, ast.BinOp) or \
isinstance(node.value, ast.Call) or \
isinstance(node.value, ast.IfExp) or \
isinstance(node.value, ast.UnaryOp):
valid_return_expr = (ast.Name, ast.Subscript, ast.Num, ast.BinOp,
ast.Call, ast.IfExp, ast.UnaryOp)
if isinstance(node.value, valid_return_expr):
result_type = self.visit(node.value)
if result_type:
self.arg_types['return_'] = result_type
Expand All @@ -287,11 +283,12 @@ def visit_Return(self, node):
class ElementwiseJIT(parallel.ElementwiseBase):
def __init__(self, func, backend=None):
backend = array.get_backend(backend)
self.tp = Transpiler(backend=backend)
self._config = get_config()
self.tp = Transpiler(backend=backend,
count_flops=self._config.count_flops)
self.backend = backend
self.name = 'elwise_%s' % func.__name__
self.func = func
self._config = get_config()
self.cython_gen = CythonGenerator()
self.source = '# Code jitted, call the function to generate the code.'
self.all_source = self.source
Expand Down Expand Up @@ -333,6 +330,10 @@ def _massage_arg(self, x):
def __call__(self, *args, **kw):
c_func = self._generate_kernel(*args)
c_args = [self._massage_arg(x) for x in args]
if self._config.count_flops:
flop_counter = array.zeros(args[0].length, np.int64,
backend=self.backend)
c_args.append(flop_counter.dev)

if self.backend == 'cython':
size = len(c_args[0])
Expand All @@ -347,6 +348,9 @@ def __call__(self, *args, **kw):
c_func(*c_args, **kw)
event.record()
event.synchronize()
if self._config.count_flops:
flops = array.sum(flop_counter)
record_flops(self.name, flops)


class ReductionJIT(parallel.ReductionBase):
Expand Down Expand Up @@ -523,7 +527,7 @@ def __call__(self, **kwargs):
c_args_dict = {k: self._massage_arg(x) for k, x in kwargs.items()}
if self._get_backend_key() in self.output_func.arg_keys:
output_arg_keys = self.output_func.arg_keys[
self._get_backend_key()]
self._get_backend_key()]
else:
raise ValueError("No kernel arguments found for backend = %s, "
"use_openmp = %s, use_double = %s" %
Expand Down
40 changes: 35 additions & 5 deletions compyle/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np

from .config import get_config
from .profile import profile
from .profile import record_flops, profile
from .cython_generator import get_parallel_range, CythonGenerator
from .transpiler import Transpiler, convert_to_float_if_needed
from .types import dtype_to_ctype
Expand Down Expand Up @@ -404,11 +404,12 @@ def get_common_cache_key(obj):
class ElementwiseBase(object):
def __init__(self, func, backend=None):
backend = array.get_backend(backend)
self.tp = Transpiler(backend=backend)
self._config = get_config()
self.tp = Transpiler(backend=backend,
count_flops=self._config.count_flops)
self.backend = backend
self.name = 'elwise_%s' % func.__name__
self.func = func
self._config = get_config()
self.cython_gen = CythonGenerator()
self.queue = None
# This is the source generated for the user code.
Expand Down Expand Up @@ -453,11 +454,17 @@ def _generate(self, declarations=None):
ctx = get_context()
self.queue = get_queue()
name = self.func.__name__
call_args = ', '.join(c_data[1])
if self._config.count_flops:
call_args += ', cpy_flop_counter'
expr = '{func}({args})'.format(
func=name,
args=', '.join(c_data[1])
args=call_args
)
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
if self._config.count_flops:
arguments += ', long* cpy_flop_counter'

preamble = convert_to_float_if_needed(self.tp.get_code())
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
double_support=True
Expand All @@ -483,11 +490,17 @@ def _generate(self, declarations=None):
from pycuda.elementwise import ElementwiseKernel
from pycuda._cluda import CLUDA_PREAMBLE
name = self.func.__name__
call_args = ', '.join(c_data[1])
if self._config.count_flops:
call_args += ', cpy_flop_counter'
expr = '{func}({args})'.format(
func=name,
args=', '.join(c_data[1])
args=call_args
)
arguments = convert_to_float_if_needed(', '.join(c_data[0][1:]))
if self._config.count_flops:
arguments += ', long* cpy_flop_counter'

preamble = convert_to_float_if_needed(self.tp.get_code())
cluda_preamble = Template(text=CLUDA_PREAMBLE).render(
double_support=True
Expand Down Expand Up @@ -519,6 +532,8 @@ def _add_address_space(arg):
return arg

args = [_add_address_space(arg) for arg in c_data[0]]
if self._config.count_flops:
args.append('GLOBAL_MEM long* cpy_flop_counter')
code[:header_idx] = wrap(
'WITHIN_KERNEL void {func}({args})'.format(
func=self.func.__name__,
Expand All @@ -527,6 +542,14 @@ def _add_address_space(arg):
width=78, subsequent_indent=' ' * 4, break_long_words=False
)
self.tp.blocks[-1].code = '\n'.join(code)
if self._config.count_flops:
for idx, block in enumerate(self.tp.blocks[:-1]):
self.tp.blocks[idx].code = block.code.replace(
'${offset}', '0'
)
self.tp.blocks[-1].code = self.tp.blocks[-1].code.replace(
'${offset}', 'i'
)

def _massage_arg(self, x):
if isinstance(x, array.Array):
Expand All @@ -539,6 +562,10 @@ def _massage_arg(self, x):
@profile
def __call__(self, *args, **kw):
c_args = [self._massage_arg(x) for x in args]
if self._config.count_flops:
flop_counter = array.zeros(args[0].length, np.int64,
backend=self.backend)
c_args.append(flop_counter.dev)
if self.backend == 'cython':
size = len(c_args[0])
c_args.insert(0, size)
Expand All @@ -552,6 +579,9 @@ def __call__(self, *args, **kw):
self.c_func(*c_args, **kw)
event.record()
event.synchronize()
if self._config.count_flops:
flops = array.sum(flop_counter)
record_flops(self.name, flops)


class Elementwise(object):
Expand Down
41 changes: 41 additions & 0 deletions compyle/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})


def _record_profile(name, time):
Expand All @@ -16,6 +17,12 @@ def _record_profile(name, time):
_profile_info[name]['calls'] += 1


def record_flops(name, flops):
global _flops_info
_flops_info[name]['flops'] += flops
_flops_info[name]['calls'] += 1


@contextmanager
def profile_ctx(name):
""" Context manager for profiling
Expand Down Expand Up @@ -54,6 +61,21 @@ def get_profile_info():
return _profile_info


def get_flops_info():
global _flops_info
return _flops_info


def reset_profile_info():
global _profile_info
_profile_info = defaultdict(lambda: {'calls': 0, 'time': 0})


def reset_flops_info():
global _flops_info
_flops_info = defaultdict(lambda: {'calls': 0, 'flops': 0})


def print_profile():
global _profile_info
profile_data = sorted(_profile_info.items(), key=lambda x: x[1]['time'],
Expand All @@ -73,6 +95,25 @@ def print_profile():
print("Total profiled time: %g secs" % tot_time)


def print_flops_info():
global _flops_info
flops_data = sorted(_flops_info.items(), key=lambda x: x[1]['flops'],
reverse=True)
if len(_flops_info) == 0:
print("No flops information available")
return
print("FLOPS info:")
print("{:<40} {:<10} {:<10}".format('Function', 'N calls', 'FLOPS'))
tot_flops = 0
for kernel, data in flops_data:
print("{:<40} {:<10} {:<10}".format(
kernel,
data['calls'],
data['flops']))
tot_flops += data['flops']
print("Total FLOPS: %i" % tot_flops)


def profile_kernel(kernel, name, backend=None):
"""For profiling raw PyCUDA/PyOpenCL kernels or cython functions
"""
Expand Down
Loading

0 comments on commit 129d687

Please sign in to comment.