diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index ecf3c16b6404..4b1937ea4d10 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -4452,6 +4452,12 @@ def get_rl_module_spec( # If module_config_dict is not defined, set to our generic one. if rl_module_spec.model_config is None: rl_module_spec.model_config = self.model_config + # Otherwise we combine the two dictionaries where settings from the + # `RLModuleSpec` have higher priority. + else: + rl_module_spec.model_config = ( + self.model_config | rl_module_spec._get_model_config() + ) if inference_only is not None: rl_module_spec.inference_only = inference_only diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index b35c9a6572fa..dfb5b4f9dcf4 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -130,7 +130,19 @@ def set_state(self, state: StateDict) -> None: # these keys (strict=False). This is most likely due to `state` coming from # an `inference_only=False` RLModule, while `self` is an `inference_only=True` # RLModule. - self.load_state_dict(convert_to_torch_tensor(state), strict=False) + missing_keys, unexpected_keys = self.load_state_dict( + convert_to_torch_tensor(state), strict=False + ) + + # For inference_only modules, missing_keys should always be empty. + # If there are missing keys, it means the target module expects parameters + # that don't exist in the source, indicating an architecture mismatch. + if self.inference_only and missing_keys: + raise ValueError( + f"Updating the module's state is missing keys: {list(missing_keys)} " + "This is most likely because the state has different layer names (or are missing layers). " + f"Complete list of state keys is {list(state.keys())}" + ) @OverrideToImplementCustomLogic @override(RLModule)