Skip to content

Commit

Permalink
fix(pt): optimize createNlistTensor (#4403)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced tensor creation process for improved performance and
efficiency.
  
- **Bug Fixes**
- Improved error handling for PyTorch-related exceptions, providing
clearer error messages.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Nov 23, 2024
1 parent 5a93798 commit 2303ff0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
22 changes: 11 additions & 11 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ void DeepPotPT::translate_error(std::function<void()> f) {
}

torch::Tensor createNlistTensor(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;

size_t total_size = 0;
for (const auto& row : data) {
torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0);
row_tensors.push_back(row_tensor);
total_size += row.size();
}

torch::Tensor tensor;
if (row_tensors.size() > 0) {
tensor = torch::cat(row_tensors, 0).unsqueeze(0);
} else {
tensor = torch::empty({1, 0, 0}, torch::kInt32);
std::vector<int> flat_data;
flat_data.reserve(total_size);
for (const auto& row : data) {
flat_data.insert(flat_data.end(), row.begin(), row.end());
}
return tensor;

torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32);
int nloc = data.size();
int nnei = nloc > 0 ? total_size / nloc : 0;
return flat_tensor.view({1, nloc, nnei});
}
DeepPotPT::DeepPotPT() : inited(false) {}
DeepPotPT::DeepPotPT(const std::string& model,
Expand Down
22 changes: 11 additions & 11 deletions source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ void DeepSpinPT::translate_error(std::function<void()> f) {
}

torch::Tensor createNlistTensor2(const std::vector<std::vector<int>>& data) {
std::vector<torch::Tensor> row_tensors;

size_t total_size = 0;
for (const auto& row : data) {
torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0);
row_tensors.push_back(row_tensor);
total_size += row.size();
}

torch::Tensor tensor;
if (row_tensors.size() > 0) {
tensor = torch::cat(row_tensors, 0).unsqueeze(0);
} else {
tensor = torch::empty({1, 0, 0}, torch::kInt32);
std::vector<int> flat_data;
flat_data.reserve(total_size);
for (const auto& row : data) {
flat_data.insert(flat_data.end(), row.begin(), row.end());
}
return tensor;

torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32);
int nloc = data.size();
int nnei = nloc > 0 ? total_size / nloc : 0;
return flat_tensor.view({1, nloc, nnei});
}
DeepSpinPT::DeepSpinPT() : inited(false) {}
DeepSpinPT::DeepSpinPT(const std::string& model,
Expand Down

0 comments on commit 2303ff0

Please sign in to comment.