Skip to content

Commit 34a149b

Browse files
authored
Merge pull request #20 from ewanwm/feature_trig_functions
add trig functions in Tensor
2 parents b5eafcb + dbb02c1 commit 34a149b

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

nuTens/tensors/tensor.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,19 @@ class Tensor{
214214

215215
/// @}
216216

217+
218+
/// @name Trigonometric
219+
/// @{
220+
221+
/// @brief Get element-wise sin of a tensor
222+
/// @param t The tensor
223+
static Tensor sin(const Tensor &t);
224+
225+
/// @brief Get element-wise cosine of a tensor
226+
/// @param t The tensor
227+
static Tensor cos(const Tensor &t);
228+
229+
/// @}
217230

218231
/// @brief Overwrite the << operator to print this tensor out to the command line
219232
friend std::ostream &operator<< (std::ostream &stream, const Tensor &tensor) {

nuTens/tensors/torch-tensor.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

22
#include <nuTens/tensors/tensor.hpp>
3-
#include <torch/torch.h>
43

54

65
// map between the data types used in nuTens and those used by pytorch
@@ -312,7 +311,7 @@ Tensor Tensor::operator- () const {
312311

313312
Tensor Tensor::cumsum(int dim) const {
314313
Tensor ret;
315-
ret._tensor = _tensor.cumsum(dim);
314+
ret._tensor = torch::cumsum(_tensor, dim);
316315
return ret;
317316
}
318317

@@ -332,6 +331,19 @@ Tensor Tensor::grad() const {
332331
return ret;
333332
}
334333

334+
335+
Tensor Tensor::sin(const Tensor &t) {
336+
Tensor ret;
337+
ret._tensor = torch::sin(t._tensor);
338+
return ret;
339+
}
340+
341+
Tensor Tensor::cos(const Tensor &t) {
342+
Tensor ret;
343+
ret._tensor = torch::cos(t._tensor);
344+
return ret;
345+
}
346+
335347
std::string Tensor::toString() const {
336348
std::ostringstream stream;
337349
stream << _tensor;

0 commit comments

Comments
 (0)