Skip to content

Commit 752e770

Browse files
authored
Add eager mode stateful operators (#4016)
Add experimental exposure of eager.rng_state. All operators that are dependent on a state (excluding readers) are methods of eager.rng_state. Adds a function for exposing eager operators as objects if we ever want to switch to the ops-like API. Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
1 parent efdec48 commit 752e770

File tree

7 files changed

+519
-82
lines changed

7 files changed

+519
-82
lines changed

dali/python/nvidia/dali/_debug_mode.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def __init__(self, exec_func, **kwargs):
538538
self._exec_func = exec_func
539539
self._cur_operator_id = -1
540540
self._next_logical_id = 0
541+
self._seed_upper_bound = (1 << 31) - 1
541542
self._operators = {}
542543
self._operators_built = False
543544
self._cur_iter_batch_info = _IterBatchInfo(-1, None) # Used for variable batch sizes.
@@ -549,7 +550,7 @@ def __init__(self, exec_func, **kwargs):
549550
import numpy as np
550551
seed = kwargs.get('seed', -1)
551552
if seed < 0:
552-
seed = np.random.randint(0, 2**32)
553+
seed = np.random.randint(self._seed_upper_bound)
553554
self._seed_generator = np.random.default_rng(seed)
554555

555556
def __enter__(self):
@@ -625,7 +626,7 @@ def _create_op(self, op_class, op_name, key, cur_context, inputs, kwargs):
625626
"""Creates direct operator."""
626627
self._operators[key] = _OperatorManager(
627628
op_class, op_name, self, cur_context, self._next_logical_id, self._max_batch_size,
628-
self._device_id, self._seed_generator.integers(0, 2**32), inputs, kwargs)
629+
self._device_id, self._seed_generator.integers(self._seed_upper_bound), inputs, kwargs)
629630

630631
self._pipe.AddMultipleOperators(
631632
self._operators[key].op_spec, self._operators[key].logical_ids)

dali/python/nvidia/dali/_utils/eager_utils.py

+224-52
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,11 @@ def _arithm_op(name, *inputs):
212212
categories_idxs, inputs, integers, reals = _ops._group_inputs(
213213
inputs, edge_type=(_tensors.TensorListCPU, _tensors.TensorListGPU))
214214
input_desc = _ops._generate_input_desc(categories_idxs, integers, reals)
215-
device = _ops._choose_device(inputs)
215+
216+
if any(isinstance(input, _tensors.TensorListGPU) for input in inputs):
217+
device = 'gpu'
218+
else:
219+
device = 'cpu'
216220

217221
if device == "gpu":
218222
inputs = list(input._as_gpu() if isinstance(
@@ -333,6 +337,87 @@ def _rxor(self, other):
333337
_stateless_operators_cache = {}
334338

335339

340+
def _create_backend_op(spec, device, num_inputs, num_outputs, call_args_names, op_name):
341+
inp_device = 'cpu' if device == 'mixed' else device
342+
out_device = 'gpu' if device == 'mixed' else device
343+
344+
for i in range(num_inputs):
345+
spec.AddInput(op_name + f'[{i}]', inp_device)
346+
347+
for i in range(num_outputs):
348+
spec.AddOutput(op_name + f'_out[{i}]', out_device)
349+
350+
for arg_name in call_args_names:
351+
spec.AddArgumentInput(arg_name, '')
352+
353+
if device == 'cpu':
354+
backend_op = _b.EagerOperatorCPU(spec)
355+
elif device == 'gpu':
356+
backend_op = _b.EagerOperatorGPU(spec)
357+
elif device == 'mixed':
358+
backend_op = _b.EagerOperatorMixed(spec)
359+
else:
360+
raise ValueError(
361+
f"Incorrect device type '{device}' in eager operator '{op_name}'.")
362+
363+
return backend_op
364+
365+
366+
def _eager_op_object_factory(op_class, op_name):
367+
""" Creates eager operator class to use with objective ops-like API. For completeness,
368+
currently not used.
369+
"""
370+
class EagerOperator(op_class):
371+
def __init__(self, **kwargs):
372+
self._batch_size = getattr(kwargs, 'batch_size', -1)
373+
374+
# Workaround for batch size deduction in _prep_args as we don't have inputs yet.
375+
kwargs['batch_size'] = 0
376+
377+
_, init_args, _ = _prep_args(
378+
[], kwargs, op_name, op_name, _callable_op_factory.disqualified_arguments)
379+
device_id = init_args.pop('device_id')
380+
init_args.pop('max_batch_size')
381+
382+
super().__init__(**init_args)
383+
384+
self._spec.AddArg('device_id', device_id)
385+
self.built = False
386+
387+
def __call__(self, *inputs, **kwargs):
388+
inputs, init_args, call_args = _prep_args(
389+
inputs, kwargs, op_name, op_name, _callable_op_factory.disqualified_arguments)
390+
391+
if not self.built:
392+
num_outputs = self.schema.CalculateOutputs(
393+
self._spec) + self.schema.CalculateAdditionalOutputs(self._spec)
394+
395+
self._spec.AddArg('max_batch_size', init_args['max_batch_size'])
396+
self._backend_op = _create_backend_op(
397+
self._spec, self._device, len(inputs), num_outputs, call_args.keys(), op_name)
398+
self.built = True
399+
400+
output = self._backend_op(inputs, kwargs)
401+
402+
if len(output) == 1:
403+
return output[0]
404+
405+
return output
406+
407+
return EagerOperator
408+
409+
410+
def _expose_eager_op_as_object(op_class, submodule):
411+
""" Exposes eager operators as objects. Can be used if we decide to change eager API from
412+
functional to objective.
413+
"""
414+
415+
op_name = op_class.schema_name
416+
module = _internal.get_submodule('nvidia.dali.experimental.eager', submodule)
417+
op = _eager_op_object_factory(op_class, op_name)
418+
setattr(module, op_name, op)
419+
420+
336421
def _eager_op_base_factory(op_class, op_name, num_inputs, call_args_names):
337422
class EagerOperatorBase(op_class):
338423
def __init__(self, *, max_batch_size, device_id, **kwargs):
@@ -341,26 +426,55 @@ def __init__(self, *, max_batch_size, device_id, **kwargs):
341426
self._spec.AddArg('device_id', device_id)
342427
self._spec.AddArg('max_batch_size', max_batch_size)
343428

344-
for i in range(num_inputs):
345-
self._spec.AddInput(op_name + f'[{i}]', self._device)
346-
347-
for arg_name in call_args_names:
348-
self._spec.AddArgumentInput(arg_name, '')
429+
num_outputs = self.schema.CalculateOutputs(
430+
self._spec) + self.schema.CalculateAdditionalOutputs(self._spec)
349431

350-
if self._device == 'cpu':
351-
self._backend_op = _b.EagerOperatorCPU(self._spec)
352-
elif self._device == 'gpu':
353-
self._backend_op = _b.EagerOperatorGPU(self._spec)
354-
elif self._device == 'mixed':
355-
self._backend_op = _b.EagerOperatorMixed(self._spec)
356-
else:
357-
raise ValueError(
358-
f"Incorrect device type '{self._device}' in eager operator '{op_name}'.")
432+
self._backend_op = _create_backend_op(
433+
self._spec, self._device, num_inputs, num_outputs, call_args_names, op_name)
359434

360435
return EagerOperatorBase
361436

362437

363-
def _stateless_op_factory(op_class, op_name, num_inputs, call_args_names):
438+
def _create_module_class():
439+
""" Creates a class imitating a module. Used for `rng_state` so we can have nested methods.
440+
E.g. `rng_state.random.normal`.
441+
"""
442+
class Module:
443+
@classmethod
444+
def _submodule(cls, name):
445+
""" Returns submodule, creates new if it does not exist. """
446+
if name not in cls._submodules:
447+
# Register a new submodule class (object representing submodule will be created in
448+
# the rng_state's constructor).
449+
cls._submodules[name] = _create_state_submodule(name)
450+
451+
return cls._submodules[name]
452+
453+
_submodules = {}
454+
455+
return Module
456+
457+
458+
def _create_state_submodule(name):
459+
""" Creates a class imitating a submodule. It can contain methods and nested submodules.
460+
Used for submodules of rng_state, e.g. `rng_state.random`, `rng_state.noise`.
461+
"""
462+
463+
class StateSubmodule(_create_module_class()):
464+
def __init__(self, operator_cache, seed_generator):
465+
self._operator_cache = operator_cache
466+
self._seed_generator = seed_generator
467+
468+
for name, submodule_class in StateSubmodule._submodules.items():
469+
# Adds nested submodules.
470+
setattr(self, name, submodule_class(self._operator_cache, self._seed_generator))
471+
472+
__name__ = name
473+
474+
return StateSubmodule
475+
476+
477+
def _callable_op_factory(op_class, op_name, num_inputs, call_args_names):
364478
class EagerOperator(_eager_op_base_factory(op_class, op_name, num_inputs, call_args_names)):
365479
def __call__(self, inputs, kwargs):
366480
# Here all kwargs are supposed to be TensorLists.
@@ -374,6 +488,13 @@ def __call__(self, inputs, kwargs):
374488
return EagerOperator
375489

376490

491+
_callable_op_factory.disqualified_arguments = {
492+
'bytes_per_sample_hint',
493+
'preserve',
494+
'seed'
495+
}
496+
497+
377498
def _iterator_op_factory(op_class, op_name, num_inputs, call_args_names):
378499
class EagerOperator(_eager_op_base_factory(op_class, op_name, num_inputs, call_args_names)):
379500
def __init__(self, call_args, *, max_batch_size, **kwargs):
@@ -425,6 +546,12 @@ def __len__(self):
425546
return EagerOperator
426547

427548

549+
_iterator_op_factory.disqualified_arguments = {
550+
'bytes_per_sample_hint',
551+
'preserve',
552+
}
553+
554+
428555
def _choose_device(op_name, wrapper_name, inputs, device_param):
429556
"""Returns device type and device_id based on inputs and device_param."""
430557

@@ -541,32 +668,57 @@ def _desc_call_args(inputs, args):
541668
[(key, value.dtype, value.layout(), len(value[0].shape())) for key, value in args.items()]))
542669

543670

671+
def _gen_cache_key(op_name, inputs, init_args, call_args):
672+
""" Creating cache key consisting of operator name, description of inputs, input arguments
673+
and init args. Each call arg is described by dtype, layout and dim.
674+
"""
675+
return op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items()))
676+
677+
544678
def _wrap_stateless(op_class, op_name, wrapper_name):
545679
"""Wraps stateless Eager Operator in a function. Callable the same way as functions in fn API,
546680
but directly with TensorLists.
547681
"""
548682
def wrapper(*inputs, **kwargs):
549683
inputs, init_args, call_args = _prep_args(
550-
inputs, kwargs, op_name, wrapper_name, _wrap_stateless.disqualified_arguments)
684+
inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments)
551685

552-
# Creating cache key consisting of operator name, description of inputs, input arguments
553-
# and init args. Each call arg is described by dtype, layout and dim.
554-
key = op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items()))
686+
key = _gen_cache_key(op_name, inputs, init_args, call_args)
555687

556688
if key not in _stateless_operators_cache:
557-
_stateless_operators_cache[key] = _stateless_op_factory(
689+
_stateless_operators_cache[key] = _callable_op_factory(
558690
op_class, wrapper_name, len(inputs), call_args.keys())(**init_args)
559691

560692
return _stateless_operators_cache[key](inputs, call_args)
561693

562694
return wrapper
563695

564696

565-
_wrap_stateless.disqualified_arguments = {
566-
'bytes_per_sample_hint',
567-
'preserve',
568-
'seed'
569-
}
697+
def _wrap_stateful(op_class, op_name, wrapper_name):
698+
"""Wraps stateful Eager Operator as method of a class. Callable the same way as functions in
699+
fn API, but directly with TensorLists.
700+
"""
701+
702+
def wrapper(self, *inputs, **kwargs):
703+
inputs, init_args, call_args = _prep_args(
704+
inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments)
705+
706+
key = _gen_cache_key(op_name, inputs, init_args, call_args)
707+
708+
if key not in self._operator_cache:
709+
# Creating a new operator instance with deterministically generated seed, so if we
710+
# preserve the order of operator calls in different instances of rng_state, they
711+
# return the same results.
712+
seed = self._seed_generator.integers(_wrap_stateful.seed_upper_bound)
713+
self._operator_cache[key] = _callable_op_factory(
714+
op_class, wrapper_name, len(inputs), call_args.keys())(**init_args, seed=seed)
715+
716+
return self._operator_cache[key](inputs, call_args)
717+
718+
return wrapper
719+
720+
721+
_wrap_stateful.seed_upper_bound = (1 << 31) - 1
570722

571723

572724
def _wrap_iterator(op_class, op_name, wrapper_name):
@@ -582,7 +734,7 @@ def wrapper(*inputs, **kwargs):
582734
raise ValueError("Iterator type eager operators should not receive any inputs.")
583735

584736
inputs, init_args, call_args = _prep_args(
585-
inputs, kwargs, op_name, wrapper_name, _wrap_iterator.disqualified_arguments)
737+
inputs, kwargs, op_name, wrapper_name, _iterator_op_factory.disqualified_arguments)
586738

587739
op = _iterator_op_factory(op_class, wrapper_name, len(inputs),
588740
call_args.keys())(call_args, **init_args)
@@ -592,14 +744,39 @@ def wrapper(*inputs, **kwargs):
592744
return wrapper
593745

594746

595-
_wrap_iterator.disqualified_arguments = {
596-
'bytes_per_sample_hint',
597-
'preserve',
598-
}
747+
def _get_rng_state_target_module(submodules):
748+
""" Returns target module of rng_state. If a module did not exist, creates it. """
749+
from nvidia.dali.experimental import eager
750+
751+
last_module = eager.rng_state
752+
for cur_module_name in submodules:
753+
# If nonexistent registers rng_state's submodule.
754+
cur_module = last_module._submodule(cur_module_name)
755+
last_module = cur_module
756+
757+
return last_module
758+
759+
760+
def _get_eager_target_module(parent_module, submodules, make_hidden):
761+
""" Returns target module inside ``parent_module`` if specified, otherwise inside eager. """
762+
if parent_module is None:
763+
# Exposing to nvidia.dali.experimental.eager module.
764+
parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager')
765+
else:
766+
# Exposing to experimental.eager submodule of the specified parent module.
767+
parent_module = _internal.get_submodule(
768+
sys.modules[parent_module], 'experimental.eager')
769+
770+
if make_hidden:
771+
op_module = _internal.get_submodule(parent_module, submodules[:-1])
772+
else:
773+
op_module = _internal.get_submodule(parent_module, submodules)
774+
775+
return op_module
599776

600777

601-
def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc, make_hidden):
602-
"""Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`).
778+
def _wrap_eager_op(op_class, submodules, parent_module, wrapper_name, wrapper_doc, make_hidden):
779+
""" Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`).
603780
Uses ``op_class`` for preprocessing inputs and keyword arguments and filling OpSpec for backend
604781
eager operators.
605782
@@ -612,36 +789,31 @@ def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc
612789
wrapper_doc (str): Documentation of the wrapper function.
613790
make_hidden (bool): If operator is hidden, we should extract it from hidden submodule.
614791
"""
792+
615793
op_name = op_class.schema_name
616794
op_schema = _b.TryGetSchema(op_name)
617-
if op_schema.IsDeprecated() or op_name in _excluded_operators or op_name in _stateful_operators:
618-
# TODO(ksztenderski): For now only exposing stateless and iterator operators.
619-
return
620-
elif op_name in _iterator_operators:
621-
wrapper = _wrap_iterator(op_class, op_name, wrapper_name)
622-
else:
623-
# If operator is not stateful or a generator expose it as stateless.
624-
wrapper = _wrap_stateless(op_class, op_name, wrapper_name)
625795

626-
if parent_module is None:
627-
# Exposing to nvidia.dali.experimental.eager module.
628-
parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager')
796+
if op_schema.IsDeprecated() or op_name in _excluded_operators:
797+
return
798+
elif op_name in _stateful_operators:
799+
wrapper = _wrap_stateful(op_class, op_name, wrapper_name)
800+
op_module = _get_rng_state_target_module(submodules)
629801
else:
630-
# Exposing to experimental.eager submodule of the specified parent module.
631-
parent_module = _internal.get_submodule(sys.modules[parent_module], 'experimental.eager')
802+
if op_name in _iterator_operators:
803+
wrapper = _wrap_iterator(op_class, op_name, wrapper_name)
804+
else:
805+
# If operator is not stateful, generator, deprecated or excluded expose it as stateless.
806+
wrapper = _wrap_stateless(op_class, op_name, wrapper_name)
632807

633-
if make_hidden:
634-
op_module = _internal.get_submodule(parent_module, submodule[:-1])
635-
else:
636-
op_module = _internal.get_submodule(parent_module, submodule)
808+
op_module = _get_eager_target_module(parent_module, submodules, make_hidden)
637809

638810
if not hasattr(op_module, wrapper_name):
639811
wrapper.__name__ = wrapper_name
640812
wrapper.__qualname__ = wrapper_name
641813
wrapper.__doc__ = wrapper_doc
642-
wrapper._schema_name = op_schema
814+
wrapper._schema_name = op_name
643815

644-
if submodule:
816+
if submodules:
645817
wrapper.__module__ = op_module.__name__
646818

647819
setattr(op_module, wrapper_name, wrapper)

0 commit comments

Comments
 (0)