diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index 8530ae706..e2ebd4c70 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -366,7 +366,7 @@ std::vector AITModelImpl::processOutputs( output.unsafeGetTensorImpl()->set_sizes_contiguous(size); } - if (floating_point_output_dtype_ != c10::nullopt && + if (floating_point_output_dtype_ != std::nullopt && output.is_floating_point()) { outputs.emplace_back(output.to(*floating_point_output_dtype_)); } else { @@ -394,7 +394,7 @@ std::vector AITModelImpl::processInputs( auto input_name = input_names_[python_input_idx]; const auto ait_input_idx = input_name_to_index_.at(input_name); auto& input = inputs[python_input_idx]; - if (floating_point_input_dtype_ != c10::nullopt && + if (floating_point_input_dtype_ != std::nullopt && input.is_floating_point()) { // Need to keep input alive; cannot just stash result of to() // call in a local!