Skip to content

Commit a8f101f

Browse files
authored
cherrypicks from dev-0.7.0 (#83)
* new API :: device_type * more constructors
1 parent f3d26d6 commit a8f101f

File tree

6 files changed

+89
-12
lines changed

6 files changed

+89
-12
lines changed

include/ttl/bits/flat_tensor_mixin.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class flat_tensor_mixin
4444
public:
4545
using value_type = R;
4646
using shape_type = S;
47+
using device_type = D;
4748

4849
size_t data_size() const { return shape_.size() * sizeof(R); }
4950

include/ttl/bits/std_tensor.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class basic_tensor<R, basic_shape<0, Dim>, D, readwrite>
4141
using mixin::mixin;
4242

4343
public:
44+
basic_tensor(R *data) : mixin(data) {}
45+
4446
basic_tensor(const basic_tensor<R, basic_shape<0, Dim>, D, owner> &t)
4547
: mixin(t.data())
4648
{
@@ -67,6 +69,8 @@ class basic_tensor<R, basic_shape<0, Dim>, D, readonly>
6769
using mixin::mixin;
6870

6971
public:
72+
basic_tensor(const R *data) : mixin(data) {}
73+
7074
basic_tensor(const basic_tensor<R, basic_shape<0, Dim>, D, owner> &t)
7175
: mixin(t.data())
7276
{

include/ttl/bits/std_tensor_mixin.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class basic_scalar_mixin
3030
public:
3131
using value_type = R;
3232
using shape_type = S;
33+
using device_type = D;
3334

3435
static constexpr auto rank = S::rank; // == 0
3536

@@ -112,6 +113,7 @@ class basic_tensor_mixin
112113
public:
113114
using value_type = R;
114115
using shape_type = S;
116+
using device_type = D;
115117

116118
using slice_type = basic_tensor<R, S, D, typename trait::Access>;
117119

tests/_test_loc.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <cstdlib>
2+
#include <filesystem>
3+
#include <iostream>
4+
5+
#include "testing.hpp"
6+
7+
namespace fs = std::filesystem;
8+
9+
int loc(const char *filename)
10+
{
11+
FILE *fp = std::fopen(filename, "r");
12+
if (fp == nullptr) { return 0; }
13+
constexpr int max_line = 1 << 16;
14+
char line[max_line];
15+
int ln = 0;
16+
while (std::fgets(line, max_line - 1, fp)) { ++ln; }
17+
std::fclose(fp);
18+
return ln;
19+
}
20+
21+
TEST(test_loc, test1)
22+
{
23+
std::string path = "/path/to/directory";
24+
int tot = 0;
25+
int n = 0;
26+
for (const auto &entry : fs::directory_iterator("include/ttl/bits")) {
27+
const int ln = loc(entry.path().c_str());
28+
printf("%4d %s\n", ln, entry.path().c_str());
29+
ASSERT_TRUE(ln <= 200);
30+
tot += ln;
31+
++n;
32+
}
33+
printf("total: %d lines in %d files\n", tot, n);
34+
ASSERT_TRUE(tot <= 2000);
35+
}

tests/test_scalar.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "testing.hpp"
2+
3+
#include <ttl/tensor>
4+
5+
TEST(scalar_test, test_constructor)
6+
{
7+
{
8+
ttl::tensor<int, 0> x;
9+
ttl::tensor_ref<int, 0> r(x);
10+
ttl::tensor_view<int, 0> v(x);
11+
}
12+
{
13+
int value = 0;
14+
ttl::tensor_ref<int, 0> r(&value);
15+
ttl::tensor_view<int, 0> v(&value);
16+
}
17+
}

tests/test_tensor.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,24 @@ TEST(tensor_test, test1)
4040
ASSERT_EQ(sum, n * (n + 1) / 2);
4141
}
4242

43-
template <bool write, typename T> struct test_assign_ {
43+
template <bool write, typename T>
44+
struct test_assign_ {
4445
void operator()(T &x, int v) { x = v; }
4546
};
4647

47-
template <typename T> struct test_assign_<false, T> {
48+
template <typename T>
49+
struct test_assign_<false, T> {
4850
void operator()(T &x, int v) {}
4951
};
5052

51-
template <bool write, typename T> void test_assign(T &&x, int v)
53+
template <bool write, typename T>
54+
void test_assign(T &&x, int v)
5255
{
5356
test_assign_<write, T>()(x, v);
5457
}
5558

56-
template <typename T, bool write = true> struct test_5d_array {
59+
template <typename T, bool write = true>
60+
struct test_5d_array {
5761
void operator()(const T &t)
5862
{
5963
using R = typename T::value_type;
@@ -118,9 +122,13 @@ TEST(tensor_test, test3)
118122
test_5d_array<decltype(v), false>()(v);
119123
}
120124

121-
template <typename R, uint8_t r> void ref_func(const tensor_ref<R, r> &x) {}
125+
template <typename R, uint8_t r>
126+
void ref_func(const tensor_ref<R, r> &x)
127+
{
128+
}
122129

123-
template <typename R, uint8_t r> void test_auto_ref()
130+
template <typename R, uint8_t r>
131+
void test_auto_ref()
124132
{
125133
static_assert(std::is_convertible<tensor<R, r>, tensor_ref<R, r>>::value,
126134
"can't convert to ref");
@@ -142,9 +150,13 @@ TEST(tensor_test, auto_ref)
142150
// f(t); // NOT possible
143151
}
144152

145-
template <typename R, uint8_t r> void view_func(const tensor_view<R, r> &x) {}
153+
template <typename R, uint8_t r>
154+
void view_func(const tensor_view<R, r> &x)
155+
{
156+
}
146157

147-
template <typename R, uint8_t r> void test_auto_view()
158+
template <typename R, uint8_t r>
159+
void test_auto_view()
148160
{
149161
static_assert(std::is_convertible<tensor<R, r>, tensor_view<R, r>>::value,
150162
"can't convert to view");
@@ -177,7 +189,8 @@ auto create_tensor_func()
177189

178190
TEST(tensor_test, return_tensor) { auto t = create_tensor_func(); }
179191

180-
template <typename R> R read_tensor_func(const tensor<R, 2> &t, int i, int j)
192+
template <typename R>
193+
R read_tensor_func(const tensor<R, 2> &t, int i, int j)
181194
{
182195
const R x = t.at(i, j);
183196
return x;
@@ -256,6 +269,8 @@ void test_static_properties(const ttl::internal::basic_tensor<R, S, D, A> &x)
256269
using T = ttl::internal::basic_tensor<R, S, D, A>;
257270
static_assert(std::is_same<typename T::value_type, R>::value,
258271
"invalid value_type");
272+
static_assert(std::is_same<typename T::device_type, D>::value,
273+
"invalid device_type");
259274
static_assert(T::rank == r, "invalid rank");
260275
auto x_shape = x.shape();
261276
static_assert(decltype(x_shape)::rank == r, "invalid rank of shape");
@@ -316,7 +331,8 @@ TEST(tensor_test, test_const_properties)
316331
"");
317332
}
318333

319-
template <typename T> void test_slice_57_52_53_slice_19_38(const T &t)
334+
template <typename T>
335+
void test_slice_57_52_53_slice_19_38(const T &t)
320336
{
321337
const auto t1 = t.slice(0, 19);
322338
const auto t2 = t.slice(19, 57);
@@ -349,7 +365,8 @@ TEST(tensor_test, test_slice)
349365
}
350366
}
351367

352-
template <typename T> void test_data_end(const T &t)
368+
template <typename T>
369+
void test_data_end(const T &t)
353370
{
354371
ASSERT_EQ(t.data_end(), t.data() + t.shape().size());
355372
{
@@ -359,7 +376,8 @@ template <typename T> void test_data_end(const T &t)
359376
}
360377
}
361378

362-
template <typename R> void test_data_end_all()
379+
template <typename R>
380+
void test_data_end_all()
363381
{
364382
// using ttl::experimental::raw_ref;
365383
// using ttl::experimental::raw_view;

0 commit comments

Comments
 (0)