Skip to content

Commit

Permalink
Add ToVecVec to ragged, implement get labels from fsa
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 2, 2022
1 parent 0562097 commit 096f316
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 2 deletions.
30 changes: 28 additions & 2 deletions k2/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,37 @@ Tensor WeightsOfArcsAsTensor(const Array1<Arc> &arcs);
// memory location because Array1 does not support a stride. However
// it would be possible to get it as an Array2.
inline Array1<float> WeightsOfArcsAsArray1(const Array1<Arc> &arcs) {
return Array1<float>(WeightsOfArcsAsTensor(arcs));
ContextPtr c = arcs.Context();
const Arc *arc_data = arcs.Data();
Array1<float> weights(c, arcs.Dim());
const float *ptr = reinterpret_cast<const float *>(arc_data);
float *weights_data = weights.Data();
int32_t stride = 4;
K2_EVAL(
c, arcs.Dim(), lambda_get_weights,
(int32_t i)->void { weights_data[i] = ptr[i * stride + 3]; });
return weights;
}

inline Array1<float> WeightsOfFsaAsArray1(const Ragged<Arc> &fsa) {
return Array1<float>(WeightsOfArcsAsTensor(fsa.values));
return WeightsOfArcsAsArray1(fsa.values);
}

inline Array1<int32_t> LabelsOfArcsAsArray1(const Array1<Arc> &arcs) {
ContextPtr c = arcs.Context();
const Arc *arc_data = arcs.Data();
Array1<int32_t> labels(c, arcs.Dim());
const int32_t *ptr = reinterpret_cast<const int32_t *>(arc_data);
int32_t *labels_data = labels.Data();
int32_t stride = 4;
K2_EVAL(
c, arcs.Dim(), lambda_get_labels,
(int32_t i)->void { labels_data[i] = ptr[i * stride + 2]; });
return labels;
}

inline Array1<int32_t> LabelsOfFsaAsArray1(const Ragged<Arc> &fsa) {
return LabelsOfArcsAsArray1(fsa.values);
}

} // namespace k2
Expand Down
33 changes: 33 additions & 0 deletions k2/csrc/fsa_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <gtest/gtest.h>

#include <vector>
#include "k2/csrc/fsa.h"
#include "k2/csrc/fsa_utils.h"
#include "k2/csrc/test_utils.h"
Expand Down Expand Up @@ -120,4 +121,36 @@ TEST(FsaVecIO, FromAndToTensor) {
}
}

TEST(FsaArcs, GetLabelsAndWeights) {
// src_state dst_state label cost
std::string s1 = R"(0 1 1 1
0 2 2 2
1 3 -1 1
1 2 2 2
2 3 -1 3
3
)";

std::string s2 = R"(0 1 1 1.5
1 2 2 2.5
2 3 -1 3.5
3
)";
for (auto &context : {GetCpuContext(), GetCudaContext()}) {
Fsa fsa1 = FsaFromString(s1);
Fsa fsa2 = FsaFromString(s2);

Fsa *fsa_array[] = {&fsa1, &fsa2};
FsaVec fsa_vec = CreateFsaVec(2, &fsa_array[0]);
fsa_vec = fsa_vec.To(context);
auto labels = LabelsOfFsaAsArray1(fsa_vec);
std::vector<int32_t> expected_labels = {1, 2, -1, 2, -1, 1, 2, -1};
CheckArrayData(labels, expected_labels);

auto weights = WeightsOfFsaAsArray1(fsa_vec);
std::vector<float> expected_weights = {1, 2, 1, 2, 3, 1.5, 2.5, 3.5};
CheckArrayData(weights, expected_weights);
}
}

} // namespace k2
23 changes: 23 additions & 0 deletions k2/csrc/ragged.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#ifndef K2_CSRC_RAGGED_H_
#define K2_CSRC_RAGGED_H_

#include <algorithm>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -506,6 +507,28 @@ ToType(int64_t, Long)
// that Array1's that are the row_ids or row_splits of a Ragged object are
// not mutable so they can be re-used.
Ragged<T> Clone() const { return Ragged<T>(shape, values.Clone()); }

// Convert a ragged tensor with 2 axes into a vector of vector.
//
// CAUTION: this->NumAxes() must be 2.
std::vector<std::vector<T>> ToVecVec() const {
K2_CHECK_EQ(NumAxes(), 2);
if (Context()->GetDeviceType() == kCuda) {
return this->To(GetCpuContext()).ToVecVec();
}
int32_t dim0 = this->Dim0();
std::vector<std::vector<T>> ans(dim0);
const int32_t *row_splits_data = RowSplits(1).Data();
const T *values_data = values.Data();
for (int32_t i = 0; i != dim0; ++i) {
int32_t len = row_splits_data[i + 1] - row_splits_data[i];
ans[i].resize(len);
std::copy(values_data + row_splits_data[i],
values_data + row_splits_data[i + 1], ans[i].begin());
}
return ans;
}

};

// e.g. will produce something like "[ [ 3 4 ] [ 1 ] ]".
Expand Down

0 comments on commit 096f316

Please sign in to comment.