@@ -212,7 +212,11 @@ def _arithm_op(name, *inputs):
212
212
categories_idxs , inputs , integers , reals = _ops ._group_inputs (
213
213
inputs , edge_type = (_tensors .TensorListCPU , _tensors .TensorListGPU ))
214
214
input_desc = _ops ._generate_input_desc (categories_idxs , integers , reals )
215
- device = _ops ._choose_device (inputs )
215
+
216
+ if any (isinstance (input , _tensors .TensorListGPU ) for input in inputs ):
217
+ device = 'gpu'
218
+ else :
219
+ device = 'cpu'
216
220
217
221
if device == "gpu" :
218
222
inputs = list (input ._as_gpu () if isinstance (
@@ -333,6 +337,87 @@ def _rxor(self, other):
333
337
_stateless_operators_cache = {}
334
338
335
339
340
+ def _create_backend_op (spec , device , num_inputs , num_outputs , call_args_names , op_name ):
341
+ inp_device = 'cpu' if device == 'mixed' else device
342
+ out_device = 'gpu' if device == 'mixed' else device
343
+
344
+ for i in range (num_inputs ):
345
+ spec .AddInput (op_name + f'[{ i } ]' , inp_device )
346
+
347
+ for i in range (num_outputs ):
348
+ spec .AddOutput (op_name + f'_out[{ i } ]' , out_device )
349
+
350
+ for arg_name in call_args_names :
351
+ spec .AddArgumentInput (arg_name , '' )
352
+
353
+ if device == 'cpu' :
354
+ backend_op = _b .EagerOperatorCPU (spec )
355
+ elif device == 'gpu' :
356
+ backend_op = _b .EagerOperatorGPU (spec )
357
+ elif device == 'mixed' :
358
+ backend_op = _b .EagerOperatorMixed (spec )
359
+ else :
360
+ raise ValueError (
361
+ f"Incorrect device type '{ device } ' in eager operator '{ op_name } '." )
362
+
363
+ return backend_op
364
+
365
+
366
+ def _eager_op_object_factory (op_class , op_name ):
367
+ """ Creates eager operator class to use with objective ops-like API. For completeness,
368
+ currently not used.
369
+ """
370
+ class EagerOperator (op_class ):
371
+ def __init__ (self , ** kwargs ):
372
+ self ._batch_size = getattr (kwargs , 'batch_size' , - 1 )
373
+
374
+ # Workaround for batch size deduction in _prep_args as we don't have inputs yet.
375
+ kwargs ['batch_size' ] = 0
376
+
377
+ _ , init_args , _ = _prep_args (
378
+ [], kwargs , op_name , op_name , _callable_op_factory .disqualified_arguments )
379
+ device_id = init_args .pop ('device_id' )
380
+ init_args .pop ('max_batch_size' )
381
+
382
+ super ().__init__ (** init_args )
383
+
384
+ self ._spec .AddArg ('device_id' , device_id )
385
+ self .built = False
386
+
387
+ def __call__ (self , * inputs , ** kwargs ):
388
+ inputs , init_args , call_args = _prep_args (
389
+ inputs , kwargs , op_name , op_name , _callable_op_factory .disqualified_arguments )
390
+
391
+ if not self .built :
392
+ num_outputs = self .schema .CalculateOutputs (
393
+ self ._spec ) + self .schema .CalculateAdditionalOutputs (self ._spec )
394
+
395
+ self ._spec .AddArg ('max_batch_size' , init_args ['max_batch_size' ])
396
+ self ._backend_op = _create_backend_op (
397
+ self ._spec , self ._device , len (inputs ), num_outputs , call_args .keys (), op_name )
398
+ self .built = True
399
+
400
+ output = self ._backend_op (inputs , kwargs )
401
+
402
+ if len (output ) == 1 :
403
+ return output [0 ]
404
+
405
+ return output
406
+
407
+ return EagerOperator
408
+
409
+
410
+ def _expose_eager_op_as_object (op_class , submodule ):
411
+ """ Exposes eager operators as objects. Can be used if we decide to change eager API from
412
+ functional to objective.
413
+ """
414
+
415
+ op_name = op_class .schema_name
416
+ module = _internal .get_submodule ('nvidia.dali.experimental.eager' , submodule )
417
+ op = _eager_op_object_factory (op_class , op_name )
418
+ setattr (module , op_name , op )
419
+
420
+
336
421
def _eager_op_base_factory (op_class , op_name , num_inputs , call_args_names ):
337
422
class EagerOperatorBase (op_class ):
338
423
def __init__ (self , * , max_batch_size , device_id , ** kwargs ):
@@ -341,26 +426,55 @@ def __init__(self, *, max_batch_size, device_id, **kwargs):
341
426
self ._spec .AddArg ('device_id' , device_id )
342
427
self ._spec .AddArg ('max_batch_size' , max_batch_size )
343
428
344
- for i in range (num_inputs ):
345
- self ._spec .AddInput (op_name + f'[{ i } ]' , self ._device )
346
-
347
- for arg_name in call_args_names :
348
- self ._spec .AddArgumentInput (arg_name , '' )
429
+ num_outputs = self .schema .CalculateOutputs (
430
+ self ._spec ) + self .schema .CalculateAdditionalOutputs (self ._spec )
349
431
350
- if self ._device == 'cpu' :
351
- self ._backend_op = _b .EagerOperatorCPU (self ._spec )
352
- elif self ._device == 'gpu' :
353
- self ._backend_op = _b .EagerOperatorGPU (self ._spec )
354
- elif self ._device == 'mixed' :
355
- self ._backend_op = _b .EagerOperatorMixed (self ._spec )
356
- else :
357
- raise ValueError (
358
- f"Incorrect device type '{ self ._device } ' in eager operator '{ op_name } '." )
432
+ self ._backend_op = _create_backend_op (
433
+ self ._spec , self ._device , num_inputs , num_outputs , call_args_names , op_name )
359
434
360
435
return EagerOperatorBase
361
436
362
437
363
- def _stateless_op_factory (op_class , op_name , num_inputs , call_args_names ):
438
+ def _create_module_class ():
439
+ """ Creates a class imitating a module. Used for `rng_state` so we can have nested methods.
440
+ E.g. `rng_state.random.normal`.
441
+ """
442
+ class Module :
443
+ @classmethod
444
+ def _submodule (cls , name ):
445
+ """ Returns submodule, creates new if it does not exist. """
446
+ if name not in cls ._submodules :
447
+ # Register a new submodule class (object representing submodule will be created in
448
+ # the rng_state's constructor).
449
+ cls ._submodules [name ] = _create_state_submodule (name )
450
+
451
+ return cls ._submodules [name ]
452
+
453
+ _submodules = {}
454
+
455
+ return Module
456
+
457
+
458
+ def _create_state_submodule (name ):
459
+ """ Creates a class imitating a submodule. It can contain methods and nested submodules.
460
+ Used for submodules of rng_state, e.g. `rng_state.random`, `rng_state.noise`.
461
+ """
462
+
463
+ class StateSubmodule (_create_module_class ()):
464
+ def __init__ (self , operator_cache , seed_generator ):
465
+ self ._operator_cache = operator_cache
466
+ self ._seed_generator = seed_generator
467
+
468
+ for name , submodule_class in StateSubmodule ._submodules .items ():
469
+ # Adds nested submodules.
470
+ setattr (self , name , submodule_class (self ._operator_cache , self ._seed_generator ))
471
+
472
+ __name__ = name
473
+
474
+ return StateSubmodule
475
+
476
+
477
+ def _callable_op_factory (op_class , op_name , num_inputs , call_args_names ):
364
478
class EagerOperator (_eager_op_base_factory (op_class , op_name , num_inputs , call_args_names )):
365
479
def __call__ (self , inputs , kwargs ):
366
480
# Here all kwargs are supposed to be TensorLists.
@@ -374,6 +488,13 @@ def __call__(self, inputs, kwargs):
374
488
return EagerOperator
375
489
376
490
491
+ _callable_op_factory .disqualified_arguments = {
492
+ 'bytes_per_sample_hint' ,
493
+ 'preserve' ,
494
+ 'seed'
495
+ }
496
+
497
+
377
498
def _iterator_op_factory (op_class , op_name , num_inputs , call_args_names ):
378
499
class EagerOperator (_eager_op_base_factory (op_class , op_name , num_inputs , call_args_names )):
379
500
def __init__ (self , call_args , * , max_batch_size , ** kwargs ):
@@ -425,6 +546,12 @@ def __len__(self):
425
546
return EagerOperator
426
547
427
548
549
+ _iterator_op_factory .disqualified_arguments = {
550
+ 'bytes_per_sample_hint' ,
551
+ 'preserve' ,
552
+ }
553
+
554
+
428
555
def _choose_device (op_name , wrapper_name , inputs , device_param ):
429
556
"""Returns device type and device_id based on inputs and device_param."""
430
557
@@ -541,32 +668,57 @@ def _desc_call_args(inputs, args):
541
668
[(key , value .dtype , value .layout (), len (value [0 ].shape ())) for key , value in args .items ()]))
542
669
543
670
671
+ def _gen_cache_key (op_name , inputs , init_args , call_args ):
672
+ """ Creating cache key consisting of operator name, description of inputs, input arguments
673
+ and init args. Each call arg is described by dtype, layout and dim.
674
+ """
675
+ return op_name + _desc_call_args (inputs , call_args ) + str (sorted (init_args .items ()))
676
+
677
+
544
678
def _wrap_stateless (op_class , op_name , wrapper_name ):
545
679
"""Wraps stateless Eager Operator in a function. Callable the same way as functions in fn API,
546
680
but directly with TensorLists.
547
681
"""
548
682
def wrapper (* inputs , ** kwargs ):
549
683
inputs , init_args , call_args = _prep_args (
550
- inputs , kwargs , op_name , wrapper_name , _wrap_stateless .disqualified_arguments )
684
+ inputs , kwargs , op_name , wrapper_name , _callable_op_factory .disqualified_arguments )
551
685
552
- # Creating cache key consisting of operator name, description of inputs, input arguments
553
- # and init args. Each call arg is described by dtype, layout and dim.
554
- key = op_name + _desc_call_args (inputs , call_args ) + str (sorted (init_args .items ()))
686
+ key = _gen_cache_key (op_name , inputs , init_args , call_args )
555
687
556
688
if key not in _stateless_operators_cache :
557
- _stateless_operators_cache [key ] = _stateless_op_factory (
689
+ _stateless_operators_cache [key ] = _callable_op_factory (
558
690
op_class , wrapper_name , len (inputs ), call_args .keys ())(** init_args )
559
691
560
692
return _stateless_operators_cache [key ](inputs , call_args )
561
693
562
694
return wrapper
563
695
564
696
565
- _wrap_stateless .disqualified_arguments = {
566
- 'bytes_per_sample_hint' ,
567
- 'preserve' ,
568
- 'seed'
569
- }
697
+ def _wrap_stateful (op_class , op_name , wrapper_name ):
698
+ """Wraps stateful Eager Operator as method of a class. Callable the same way as functions in
699
+ fn API, but directly with TensorLists.
700
+ """
701
+
702
+ def wrapper (self , * inputs , ** kwargs ):
703
+ inputs , init_args , call_args = _prep_args (
704
+ inputs , kwargs , op_name , wrapper_name , _callable_op_factory .disqualified_arguments )
705
+
706
+ key = _gen_cache_key (op_name , inputs , init_args , call_args )
707
+
708
+ if key not in self ._operator_cache :
709
+ # Creating a new operator instance with deterministically generated seed, so if we
710
+ # preserve the order of operator calls in different instances of rng_state, they
711
+ # return the same results.
712
+ seed = self ._seed_generator .integers (_wrap_stateful .seed_upper_bound )
713
+ self ._operator_cache [key ] = _callable_op_factory (
714
+ op_class , wrapper_name , len (inputs ), call_args .keys ())(** init_args , seed = seed )
715
+
716
+ return self ._operator_cache [key ](inputs , call_args )
717
+
718
+ return wrapper
719
+
720
+
721
+ _wrap_stateful .seed_upper_bound = (1 << 31 ) - 1
570
722
571
723
572
724
def _wrap_iterator (op_class , op_name , wrapper_name ):
@@ -582,7 +734,7 @@ def wrapper(*inputs, **kwargs):
582
734
raise ValueError ("Iterator type eager operators should not receive any inputs." )
583
735
584
736
inputs , init_args , call_args = _prep_args (
585
- inputs , kwargs , op_name , wrapper_name , _wrap_iterator .disqualified_arguments )
737
+ inputs , kwargs , op_name , wrapper_name , _iterator_op_factory .disqualified_arguments )
586
738
587
739
op = _iterator_op_factory (op_class , wrapper_name , len (inputs ),
588
740
call_args .keys ())(call_args , ** init_args )
@@ -592,14 +744,39 @@ def wrapper(*inputs, **kwargs):
592
744
return wrapper
593
745
594
746
595
- _wrap_iterator .disqualified_arguments = {
596
- 'bytes_per_sample_hint' ,
597
- 'preserve' ,
598
- }
747
+ def _get_rng_state_target_module (submodules ):
748
+ """ Returns target module of rng_state. If a module did not exist, creates it. """
749
+ from nvidia .dali .experimental import eager
750
+
751
+ last_module = eager .rng_state
752
+ for cur_module_name in submodules :
753
+ # If nonexistent registers rng_state's submodule.
754
+ cur_module = last_module ._submodule (cur_module_name )
755
+ last_module = cur_module
756
+
757
+ return last_module
758
+
759
+
760
+ def _get_eager_target_module (parent_module , submodules , make_hidden ):
761
+ """ Returns target module inside ``parent_module`` if specified, otherwise inside eager. """
762
+ if parent_module is None :
763
+ # Exposing to nvidia.dali.experimental.eager module.
764
+ parent_module = _internal .get_submodule ('nvidia.dali' , 'experimental.eager' )
765
+ else :
766
+ # Exposing to experimental.eager submodule of the specified parent module.
767
+ parent_module = _internal .get_submodule (
768
+ sys .modules [parent_module ], 'experimental.eager' )
769
+
770
+ if make_hidden :
771
+ op_module = _internal .get_submodule (parent_module , submodules [:- 1 ])
772
+ else :
773
+ op_module = _internal .get_submodule (parent_module , submodules )
774
+
775
+ return op_module
599
776
600
777
601
- def _wrap_eager_op (op_class , submodule , parent_module , wrapper_name , wrapper_doc , make_hidden ):
602
- """Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`).
778
+ def _wrap_eager_op (op_class , submodules , parent_module , wrapper_name , wrapper_doc , make_hidden ):
779
+ """ Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`).
603
780
Uses ``op_class`` for preprocessing inputs and keyword arguments and filling OpSpec for backend
604
781
eager operators.
605
782
@@ -612,36 +789,31 @@ def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc
612
789
wrapper_doc (str): Documentation of the wrapper function.
613
790
make_hidden (bool): If operator is hidden, we should extract it from hidden submodule.
614
791
"""
792
+
615
793
op_name = op_class .schema_name
616
794
op_schema = _b .TryGetSchema (op_name )
617
- if op_schema .IsDeprecated () or op_name in _excluded_operators or op_name in _stateful_operators :
618
- # TODO(ksztenderski): For now only exposing stateless and iterator operators.
619
- return
620
- elif op_name in _iterator_operators :
621
- wrapper = _wrap_iterator (op_class , op_name , wrapper_name )
622
- else :
623
- # If operator is not stateful or a generator expose it as stateless.
624
- wrapper = _wrap_stateless (op_class , op_name , wrapper_name )
625
795
626
- if parent_module is None :
627
- # Exposing to nvidia.dali.experimental.eager module.
628
- parent_module = _internal .get_submodule ('nvidia.dali' , 'experimental.eager' )
796
+ if op_schema .IsDeprecated () or op_name in _excluded_operators :
797
+ return
798
+ elif op_name in _stateful_operators :
799
+ wrapper = _wrap_stateful (op_class , op_name , wrapper_name )
800
+ op_module = _get_rng_state_target_module (submodules )
629
801
else :
630
- # Exposing to experimental.eager submodule of the specified parent module.
631
- parent_module = _internal .get_submodule (sys .modules [parent_module ], 'experimental.eager' )
802
+ if op_name in _iterator_operators :
803
+ wrapper = _wrap_iterator (op_class , op_name , wrapper_name )
804
+ else :
805
+ # If operator is not stateful, generator, deprecated or excluded expose it as stateless.
806
+ wrapper = _wrap_stateless (op_class , op_name , wrapper_name )
632
807
633
- if make_hidden :
634
- op_module = _internal .get_submodule (parent_module , submodule [:- 1 ])
635
- else :
636
- op_module = _internal .get_submodule (parent_module , submodule )
808
+ op_module = _get_eager_target_module (parent_module , submodules , make_hidden )
637
809
638
810
if not hasattr (op_module , wrapper_name ):
639
811
wrapper .__name__ = wrapper_name
640
812
wrapper .__qualname__ = wrapper_name
641
813
wrapper .__doc__ = wrapper_doc
642
- wrapper ._schema_name = op_schema
814
+ wrapper ._schema_name = op_name
643
815
644
- if submodule :
816
+ if submodules :
645
817
wrapper .__module__ = op_module .__name__
646
818
647
819
setattr (op_module , wrapper_name , wrapper )
0 commit comments