Skip to content

Commit

Permalink
Merge pull request #190 from didi/fe_op
Browse files Browse the repository at this point in the history
Fix doc and bugs of FE.
  • Loading branch information
zh794390558 authored Dec 20, 2019
2 parents 2e1b064 + 8e4bfa5 commit b85dfee
Show file tree
Hide file tree
Showing 64 changed files with 900 additions and 292 deletions.
2 changes: 2 additions & 0 deletions core/ops/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
gen/
cppjieba
*.so
!data/sm1_cln.wav
*.scp
!noiselist.scp
Expand Down
4 changes: 3 additions & 1 deletion core/ops/kernels/analyfiltbank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ int Analyfiltbank::proc_afb(const float* mic_buf) {
xcomplex* win = static_cast<xcomplex*>(malloc(sizeof(xcomplex) * i_FFTSiz));
xcomplex* fftwin =
static_cast<xcomplex*>(malloc(sizeof(xcomplex) * i_FFTSiz));
float* fft_buf = static_cast<float*>(malloc(sizeof(float) * 2 * i_FFTSiz));

/* generate window */
gen_window(pf_WINDOW, i_WinLen, s_WinTyp);
Expand All @@ -96,7 +97,7 @@ int Analyfiltbank::proc_afb(const float* mic_buf) {
}

/* fft */
dit_r2_fft(win, fftwin, i_FFTSiz, -1);
dit_r2_fft(win, fftwin, fft_buf, i_FFTSiz, -1);

for (k = 0; k < i_NumFrq; k++) {
pf_PowSpc[n * i_NumFrq + k] = complex_abs2(fftwin[k]);
Expand All @@ -106,6 +107,7 @@ int Analyfiltbank::proc_afb(const float* mic_buf) {

free(win);
free(fftwin);
free(fft_buf);

return 1;
}
Expand Down
2 changes: 2 additions & 0 deletions core/ops/kernels/fbank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Fbank::Fbank()
upper_frequency_limit_(kDefaultUpperFrequencyLimit),
filterbank_channel_count_(kDefaultFilterbankChannelCount) {}

Fbank::~Fbank() {}

bool Fbank::Initialize(int input_length, double input_sample_rate) {
if (input_length < 1) {
LOG(ERROR) << "Input length must be positive.";
Expand Down
1 change: 1 addition & 0 deletions core/ops/kernels/fbank.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace delta {
class Fbank {
public:
Fbank();
~Fbank();
bool Initialize(int input_length, double input_sample_rate);
// Input is a single squared-magnitude spectrogram frame. The input spectrum
// is converted to linear magnitude and weighted into bands using a
Expand Down
7 changes: 7 additions & 0 deletions core/ops/kernels/fbank_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class FbankOp : public OpKernel {
sample_rate_tensor.shape().DebugString(), " instead."));
const int32 sample_rate = sample_rate_tensor.scalar<int32>()();

if (upper_frequency_limit_ <= 0)
upper_frequency_limit_ = sample_rate / 2.0 + upper_frequency_limit_;
else if (upper_frequency_limit_ > sample_rate / 2.0 || upper_frequency_limit_ <= lower_frequency_limit_)
upper_frequency_limit_ = sample_rate / 2.0;

// shape [channels, time, bins]
const int spectrogram_channels = spectrogram.dim_size(2);
const int spectrogram_samples = spectrogram.dim_size(1);
Expand Down Expand Up @@ -94,6 +99,8 @@ class FbankOp : public OpKernel {
for (int i = 0; i < filterbank_channel_count_; ++i) {
output_data[i] = fbank_output[i];
}
std::vector<double>().swap(fbank_input);
std::vector<double>().swap(fbank_output);
}
}
}
Expand Down
39 changes: 34 additions & 5 deletions core/ops/kernels/framepow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const float frame_length_sec = 0.010;
FramePow::FramePow() {
window_length_sec_ = window_length_sec;
frame_length_sec_ = frame_length_sec;
i_snip_edges = true;
i_remove_dc_offset = true;
pf_FrmEng = NULL;
}

Expand All @@ -40,27 +42,54 @@ void FramePow::set_frame_length_sec(float frame_length_sec) {
frame_length_sec_ = frame_length_sec;
}

void FramePow::set_snip_edges(bool snip_edges) { i_snip_edges = snip_edges; }

void FramePow::set_remove_dc_offset(bool remove_dc_offset) {
i_remove_dc_offset = remove_dc_offset;
}

int FramePow::init_eng(int input_size, float sample_rate) {
f_SamRat = sample_rate;
i_WinLen = static_cast<int>(window_length_sec_ * f_SamRat);
i_FrmLen = static_cast<int>(frame_length_sec_ * f_SamRat);
i_NumFrm = (input_size - i_WinLen) / i_FrmLen + 1;
if (i_snip_edges == true)
i_NumFrm = (input_size - i_WinLen) / i_FrmLen + 1;
else
i_NumFrm = (input_size + i_FrmLen / 2) / i_FrmLen;

pf_FrmEng = static_cast<float*>(malloc(sizeof(float) * i_NumFrm));

return 1;
}

int FramePow::proc_eng(const float* mic_buf) {
int n, k;
int FramePow::proc_eng(const float* mic_buf, int input_size) {
int i, n, k;
float* win = static_cast<float*>(malloc(sizeof(float) * i_WinLen));

for (n = 0; n < i_NumFrm; n++) {
pf_FrmEng[n] = 0.0;
float sum = 0.0;
float energy = 0.0;
for (k = 0; k < i_WinLen; k++) {
win[k] = mic_buf[n * i_FrmLen + k];
pf_FrmEng[n] = pf_FrmEng[n] + win[k] * win[k];
int index = n * i_FrmLen + k;
if (index < input_size)
win[k] = mic_buf[index];
else
win[k] = 0.0f;
sum += win[k];
}

if (i_remove_dc_offset == true) {
float mean = sum / i_WinLen;
for (int l = 0; l < i_WinLen; l++) win[l] -= mean;
}

for (i = 0; i < i_WinLen; i++) {
energy += win[i] * win[i];
}

pf_FrmEng[n] = log(energy);

}

free(win);
Expand Down
8 changes: 7 additions & 1 deletion core/ops/kernels/framepow.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class FramePow {
private:
float window_length_sec_;
float frame_length_sec_;
bool i_snip_edges;
bool i_remove_dc_offset;

float f_SamRat;
int i_WinLen;
Expand All @@ -44,9 +46,13 @@ class FramePow {

void set_frame_length_sec(float frame_length_sec);

void set_snip_edges(bool snip_edges);

void set_remove_dc_offset(bool remove_dc_offset);

int init_eng(int input_size, float sample_rate);

int proc_eng(const float* mic_buf);
int proc_eng(const float* mic_buf, int input_size);

int get_eng(float* output);

Expand Down
11 changes: 10 additions & 1 deletion core/ops/kernels/framepow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class FramePowOp : public OpKernel {
explicit FramePowOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("window_length", &window_length_));
OP_REQUIRES_OK(context, context->GetAttr("frame_length", &frame_length_));
OP_REQUIRES_OK(context, context->GetAttr("snip_edges", &snip_edges_));
OP_REQUIRES_OK(context,
context->GetAttr("remove_dc_offset", &remove_dc_offset_));
}

void Compute(OpKernelContext* context) override {
Expand All @@ -49,6 +52,8 @@ class FramePowOp : public OpKernel {
FramePow cls_eng;
cls_eng.set_window_length_sec(window_length_);
cls_eng.set_frame_length_sec(frame_length_);
cls_eng.set_snip_edges(snip_edges_);
cls_eng.set_remove_dc_offset(remove_dc_offset_);
OP_REQUIRES(context, cls_eng.init_eng(L, sample_rate),
errors::InvalidArgument(
"framepow_class initialization failed for length ", L,
Expand All @@ -58,20 +63,24 @@ class FramePowOp : public OpKernel {
int i_WinLen = static_cast<int>(window_length_ * sample_rate);
int i_FrmLen = static_cast<int>(frame_length_ * sample_rate);
int i_NumFrm = (L - i_WinLen) / i_FrmLen + 1;
if (snip_edges_ == false) i_NumFrm = (L + i_FrmLen / 2) / i_FrmLen;
if (i_NumFrm < 1) i_NumFrm = 1;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({1, i_NumFrm}), &output_tensor));

const float* input_flat = input_tensor.flat<float>().data();
float* output_flat = output_tensor->flat<float>().data();

int ret;
ret = cls_eng.proc_eng(input_flat);
ret = cls_eng.proc_eng(input_flat, L);
ret = cls_eng.get_eng(output_flat);
}

private:
float window_length_;
float frame_length_;
bool snip_edges_;
bool remove_dc_offset_;
};

REGISTER_KERNEL_BUILDER(Name("FramePow").Device(DEVICE_CPU), FramePowOp);
Expand Down
23 changes: 9 additions & 14 deletions core/ops/kernels/mfcc_dct_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class MfccDctOp : public OpKernel {
OP_REQUIRES(context, fbank.dims() == 3,
errors::InvalidArgument("Fbank must be 3-dimensional",
fbank.shape().DebugString()));
const Tensor& spectrum = context->input(1);
OP_REQUIRES(context, spectrum.dims() == 3,
errors::InvalidArgument("Spectrum must be 3-dimensional",
spectrum.shape().DebugString()));
const Tensor& framepow = context->input(1);
OP_REQUIRES(context, framepow.dims() == 1,
errors::InvalidArgument("Framepow must be 1-dimensional",
framepow.shape().DebugString()));
const Tensor& sample_rate_tensor = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()),
errors::InvalidArgument(
Expand All @@ -56,8 +56,6 @@ class MfccDctOp : public OpKernel {
const int fbank_channels = fbank.dim_size(2);
const int fbank_samples = fbank.dim_size(1);
const int audio_channels = fbank.dim_size(0);
const int spectrum_samples = spectrum.dim_size(1);
const int spectrum_channels = spectrum.dim_size(2);

MfccDct mfcc;
mfcc.set_coefficient_count(coefficient_count_);
Expand All @@ -77,7 +75,7 @@ class MfccDctOp : public OpKernel {
&output_tensor));

const float* fbank_flat = fbank.flat<float>().data();
const float* spectrum_flat = spectrum.flat<float>().data();
const float* framepow_flat = framepow.flat<float>().data();
float* output_flat = output_tensor->flat<float>().data();

for (int audio_channel = 0; audio_channel < audio_channels;
Expand All @@ -86,13 +84,10 @@ class MfccDctOp : public OpKernel {
const float* sample_data =
fbank_flat + (audio_channel * fbank_samples * fbank_channels) +
(fbank_sample * fbank_channels);
const float* spectrum_data =
spectrum_flat + (audio_channel * fbank_samples * spectrum_channels) +
(fbank_sample * spectrum_channels);
const float* framepow_data = framepow_flat + fbank_sample;
std::vector<double> mfcc_input(sample_data,
sample_data + fbank_channels);
std::vector<double> spectrum_input(spectrum_data,
spectrum_data + spectrum_channels);
std::vector<double> framepow_input(framepow_data, framepow_data + 1);
std::vector<double> mfcc_output;
mfcc.Compute(mfcc_input, &mfcc_output);
DCHECK_EQ(coefficient_count_, mfcc_output.size());
Expand All @@ -103,10 +98,10 @@ class MfccDctOp : public OpKernel {
output_data[i] = mfcc_output[i];
}
if (use_energy_)
output_data[0] = spectrum_input[0];
output_data[0] = framepow_input[0];

std::vector<double>().swap(mfcc_input);
std::vector<double>().swap(spectrum_input);
std::vector<double>().swap(framepow_input);
std::vector<double>().swap(mfcc_output);
}
}
Expand Down
6 changes: 6 additions & 0 deletions core/ops/kernels/mfcc_mel_filterbank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ namespace tensorflow {

MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {}

MfccMelFilterbank::~MfccMelFilterbank() {
std::vector<double>().swap(center_frequencies_);
std::vector<double>().swap(weights_);
std::vector<int>().swap(band_mapper_);
}

bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate,
int output_channel_count,
double lower_frequency_limit,
Expand Down
1 change: 1 addition & 0 deletions core/ops/kernels/mfcc_mel_filterbank.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace tensorflow {
class MfccMelFilterbank {
public:
MfccMelFilterbank();
~MfccMelFilterbank();
bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1.
double input_sample_rate, int output_channel_count,
double lower_frequency_limit, double upper_frequency_limit);
Expand Down
3 changes: 1 addition & 2 deletions core/ops/kernels/resample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ LinearResample::LinearResample(int samp_rate_in_hz,
assert(samp_rate_in_hz > 0.0 &&
samp_rate_out_hz > 0.0 &&
filter_cutoff_hz > 0.0 &&
filter_cutoff_hz*2 <= samp_rate_in_hz &&
filter_cutoff_hz*2 <= samp_rate_out_hz &&
num_zeros > 0);

Expand All @@ -56,7 +55,7 @@ int LinearResample::GetNumOutputSamples(int input_num_samp,

// work out the number of ticks in the time interval
// [ 0, input_num_samp/samp_rate_in_ ).
int interval_length_in_ticks = input_num_samp * ticks_per_input_period;
long long interval_length_in_ticks = (long long)input_num_samp * (long long)ticks_per_input_period;
if (!flush) {
BaseFloat window_width = num_zeros_ / (2.0 * filter_cutoff_);
int window_width_ticks = floor(window_width * tick_freq);
Expand Down
4 changes: 4 additions & 0 deletions core/ops/kernels/resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ limitations under the License.
#include <vector>
#include <assert.h>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/logging.h"
using namespace tensorflow; // NOLINT

using namespace std;
#include "kernels/support_functions.h"

Expand Down
Loading

0 comments on commit b85dfee

Please sign in to comment.