diff --git a/include/huffman.hpp b/include/huffman.hpp index 7233f6e6a..a943256d4 100644 --- a/include/huffman.hpp +++ b/include/huffman.hpp @@ -166,11 +166,11 @@ template class HuffmanCode { HuffmanCode(const std::pair &endpoints, const It begin, const It end); - //! Smallest and largest symbols (inclusive) to receive codewords. + //! Smallest and largest symbols (inclusive) eligible for codewords. std::pair endpoints; - //! Number of symbols that will be assigned codewords (including one for the - //! 'missed' symbol). + //! Number of symbols eligible for codewords (including one for the 'missed' + //! symbol). std::size_t ncodewords; //! Frequencies of the symbols in the input stream. @@ -179,10 +179,16 @@ template class HuffmanCode { //! Codewords associated to the symbols. std::vector codewords; + //! Report the number of symbols in the stream. + std::size_t nsymbols() const; + //! Report the number of out-of-range symbols encountered in the stream or //! given in the frequency table pairs. std::size_t nmissed() const; + //! Report the size in bits of the encoded stream. + std::size_t nbits_hit() const; + //! Check whether a symbol is eligible for a codeword. bool out_of_range(const Symbol symbol) const; diff --git a/include/huffman.tpp b/include/huffman.tpp index 64e506051..29239ac0f 100644 --- a/include/huffman.tpp +++ b/include/huffman.tpp @@ -149,10 +149,23 @@ HuffmanCode::HuffmanCode(const std::pair &endpoints, recursively_set_codewords(queue.top(), {}); } +template std::size_t HuffmanCode::nsymbols() const { + return std::accumulate(frequencies.begin(), frequencies.end(), + static_cast(0)); +} + template std::size_t HuffmanCode::nmissed() const { return frequencies.at(0); } +template std::size_t HuffmanCode::nbits_hit() const { + std::size_t nbits = 0; + for (std::size_t i = 0; i < ncodewords; ++i) { + nbits += frequencies.at(i) * codewords.at(i).length; + } + return nbits; +} + template bool HuffmanCode::out_of_range(const Symbol symbol) const { return symbol < endpoints.first or symbol > endpoints.second; @@ -403,13 +416,7 @@ MemoryBuffer huffman_encode(Symbol const *const begin, const std::size_t n) { const HuffmanCode code(begin, begin + n); - std::vector lengths; - for (const HuffmanCodeword &codeword : code.codewords) { - lengths.push_back(codeword.length); - } - const std::size_t nbits = - std::inner_product(code.frequencies.begin(), code.frequencies.end(), - lengths.begin(), static_cast(0)); + const std::size_t nbits = code.nbits_hit(); const std::size_t nbytes_hit = (nbits + CHAR_BIT - 1) / CHAR_BIT; pb::HuffmanHeader header; @@ -568,11 +575,8 @@ MemoryBuffer huffman_decode(const MemoryBuffer &buffer) { const HuffmanCode code(endpoints, chained_frequency_supertable.begin(), chained_frequency_supertable.end()); - // TODO: Maybe add a member function for this. - const std::size_t nout = - std::accumulate(code.frequencies.begin(), code.frequencies.end(), - static_cast(0)); - MemoryBuffer out(nout); + const std::size_t nsymbols = code.nsymbols(); + MemoryBuffer out(nsymbols); Symbol *q = out.data.get(); const Bits bits(window.current, window.current + nbits / CHAR_BIT, @@ -581,7 +585,7 @@ MemoryBuffer huffman_decode(const MemoryBuffer &buffer) { const typename HuffmanCode::Node root = code.queue.top(); assert(root); Bits::iterator b = bits.begin(); - for (std::size_t i = 0; i < nout; ++i) { + for (std::size_t i = 0; i < nsymbols; ++i) { typename HuffmanCode::Node node; for (node = root; node->left; node = *b++ ? node->right : node->left, ++nbits_read) diff --git a/src/huffman.cpp b/src/huffman.cpp index a9f9fbca1..2f771c396 100644 --- a/src/huffman.cpp +++ b/src/huffman.cpp @@ -248,13 +248,7 @@ HuffmanEncodedStream huffman_encoding(long int const *const quantized_data, const HuffmanCode code(nql_endpoints, quantized_data, quantized_data + n); - std::vector lengths; - for (const HuffmanCodeword &codeword : code.codewords) { - lengths.push_back(codeword.length); - } - const std::size_t nbits = - std::inner_product(code.frequencies.begin(), code.frequencies.end(), - lengths.begin(), static_cast(0)); + const std::size_t nbits = code.nbits_hit(); const std::size_t nnz = code.ncodewords - std::count(code.frequencies.begin(), code.frequencies.end(), 0);