Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CTC Beam Search Decoder (KenLM Lexicon) #2072

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ endif()

project(torchaudio)

set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")

# check and set CMAKE_CXX_STANDARD
string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard)
Expand Down Expand Up @@ -59,6 +60,7 @@ endif()
option(BUILD_SOX "Build libsox statically" ON)
option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_CTC_DECODER "Build Flashlight decoder" OFF)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF)
option(USE_ROCM "Enable ROCM support" OFF)
Expand Down
95 changes: 95 additions & 0 deletions cmake/Findkenlm.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Try to find the KenLM library
#
# The following variables are optionally searched for defaults
# KENLM_ROOT: Base directory where all KENLM components are found
#
# The following are set after configuration is done:
# KENLM_FOUND
# KENLM_LIBRARIES
# KENLM_INCLUDE_DIRS
# KENLM_INCLUDE_DIRS_LM
#

message(STATUS "Looking for KenLM")

# Required for KenLM to read ARPA files in compressed format
find_package(LibLZMA REQUIRED)
find_package(BZip2 REQUIRED)
find_package(ZLIB REQUIRED)

find_library(
KENLM_LIB
kenlm
HINTS
${KENLM_ROOT}/lib
${KENLM_ROOT}/build/lib
PATHS
$ENV{KENLM_ROOT}/lib
$ENV{KENLM_ROOT}/build/lib
)

find_library(
KENLM_UTIL_LIB
kenlm_util
HINTS
${KENLM_ROOT}/lib
${KENLM_ROOT}/build/lib
PATHS
$ENV{KENLM_ROOT}/lib
$ENV{KENLM_ROOT}/build/lib
)

if(KENLM_LIB)
message(STATUS "Using kenlm library found in ${KENLM_LIB}")
else()
message(STATUS "kenlm library not found; if you already have kenlm installed, please set CMAKE_LIBRARY_PATH, KENLM_LIB or KENLM_ROOT environment variable")
endif()

if(KENLM_UTIL_LIB)
message(STATUS "Using kenlm utils library found in ${KENLM_UTIL_LIB}")
else()
message(STATUS "kenlm utils library not found; if you already have kenlm installed, please set CMAKE_LIBRARY_PATH, KENLM_UTIL_LIB or KENLM_ROOT environment variable")
endif()

# find a model header, then get the entire include directory. We need to do this because
# cmake consistently confuses other things along this path
find_path(KENLM_MODEL_HEADER
model.hh
PATH_SUFFIXES
kenlm/lm
include/kenlm/lm
HINTS
${KENLM_ROOT}/lm
${KENLM_ROOT}/include/kenlm/lm
PATHS
$ENV{KENLM_ROOT}/lm
$ENV{KENLM_ROOT}/include/kenlm/lm
)

if(KENLM_MODEL_HEADER)
message(STATUS "kenlm model.hh found in ${KENLM_MODEL_HEADER}")

get_filename_component(KENLM_INCLUDE_LM ${KENLM_MODEL_HEADER} DIRECTORY)
get_filename_component(KENLM_INCLUDE_DIR ${KENLM_INCLUDE_LM} DIRECTORY)
else()
message(STATUS "kenlm model.hh not found; if you already have kenlm installed, please set CMAKE_INCLUDE_PATH, KENLM_MODEL_HEADER or KENLM_ROOT environment variable")
endif()

set(KENLM_LIBRARIES
${KENLM_LIB}
${KENLM_UTIL_LIB}
${LIBLZMA_LIBRARIES}
${BZIP2_LIBRARIES}
${ZLIB_LIBRARIES}
)
# Some KenLM include paths are relative to [include dir]/kenlm, not just [include dir] (bad)
set(KENLM_INCLUDE_DIRS_LM ${KENLM_INCLUDE_LM})
set(KENLM_INCLUDE_DIRS ${KENLM_INCLUDE_DIR})

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(kenlm DEFAULT_MSG KENLM_INCLUDE_DIRS KENLM_LIBRARIES)

if (kenlm_FOUND)
message(STATUS "Found kenlm (include: ${KENLM_INCLUDE_DIRS}, library: ${KENLM_LIBRARIES})")
mark_as_advanced(KENLM_ROOT KENLM_INCLUDE_DIRS KENLM_LIBRARIES)
endif()
1 change: 1 addition & 0 deletions examples/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")

SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")
SET(BUILD_RNNT ON CACHE BOOL "Build RNN transducer into libtorchaudio")
SET(BUILD_CTC_DECODER OFF CACHE BOOL "Build Flashlight decoder into libtorchaudio")
SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding")

find_package(Torch REQUIRED)
Expand Down
1 change: 1 addition & 0 deletions examples/libtorchaudio/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cmake -GNinja \
-DBUILD_SOX=ON \
-DBUILD_KALDI=OFF \
-DBUILD_RNNT=ON \
-DBUILD_CTC_DECODER=OFF \
..
cmake --build .
```
Expand Down
16 changes: 16 additions & 0 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,20 @@ if (BUILD_KALDI)
list(APPEND TORCHAUDIO_THIRD_PARTIES kaldi)
endif()

################################################################################
# KenLM
################################################################################
if (BUILD_CTC_DECODER)
find_package(kenlm)
if (NOT kenlm_FOUND)
message(FATAL_ERROR "KenLM not found - Please install KenLM and set KENLM_ROOT.")
endif()

add_library(kenlm INTERFACE)
add_subdirectory(kenlm)
target_include_directories(kenlm INTERFACE ${KENLM_INCLUDE_DIRS})
target_link_libraries(kenlm INTERFACE ${KENLM_LIBRARIES} -lbz2 -lz)
list(APPEND TORCHAUDIO_THIRD_PARTIES kenlm)
endif()

set_property(GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES "${TORCHAUDIO_THIRD_PARTIES}")
31 changes: 31 additions & 0 deletions third_party/kenlm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
include (ExternalProject)

set(INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../install)
set(ARCHIVE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/archives)

set(KENLM_MAX_ORDER 6 CACHE STRING "Maximum ngram order for KenLM")

ExternalProject_Add(kenlm_
PREFIX kenlm
DOWNLOAD_DIR ${ARCHIVE_DIR}
GIT_REPOSITORY https://github.com/kpu/kenlm.git
GIT_TAG 4a277534fd33da323205e6ec256e8fd0ff6ee6fa
BUILD_IN_SOURCE 1
BUILD_COMMAND ${CMAKE_COMMAND} --build .
CMAKE_CACHE_ARGS
-DBUILD_SHARED_LIBS:BOOL=${BUILD_SHARED_LIBS}
-DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}
-DCMAKE_INSTALL_PREFIX:PATH=${INSTALL_DIR}
BUILD_BYPRODUCTS ${KENLM_LIBRARIES}
DOWNLOAD_NO_PROGRESS ON
LOG_DOWNLOAD ON
LOG_UPDATE ON
LOG_CONFIGURE ON
LOG_BUILD ON
LOG_INSTALL ON
LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
add_dependencies(kenlm kenlm_)
set(KENLM_INCLUDE_DIRS ${KENLM_INCLUDE_DIRS} PARENT_SCOPE)
set(KENLM_LIBRARIES ${KENLM_LIBRARIES} PARENT_SCOPE)
11 changes: 10 additions & 1 deletion tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _get_build(var, default=False):
_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX", True)
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_RNNT = _get_build("BUILD_RNNT", True)
_BUILD_CTC_DECODER = _get_build("BUILD_CTC_DECODER", False)
_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None)
_USE_OPENMP = _get_build("USE_OPENMP", True) and \
Expand All @@ -45,11 +46,18 @@ def _get_build(var, default=False):


def get_ext_modules():
return [
modules = [
Extension(name='torchaudio.lib.libtorchaudio', sources=[]),
Extension(name='torchaudio._torchaudio', sources=[]),
]

if _BUILD_CTC_DECODER:
modules.extend([
Extension(name='torchaudio.lib.libtorchaudio_decoder', sources=[]),
Extension(name='torchaudio._torchaudio_decoder', sources=[]),
])
return modules


# Based off of
# https://github.com/pybind/cmake_example/blob/580c5fd29d4651db99d8874714b07c0c49a53f8a/setup.py
Expand Down Expand Up @@ -89,6 +97,7 @@ def build_extension(self, ext):
f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}",
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}",
f"-DBUILD_CTC_DECODER:BOOL={'ON' if _BUILD_CTC_DECODER else 'OFF'}",
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
Expand Down
131 changes: 85 additions & 46 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,43 @@ define_library(
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
)


################################################################################
# libtorchaudio_decoder.so
################################################################################
if (BUILD_CTC_DECODER)
set(
LIBTORCHAUDIO_DECODER_SOURCES
decoder/src/decoder/LexiconDecoder.cpp
decoder/src/decoder/Trie.cpp
decoder/src/decoder/Utils.cpp
decoder/src/decoder/lm/KenLM.cpp
decoder/src/dictionary/String.cpp
decoder/src/dictionary/System.cpp
decoder/src/dictionary/Dictionary.cpp
decoder/src/dictionary/Utils.cpp
)
set(
LIBTORCHAUDIO_DECODER_DEFINITIONS
BUILD_CTC_DECODER
KENLM_MAX_ORDER=${KENLM_MAX_ORDER}
)
set(
LIBTORCHAUDIO_DECODER_DEPS
libtorchaudio
${KENLM_LIBRARIES}
)

define_library(
libtorchaudio_decoder
"${LIBTORCHAUDIO_DECODER_SOURCES}"
"${PROJECT_SOURCE_DIR};$ENV{KENLM_ROOT}/include;$ENV{KENLM_ROOT}/include/kenlm"
"${LIBTORCHAUDIO_DECODER_DEPS}"
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS};${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()

# TODO: Add libtorchaudio_decoder
if (APPLE)
set(TORCHAUDIO_LIBRARY libtorchaudio CACHE INTERNAL "")
else()
Expand All @@ -136,6 +173,39 @@ endif()
# _torchaudio.so
################################################################################
if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
# See https://github.com/pytorch/pytorch/issues/38122
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
if (WIN32)
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
set(ADDITIONAL_ITEMS Python3::Python)
endif()
function(define_extension name sources libraries definitions)
add_library(${name} SHARED ${sources})
target_compile_definitions(${name} PRIVATE "${definitions}")
target_include_directories(
${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR})
target_link_libraries(
${name}
${libraries}
${TORCH_PYTHON_LIBRARY}
${ADDITIONAL_ITEMS}
)
set_target_properties(${name} PROPERTIES PREFIX "")
if (MSVC)
set_target_properties(${name} PROPERTIES SUFFIX ".pyd")
endif(MSVC)
if (APPLE)
# https://github.com/facebookarchive/caffe2/issues/854#issuecomment-364538485
# https://github.com/pytorch/pytorch/commit/73f6715f4725a0723d8171d3131e09ac7abf0666
set_target_properties(${name} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()
install(
TARGETS ${name}
LIBRARY DESTINATION .
RUNTIME DESTINATION . # For Windows
)
endfunction()

set(
EXTENSION_SOURCES
pybind/pybind.cpp
Expand All @@ -150,53 +220,22 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
pybind/sox/utils.cpp
)
endif()
add_library(
_torchaudio
SHARED
${EXTENSION_SOURCES}
)

target_compile_definitions(
_torchaudio
PRIVATE ${LIBTORCHAUDIO_COMPILE_DEFINITIONS}
)

set_target_properties(_torchaudio PROPERTIES PREFIX "")
if (MSVC)
set_target_properties(_torchaudio PROPERTIES SUFFIX ".pyd")
endif(MSVC)

if (APPLE)
# https://github.com/facebookarchive/caffe2/issues/854#issuecomment-364538485
# https://github.com/pytorch/pytorch/commit/73f6715f4725a0723d8171d3131e09ac7abf0666
set_target_properties(_torchaudio PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()

target_include_directories(
_torchaudio
PRIVATE
${PROJECT_SOURCE_DIR}
${Python_INCLUDE_DIR}
)

# See https://github.com/pytorch/pytorch/issues/38122
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")

if (WIN32)
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
set(ADDITIONAL_ITEMS Python3::Python)
endif()

target_link_libraries(
define_extension(
_torchaudio
"${EXTENSION_SOURCES}"
libtorchaudio
${TORCH_PYTHON_LIBRARY}
${ADDITIONAL_ITEMS}
)

install(
TARGETS _torchaudio
LIBRARY DESTINATION .
RUNTIME DESTINATION . # For Windows
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
)
if(BUILD_CTC_DECODER)
set(
DECODER_EXTENSION_SOURCES
decoder/bindings/pybind.cpp
)
define_extension(
_torchaudio_decoder
"${DECODER_EXTENSION_SOURCES}"
"libtorchaudio;libtorchaudio_decoder"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
endif()
Loading