Skip to content

Commit

Permalink
Merge branch 'main' into mistral_7b_support/akoumparouli
Browse files Browse the repository at this point in the history
  • Loading branch information
ericharper authored Jan 24, 2024
2 parents 035a228 + aeb9799 commit 81372b0
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 7 deletions.
6 changes: 6 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ model:
## Flash Attention
use_flash_attention: False # Use flash attention in self-attention module, this config does nothing when transformer_engine=True

##Offloading Activations/Weights to CPU
cpu_offloading: False
cpu_offloading_num_layers: 11 #This value should be between [1,num_layers-1] as we don't want to offload the final layer's activations and expose any offloading duration for the final layer
cpu_offloading_activations: True
cpu_offloading_weights: True

## Network
sharp: False # Enable the use of SHARP for NCCL data-parallel communications. This is going to be ignored if the network doesn't support SHARP.

Expand Down
48 changes: 41 additions & 7 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck', 'PretrainedModelInfo']

_TYPECHECK_ENABLED = True
_TYPECHECK_SEMANTIC_CHECK_ENABLED = True
# TODO @blisc: Remove _HAS_HYDRA
_HAS_HYDRA = True

Expand All @@ -54,6 +55,13 @@ def is_typecheck_enabled():
return _TYPECHECK_ENABLED


def is_semantic_typecheck_enabled():
"""
Getter method for typechecking semantics state.
"""
return _TYPECHECK_SEMANTIC_CHECK_ENABLED


@dataclass
class TypecheckMetadata:
"""
Expand Down Expand Up @@ -178,7 +186,6 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw
kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped
function upon call.
"""
# TODO: Properly implement this
if input_types is not None:
# Precompute metadata
metadata = TypecheckMetadata(original_types=input_types, ignore_collections=ignore_collections)
Expand All @@ -202,9 +209,11 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw
)

# Perform neural type check
if hasattr(value, 'neural_type') and not metadata.base_types[key].compare(value.neural_type) in (
NeuralTypeComparisonResult.SAME,
NeuralTypeComparisonResult.GREATER,
if (
hasattr(value, 'neural_type')
and is_semantic_typecheck_enabled()
and not metadata.base_types[key].compare(value.neural_type)
in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,)
):
error_msg = [
f"{input_types[key].compare(value.neural_type)} :",
Expand Down Expand Up @@ -379,9 +388,11 @@ def __check_neural_type(self, obj, metadata: TypecheckMetadata, depth: int, name
f"Expected nested depth : {metadata.container_depth[name]}"
)

if hasattr(obj, 'neural_type') and not type_val.compare(obj.neural_type) in (
NeuralTypeComparisonResult.SAME,
NeuralTypeComparisonResult.GREATER,
if (
hasattr(obj, 'neural_type')
and is_semantic_typecheck_enabled()
and not type_val.compare(obj.neural_type)
in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,)
):
raise TypeError(
f"{type_val.compare(obj.neural_type)} : \n"
Expand Down Expand Up @@ -1114,3 +1125,26 @@ def disable_checks():
yield
finally:
typecheck.set_typecheck_enabled(enabled=True)

@staticmethod
def set_semantic_check_enabled(enabled: bool = True):
"""
Global method to enable/disable semantic typechecking.
Args:
enabled: bool, when True will enable semantic typechecking.
"""
global _TYPECHECK_SEMANTIC_CHECK_ENABLED
_TYPECHECK_SEMANTIC_CHECK_ENABLED = enabled

@staticmethod
@contextmanager
def disable_semantic_checks():
"""
Context manager that temporarily disables semantic type checking within its context.
"""
typecheck.set_semantic_check_enabled(enabled=False)
try:
yield
finally:
typecheck.set_semantic_check_enabled(enabled=True)
38 changes: 38 additions & 0 deletions tests/core/test_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,3 +1152,41 @@ def __call__(self, x):
assert len(outA[0]) == 3
for i in range(len(outA)):
assert outA[0][i].neural_type.compare(NeuralType(('B', 'D'), LogitsType()))

@pytest.mark.unit
def test_disable_semantic_types_input_output(self):
class InputOutputTypes(Typing):
@property
def input_types(self):
return {"x": NeuralType(('B',), LogprobsType())}

@property
def output_types(self):
return {"y": NeuralType(('B',), LabelsType())}

@typecheck()
def __call__(self, x):
x += 1
return x

obj = InputOutputTypes()
result = obj(x=torch.zeros(10))

assert result.sum() == torch.tensor(10.0)
assert result.neural_type.compare(NeuralType(('B',), LabelsType())) == NeuralTypeComparisonResult.SAME

# Test that input is provided with wrong type and semantic checks are not disabled
with pytest.raises(TypeError):
input_data = torch.zeros(10)
input_data.neural_type = NeuralType(('B',), LabelsType())
_ = obj(x=input_data)

# Provide input with wrong type after disabling semantic type checks
with typecheck.disable_semantic_checks():
input_data = torch.zeros(10)
input_data.neural_type = NeuralType(('B',), LabelsType()) # Should be LogprobsType()
result = obj(x=input_data)

# assert that even if semantic types are disabled, output is attached with appropriate types
assert result.sum() == torch.tensor(10.0)
assert result.neural_type.compare(NeuralType(('B',), LabelsType())) == NeuralTypeComparisonResult.SAME

0 comments on commit 81372b0

Please sign in to comment.