From 947bd287b94fba3a7dbe8cfb17028cf7c8986b76 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Sat, 7 Aug 2021 11:51:32 -0700 Subject: [PATCH] [microTVM] Project API infrastructure (#8380) * Initial commit of API server impl. * initial commit of api client * Add TVM-side glue code to use Project API * Change tvm.micro.Session to use Project API * Rework how crt_config.h is used on the host. * use template crt_config.h for host test runtime; delete src/runtime/crt/host/crt_config.h so that it doesn't diverge from the template * bring template crt_config.h inline with the one actually in use * rename to MAX_STRLEN_DLTYPE * Create a dedicated TVM-side host crt_config.h in src/runtime/micro * Modify Transport infrastructure to work with Project API * Add host microTVM API server * Zephyr implementation of microTVM API server * move all zephyr projects to apps/microtvm/zephyr/template_project * consolidate CcompilerAnnotator * Allow model library format with c backend, add test. * Update unit tests * fix incorrect doc * Delete old Zephyr build infrastructure * Delete old build abstractions * Delete old Transport implementations and simplify module * lint * ASF header * address gromero comments * final fixes? * fix is_shutdown * fix user-facing API * fix TempDirectory / operator * Update micro_tflite tutorial * lint * fix test_crt and test_link_params * undo global micro import, hopefully fix fixture * lint * fix more tests * Address tmoreau89 comments and mehrdadh comments * fix random number generator prj.conf for physical hw * uncomment proper aot option --- apps/bundle_deploy/crt_config/crt_config.h | 4 +- apps/microtvm/zephyr/aot_demo/CMakeLists.txt | 27 - apps/microtvm/zephyr/aot_demo/README.md | 20 - .../zephyr/aot_demo/boards/mps2_an521.conf | 28 - .../boards/nrf5340dk_nrf5340_cpuapp.conf | 34 - .../zephyr/aot_demo/boards/nucleo_l4r5zi.conf | 31 - .../aot_demo/boards/qemu_cortex_r5.conf | 25 - .../zephyr/aot_demo/boards/qemu_x86.conf | 28 - .../microtvm/zephyr/aot_demo/crt/crt_config.h | 62 -- apps/microtvm/zephyr/aot_demo/prj.conf | 32 - apps/microtvm/zephyr/aot_demo/qemu-hack | 1 - .../zephyr/host_driven/CMakeLists.txt | 26 - .../zephyr/host_driven/boards/mps2_an521.conf | 28 - .../boards/nrf5340dk_nrf5340_cpuapp.conf | 34 - .../host_driven/boards/nucleo_f746zg.conf | 33 - .../host_driven/boards/nucleo_l4r5zi.conf | 31 - .../host_driven/boards/qemu_cortex_r5.conf | 25 - .../host_driven/boards/qemu_riscv32.conf | 32 - .../host_driven/boards/qemu_riscv64.conf | 28 - .../zephyr/host_driven/boards/qemu_x86.conf | 25 - .../host_driven/boards/stm32f746g_disco.conf | 31 - apps/microtvm/zephyr/host_driven/prj.conf | 32 - apps/microtvm/zephyr/host_driven/qemu-hack | 1 - .../template_project/CMakeLists.txt.template | 49 ++ .../README.md | 0 .../crt_config}/crt_config.h | 2 +- .../template_project/microtvm_api_server.py | 716 ++++++++++++++++ .../qemu-hack/qemu-system-arm | 0 .../qemu-hack/qemu-system-i386 | 4 +- .../qemu-hack/qemu-system-riscv32 | 0 .../qemu-hack/qemu-system-riscv64 | 0 .../qemu-hack/qemu-system-xilinx-aarch64 | 0 .../src/aot_demo}/main.c | 3 +- .../src/aot_demo}/zephyr_uart.c | 0 .../src/aot_demo}/zephyr_uart.h | 0 .../src/host_driven}/main.c | 0 cmake/modules/StandaloneCrt.cmake | 7 +- include/tvm/runtime/crt/rpc_common/framing.h | 2 +- python/tvm/contrib/utils.py | 2 +- python/tvm/micro/__init__.py | 11 +- python/tvm/micro/artifact.py | 295 ------- python/tvm/micro/build.py | 210 ----- python/tvm/micro/compiler.py | 361 -------- python/tvm/micro/contrib/__init__.py | 16 - python/tvm/micro/contrib/base.py | 67 -- python/tvm/micro/contrib/zephyr.py | 789 ------------------ python/tvm/micro/interface_api.py | 8 +- python/tvm/micro/micro_binary.py | 65 -- python/tvm/micro/micro_library.py | 93 --- python/tvm/micro/model_library_format.py | 9 +- python/tvm/micro/project.py | 151 ++++ python/tvm/micro/project_api/client.py | 235 ++++++ python/tvm/micro/project_api/server.py | 776 +++++++++++++++++ python/tvm/micro/session.py | 24 +- .../micro/{transport/base.py => transport.py} | 50 +- python/tvm/micro/transport/__init__.py | 27 - python/tvm/micro/transport/debug.py | 64 -- python/tvm/micro/transport/file_descriptor.py | 119 --- python/tvm/micro/transport/serial.py | 135 --- python/tvm/micro/transport/subprocess.py | 67 -- python/tvm/micro/transport/wakeup.py | 79 -- python/tvm/relay/testing/byoc.py | 76 ++ src/runtime/crt/crt_config-template.h | 11 +- .../crt/graph_executor/graph_executor.c | 27 +- src/runtime/crt/host/Makefile | 76 ++ src/runtime/crt/host/microtvm_api_server.py | 200 +++++ .../crt/microtvm_rpc_common/framing.cc | 20 + src/runtime/{crt/host => micro}/crt_config.h | 10 +- src/runtime/micro/micro_session.cc | 15 +- tests/lint/check_file_type.py | 35 +- tests/micro/zephyr/conftest.py | 29 +- tests/micro/zephyr/test_zephyr.py | 256 +++--- tests/micro/zephyr/test_zephyr_aot.py | 241 +++--- tests/python/relay/aot/aot_test.mk | 26 +- tests/python/relay/aot/aot_test_utils.py | 88 +- tests/python/relay/aot/test_crt_aot.py | 77 +- .../python/relay/test_pass_partition_graph.py | 56 +- tests/python/unittest/test_crt.py | 69 +- tests/python/unittest/test_link_params.py | 37 +- tests/python/unittest/test_micro_artifact.py | 149 ---- .../test_micro_model_library_format.py | 65 ++ .../python/unittest/test_micro_project_api.py | 424 ++++++++++ tests/python/unittest/test_micro_transport.py | 12 +- tutorials/micro/micro_tflite.py | 121 ++- 84 files changed, 3369 insertions(+), 3805 deletions(-) delete mode 100644 apps/microtvm/zephyr/aot_demo/CMakeLists.txt delete mode 100644 apps/microtvm/zephyr/aot_demo/README.md delete mode 100644 apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf delete mode 100644 apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf delete mode 100644 apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf delete mode 100644 apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf delete mode 100644 apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf delete mode 100644 apps/microtvm/zephyr/aot_demo/crt/crt_config.h delete mode 100644 apps/microtvm/zephyr/aot_demo/prj.conf delete mode 120000 apps/microtvm/zephyr/aot_demo/qemu-hack delete mode 100644 apps/microtvm/zephyr/host_driven/CMakeLists.txt delete mode 100644 apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf delete mode 100644 apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf delete mode 100644 apps/microtvm/zephyr/host_driven/prj.conf delete mode 120000 apps/microtvm/zephyr/host_driven/qemu-hack create mode 100644 apps/microtvm/zephyr/template_project/CMakeLists.txt.template rename apps/microtvm/zephyr/{host_driven => template_project}/README.md (100%) rename apps/microtvm/zephyr/{host_driven/crt => template_project/crt_config}/crt_config.h (97%) create mode 100644 apps/microtvm/zephyr/template_project/microtvm_api_server.py rename apps/microtvm/zephyr/{ => template_project}/qemu-hack/qemu-system-arm (100%) rename apps/microtvm/zephyr/{ => template_project}/qemu-hack/qemu-system-i386 (91%) rename apps/microtvm/zephyr/{ => template_project}/qemu-hack/qemu-system-riscv32 (100%) rename apps/microtvm/zephyr/{ => template_project}/qemu-hack/qemu-system-riscv64 (100%) rename apps/microtvm/zephyr/{ => template_project}/qemu-hack/qemu-system-xilinx-aarch64 (100%) rename apps/microtvm/zephyr/{aot_demo/src => template_project/src/aot_demo}/main.c (97%) rename apps/microtvm/zephyr/{aot_demo/src => template_project/src/aot_demo}/zephyr_uart.c (100%) rename apps/microtvm/zephyr/{aot_demo/include => template_project/src/aot_demo}/zephyr_uart.h (100%) rename apps/microtvm/zephyr/{host_driven/src => template_project/src/host_driven}/main.c (100%) delete mode 100644 python/tvm/micro/artifact.py delete mode 100644 python/tvm/micro/compiler.py delete mode 100644 python/tvm/micro/contrib/__init__.py delete mode 100644 python/tvm/micro/contrib/base.py delete mode 100644 python/tvm/micro/contrib/zephyr.py delete mode 100644 python/tvm/micro/micro_binary.py delete mode 100644 python/tvm/micro/micro_library.py create mode 100644 python/tvm/micro/project.py create mode 100644 python/tvm/micro/project_api/client.py create mode 100644 python/tvm/micro/project_api/server.py rename python/tvm/micro/{transport/base.py => transport.py} (84%) delete mode 100644 python/tvm/micro/transport/__init__.py delete mode 100644 python/tvm/micro/transport/debug.py delete mode 100644 python/tvm/micro/transport/file_descriptor.py delete mode 100644 python/tvm/micro/transport/serial.py delete mode 100644 python/tvm/micro/transport/subprocess.py delete mode 100644 python/tvm/micro/transport/wakeup.py create mode 100644 python/tvm/relay/testing/byoc.py create mode 100644 src/runtime/crt/host/Makefile create mode 100644 src/runtime/crt/host/microtvm_api_server.py rename src/runtime/{crt/host => micro}/crt_config.h (90%) delete mode 100644 tests/python/unittest/test_micro_artifact.py create mode 100644 tests/python/unittest/test_micro_project_api.py diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index 11086c0e9a15..58f923512d2e 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -35,9 +35,9 @@ /*! Maximum supported arguments in generated functions */ #define TVM_CRT_MAX_ARGS 10 /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_STRLEN_DLTYPE 10 +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_STRLEN_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/zephyr/aot_demo/CMakeLists.txt b/apps/microtvm/zephyr/aot_demo/CMakeLists.txt deleted file mode 100644 index d7ec2a25db14..000000000000 --- a/apps/microtvm/zephyr/aot_demo/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -cmake_minimum_required(VERSION 3.13.1) - -set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") - -set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. - -find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) -project(microtvm_zephyr_runtime) - -set(CMAKE_VERBOSE_MAKEFILE ON) - -target_sources(app PRIVATE src/zephyr_uart.c) -target_sources(app PRIVATE src/main.c) - -foreach(tvm_lib ${TVM_LIBS}) - string(LENGTH ${tvm_lib} tvm_lib_length) - math(EXPR tvm_lib_cut "${tvm_lib_length} - 2") - string(SUBSTRING ${tvm_lib} 3 ${tvm_lib_cut} tvm_lib_name) - add_library(${tvm_lib_name} STATIC IMPORTED) - set_target_properties(${tvm_lib_name} PROPERTIES - IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/${tvm_lib}) - target_link_libraries(app PRIVATE ${tvm_lib_name}) -endforeach(tvm_lib ${TVM_LIBS}) - -target_include_directories(app PRIVATE ${TVM_INCLUDE_DIRS}) diff --git a/apps/microtvm/zephyr/aot_demo/README.md b/apps/microtvm/zephyr/aot_demo/README.md deleted file mode 100644 index a718da65e2fa..000000000000 --- a/apps/microtvm/zephyr/aot_demo/README.md +++ /dev/null @@ -1,20 +0,0 @@ - - - - - - - - - - - - - - - - - -This directory contains a Zephyr-based ahead of time (AOT) "demo" runtime environment that -pulls together the microTVM runtime dependencies into a single application -that can run TVM on a microTVM device without the need to a host. diff --git a/apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf b/apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf deleted file mode 100644 index 3916b17c49cf..000000000000 --- a/apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the MPS2-AN512 board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=n diff --git a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf deleted file mode 100644 index 6c588c86b0d5..000000000000 --- a/apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the nRF5340 DK board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For AOT runtime which requires lots of function call. -CONFIG_MAIN_STACK_SIZE=2000 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=y - -# For models with floating point. -CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf b/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf deleted file mode 100644 index 52a6753c733b..000000000000 --- a/apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the STM32L4R5ZI Nucleo board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For AOT runtime which requires lots of function call. -CONFIG_MAIN_STACK_SIZE=3000 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=y diff --git a/apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf b/apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf deleted file mode 100644 index 267589ba8f0c..000000000000 --- a/apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated Cortex R5 microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default stack size is 1k, this is required for debug mode. -CONFIG_MAIN_STACK_SIZE=2000 diff --git a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf b/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf deleted file mode 100644 index 4b0e494068fa..000000000000 --- a/apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default stack size is 1k, this is required for debug mode. -CONFIG_MAIN_STACK_SIZE=2000 - -# For models with floating point. -CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/aot_demo/crt/crt_config.h b/apps/microtvm/zephyr/aot_demo/crt/crt_config.h deleted file mode 100644 index 9ee315aa1763..000000000000 --- a/apps/microtvm/zephyr/aot_demo/crt/crt_config.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/runtime/crt_config.h.template - * \brief Template for CRT configuration, to be modified on each target. - */ -#ifndef TVM_RUNTIME_CRT_CONFIG_H_ -#define TVM_RUNTIME_CRT_CONFIG_H_ - -#include - -/*! Log level of the CRT runtime */ -#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG - -/*! Maximum supported dimension in NDArray */ -#define TVM_CRT_MAX_NDIM 6 - -/*! Maximum supported arguments in generated functions */ -#define TVM_CRT_MAX_ARGS 10 - -/*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 - -/*! Maximum number of registered modules. */ -#define TVM_CRT_MAX_REGISTERED_MODULES 2 - -/*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES (1 * 1024) - -/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_MAX_STRLEN_DLTYPE 10 - -/*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 - -/*! \brief Maximum length of a PackedFunc function name. */ -#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 - -/*! \brief Log2 of the page size (bytes) for a virtual memory page. */ -#define TVM_CRT_PAGE_BITS 10 // 1 kB - -/*! \brief Number of pages on device. */ -#define TVM_CRT_MAX_PAGES 300 - -#endif // TVM_RUNTIME_CRT_CONFIG_H_ diff --git a/apps/microtvm/zephyr/aot_demo/prj.conf b/apps/microtvm/zephyr/aot_demo/prj.conf deleted file mode 100644 index c6ab10e9d86e..000000000000 --- a/apps/microtvm/zephyr/aot_demo/prj.conf +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# The settings in this file are generic for all boards, and are merged -# with the settings in the file boards/.conf by the Zephyr build -# process. - -# For UART implementation in main(). -CONFIG_RING_BUFFER=y -CONFIG_UART_CONSOLE=n -CONFIG_UART_INTERRUPT_DRIVEN=y - -# For RPC server C++ bindings. -CONFIG_CPLUSPLUS=y -CONFIG_NEWLIB_LIBC=y - -# For TVMPlatformAbort(). -CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/aot_demo/qemu-hack b/apps/microtvm/zephyr/aot_demo/qemu-hack deleted file mode 120000 index b4810f2aab6e..000000000000 --- a/apps/microtvm/zephyr/aot_demo/qemu-hack +++ /dev/null @@ -1 +0,0 @@ -../qemu-hack \ No newline at end of file diff --git a/apps/microtvm/zephyr/host_driven/CMakeLists.txt b/apps/microtvm/zephyr/host_driven/CMakeLists.txt deleted file mode 100644 index f04a792086cb..000000000000 --- a/apps/microtvm/zephyr/host_driven/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -cmake_minimum_required(VERSION 3.13.1) - -set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") - -set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. - -find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) -project(microtvm_zephyr_runtime) - -set(CMAKE_VERBOSE_MAKEFILE ON) - -target_sources(app PRIVATE src/main.c) - -foreach(tvm_lib ${TVM_LIBS}) - string(LENGTH ${tvm_lib} tvm_lib_length) - math(EXPR tvm_lib_cut "${tvm_lib_length} - 2") - string(SUBSTRING ${tvm_lib} 3 ${tvm_lib_cut} tvm_lib_name) - add_library(${tvm_lib_name} STATIC IMPORTED) - set_target_properties(${tvm_lib_name} PROPERTIES - IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/${tvm_lib}) - target_link_libraries(app PRIVATE ${tvm_lib_name}) -endforeach(tvm_lib ${TVM_LIBS}) - -target_include_directories(app PRIVATE ${TVM_INCLUDE_DIRS}) diff --git a/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf b/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf deleted file mode 100644 index 3916b17c49cf..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the MPS2-AN512 board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=n diff --git a/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf b/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf deleted file mode 100644 index 511ff0121d32..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the nRF5340 DK board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# Required for Cortex-M33 devices. -CONFIG_MAIN_STACK_SIZE=1536 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=y - -# For models with floating point. -CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf deleted file mode 100644 index 33b08032c32e..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the STM32F746 Nucleo board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For operations that stack allocates a large float array. -CONFIG_MAIN_STACK_SIZE=1536 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y - -# For debugging. -CONFIG_LED=y - -# For models with floating point. -CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf b/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf deleted file mode 100644 index b87206019026..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the STM32L4R5ZI Nucleo board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For operations that stack allocates a large float array. -CONFIG_MAIN_STACK_SIZE=1536 - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=y diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf deleted file mode 100644 index 4097f7ec5487..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated Cortex R5 microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default stack size is 1k, this is required for debug mode. -CONFIG_MAIN_STACK_SIZE=1536 diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf deleted file mode 100644 index b94d96b11fba..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated RISCV32 microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default is 512, raised here for operations with large floating point data. -CONFIG_MAIN_STACK_SIZE=2048 - -# For models with floating point. -CONFIG_FPU=y - -# For floating point operations. It has exception on floating point operations -# without this flag. -CONFIG_FPU_SHARING=y diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf deleted file mode 100644 index 1da5f054da46..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated RISCV64 microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default 512, for operations with large floating point data. -CONFIG_MAIN_STACK_SIZE=2048 - -# For models with floating point. -CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf b/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf deleted file mode 100644 index f314f59a597a..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This file is specific to the QEMU-emulated microTVM board. - -# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random. -CONFIG_TEST_RANDOM_GENERATOR=y -CONFIG_TIMER_RANDOM_GENERATOR=y - -# Default stack size is 1k, this is required for debug mode. -CONFIG_MAIN_STACK_SIZE=1536 diff --git a/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf b/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf deleted file mode 100644 index 542faf28cd67..000000000000 --- a/apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# This file is specific to the STM32F746G Discovery board. - -# For intrinsics used by generated optimized operators. -CONFIG_CMSIS_DSP=y - -# For random number generation. -CONFIG_ENTROPY_GENERATOR=y -CONFIG_TEST_RANDOM_GENERATOR=y - -# For debugging. -CONFIG_LED=n - -# For models with floating point. -CONFIG_FPU=y diff --git a/apps/microtvm/zephyr/host_driven/prj.conf b/apps/microtvm/zephyr/host_driven/prj.conf deleted file mode 100644 index c6ab10e9d86e..000000000000 --- a/apps/microtvm/zephyr/host_driven/prj.conf +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# The settings in this file are generic for all boards, and are merged -# with the settings in the file boards/.conf by the Zephyr build -# process. - -# For UART implementation in main(). -CONFIG_RING_BUFFER=y -CONFIG_UART_CONSOLE=n -CONFIG_UART_INTERRUPT_DRIVEN=y - -# For RPC server C++ bindings. -CONFIG_CPLUSPLUS=y -CONFIG_NEWLIB_LIBC=y - -# For TVMPlatformAbort(). -CONFIG_REBOOT=y diff --git a/apps/microtvm/zephyr/host_driven/qemu-hack b/apps/microtvm/zephyr/host_driven/qemu-hack deleted file mode 120000 index b4810f2aab6e..000000000000 --- a/apps/microtvm/zephyr/host_driven/qemu-hack +++ /dev/null @@ -1 +0,0 @@ -../qemu-hack \ No newline at end of file diff --git a/apps/microtvm/zephyr/template_project/CMakeLists.txt.template b/apps/microtvm/zephyr/template_project/CMakeLists.txt.template new file mode 100644 index 000000000000..17e9d75c76e8 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/CMakeLists.txt.template @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.13.1) + +set(ENV{QEMU_BIN_PATH} "${CMAKE_SOURCE_DIR}/qemu-hack") + +set(QEMU_PIPE "\${QEMU_PIPE}") # QEMU_PIPE is set by the calling TVM instance. + +find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) +project(microtvm_autogenerated_project) + +set(CRT_LIBS ) +set(CRT_LIB_BASE crt/src/runtime/crt) +foreach(crt_lib_name ${CRT_LIBS}) + zephyr_library_named(${crt_lib_name}) + file(GLOB_RECURSE crt_lib_srcs ${CRT_LIB_BASE}/${crt_lib_name}/*.c ${CRT_LIB_BASE}/${crt_lib_name}/*.cc) + target_sources(${crt_lib_name} PRIVATE ${crt_lib_srcs}) + zephyr_library_include_directories(${crt_lib_name} PRIVATE crt_config crt/include) + target_link_libraries(app PRIVATE ${crt_lib_name}) +endforeach(crt_lib_name ${CRT_LIBS}) + +# define a library for the model sources. +zephyr_library_named(tvm_model) +file(GLOB_RECURSE tvm_model_srcs model/codegen/host/src/*.c model/codegen/host/lib/*.o) +target_sources(tvm_model PRIVATE ${tvm_model_srcs}) +target_include_directories(tvm_model PRIVATE ${CMAKE_SOURCE_DIR}/include crt_config crt/include) +target_compile_options(tvm_model PRIVATE -Wno-unused-variable) # TVM-generated code tends to include lots of these. +target_link_libraries(app PRIVATE tvm_model) + +file(GLOB_RECURSE app_srcs src/**.c) +target_sources(app PRIVATE ${app_srcs}) +target_include_directories(app PRIVATE crt_config ${CMAKE_SOURCE_DIR}/include crt/include) diff --git a/apps/microtvm/zephyr/host_driven/README.md b/apps/microtvm/zephyr/template_project/README.md similarity index 100% rename from apps/microtvm/zephyr/host_driven/README.md rename to apps/microtvm/zephyr/template_project/README.md diff --git a/apps/microtvm/zephyr/host_driven/crt/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h similarity index 97% rename from apps/microtvm/zephyr/host_driven/crt/crt_config.h rename to apps/microtvm/zephyr/template_project/crt_config/crt_config.h index 658b97e267ba..f8fc7514a28d 100644 --- a/apps/microtvm/zephyr/host_driven/crt/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -42,7 +42,7 @@ #define TVM_CRT_MAX_REGISTERED_MODULES 2 /*! Maximum packet size, in bytes, including the length header. */ -#define TVM_CRT_MAX_PACKET_SIZE_BYTES (4 * 1024) +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8192 /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py new file mode 100644 index 000000000000..8ab6381e73c5 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -0,0 +1,716 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import atexit +import collections +import collections.abc +import enum +import fcntl +import logging +import os +import os.path +import pathlib +import queue +import re +import select +import shlex +import shutil +import subprocess +import sys +import tarfile +import tempfile +import threading +import time + +import serial +import serial.tools.list_ports +import yaml + +from tvm.micro.project_api import server + + +_LOG = logging.getLogger(__name__) + + +API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) + + +BUILD_DIR = API_SERVER_DIR / "build" + + +MODEL_LIBRARY_FORMAT_RELPATH = "model.tar" + + +IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() + + +def check_call(cmd_args, *args, **kwargs): + cwd_str = "" if "cwd" not in kwargs else f" (in cwd: {kwargs['cwd']})" + _LOG.info("run%s: %s", cwd_str, " ".join(shlex.quote(a) for a in cmd_args)) + return subprocess.check_call(cmd_args, *args, **kwargs) + + +CACHE_ENTRY_RE = re.compile(r"(?P[^:]+):(?P[^=]+)=(?P.*)") + + +CMAKE_BOOL_MAP = dict( + [(k, True) for k in ("1", "ON", "YES", "TRUE", "Y")] + + [(k, False) for k in ("0", "OFF", "NO", "FALSE", "N", "IGNORE", "NOTFOUND", "")] +) + + +class CMakeCache(collections.abc.Mapping): + def __init__(self, path): + self._path = path + self._dict = None + + def __iter__(self): + return iter(self._dict) + + def __getitem__(self, key): + if self._dict is None: + self._dict = self._read_cmake_cache() + + return self._dict[key] + + def __len__(self): + return len(self._dict) + + def _read_cmake_cache(self): + """Read a CMakeCache.txt-like file and return a dictionary of values.""" + entries = collections.OrderedDict() + with open(self._path, encoding="utf-8") as f: + for line in f: + m = CACHE_ENTRY_RE.match(line.rstrip("\n")) + if not m: + continue + + if m.group("type") == "BOOL": + value = CMAKE_BOOL_MAP[m.group("value").upper()] + else: + value = m.group("value") + + entries[m.group("name")] = value + + return entries + + +CMAKE_CACHE = CMakeCache(BUILD_DIR / "CMakeCache.txt") + + +class BoardError(Exception): + """Raised when an attached board cannot be opened (i.e. missing /dev nodes, etc).""" + + +class BoardAutodetectFailed(Exception): + """Raised when no attached hardware is found matching the board= given to ZephyrCompiler.""" + + +def _get_flash_runner(): + flash_runner = CMAKE_CACHE.get("ZEPHYR_BOARD_FLASH_RUNNER") + if flash_runner is not None: + return flash_runner + + with open(CMAKE_CACHE["ZEPHYR_RUNNERS_YAML"]) as f: + doc = yaml.load(f, Loader=yaml.FullLoader) + return doc["flash-runner"] + + +def _get_device_args(options): + flash_runner = _get_flash_runner() + + if flash_runner == "nrfjprog": + return _get_nrf_device_args(options) + + if flash_runner == "openocd": + return _get_openocd_device_args(options) + + raise BoardError( + f"Don't know how to find serial terminal for board {CMAKE_CACHE['BOARD']} with flash " + f"runner {flash_runner}" + ) + + +# kwargs passed to usb.core.find to find attached boards for the openocd flash runner. +BOARD_USB_FIND_KW = { + "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, + "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, + "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, +} + + +def openocd_serial(options): + """Find the serial port to use for a board with OpenOCD flash strategy.""" + if "openocd_serial" in options: + return options["openocd_serial"] + + import usb # pylint: disable=import-outside-toplevel + + find_kw = BOARD_USB_FIND_KW[CMAKE_CACHE["BOARD"]] + boards = usb.core.find(find_all=True, **find_kw) + serials = [] + for b in boards: + serials.append(b.serial_number) + + if len(serials) == 0: + raise BoardAutodetectFailed(f"No attached USB devices matching: {find_kw!r}") + serials.sort() + + autodetected_openocd_serial = serials[0] + _LOG.debug("zephyr openocd driver: autodetected serial %s", serials[0]) + + return autodetected_openocd_serial + + +def _get_openocd_device_args(options): + return ["--serial", openocd_serial(options)] + + +def _get_nrf_device_args(options): + nrfjprog_args = ["nrfjprog", "--ids"] + nrfjprog_ids = subprocess.check_output(nrfjprog_args, encoding="utf-8") + if not nrfjprog_ids.strip("\n"): + raise BoardAutodetectFailed(f'No attached boards recognized by {" ".join(nrfjprog_args)}') + + boards = nrfjprog_ids.split("\n")[:-1] + if len(boards) > 1: + if options["nrfjprog_snr"] is None: + raise BoardError( + "Multiple boards connected; specify one with nrfjprog_snr=: " f'{", ".join(boards)}' + ) + + if str(options["nrfjprog_snr"]) not in boards: + raise BoardError( + f"nrfjprog_snr ({options['nrfjprog_snr']}) not found in {nrfjprog_args}: {boards}" + ) + + return ["--snr", options["nrfjprog_snr"]] + + if not boards: + return [] + + return ["--snr", boards[0]] + + +PROJECT_TYPES = [] +if IS_TEMPLATE: + for d in (API_SERVER_DIR / "src").iterdir(): + if d.is_dir(): + PROJECT_TYPES.append(d.name) + + +PROJECT_OPTIONS = [ + server.ProjectOption( + "extra_files", + help="If given, during generate_project, uncompress the tarball at this path into the project dir", + ), + server.ProjectOption( + "gdbserver_port", help=("If given, port number to use when running the local gdbserver") + ), + server.ProjectOption( + "nrfjprog_snr", + help=( + "When used with nRF targets, serial # of the " "attached board to use, from nrfjprog" + ), + ), + server.ProjectOption( + "openocd_serial", + help=("When used with OpenOCD targets, serial # of the " "attached board to use"), + ), + server.ProjectOption( + "project_type", + help="Type of project to generate.", + choices=tuple(PROJECT_TYPES), + ), + server.ProjectOption("verbose", help="Run build with verbose output"), + server.ProjectOption( + "west_cmd", + help=( + "Path to the west tool. If given, supersedes both the zephyr_base " + "option and ZEPHYR_BASE environment variable." + ), + ), + server.ProjectOption("zephyr_base", help="Path to the zephyr base directory."), + server.ProjectOption("zephyr_board", help="Name of the Zephyr board to build for"), +] + + +class Handler(server.ProjectAPIHandler): + def __init__(self): + super(Handler, self).__init__() + self._proc = None + + def server_info_query(self, tvm_version): + return server.ServerInfo( + platform_name="zephyr", + is_template=IS_TEMPLATE, + model_library_format_path="" + if IS_TEMPLATE + else (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH), + project_options=PROJECT_OPTIONS, + ) + + # These files and directories will be recursively copied into generated projects from the CRT. + CRT_COPY_ITEMS = ("include", "Makefile", "src") + + # Maps extra line added to prj.conf to a tuple or list of zephyr_board for which it is needed. + EXTRA_PRJ_CONF_DIRECTIVES = { + "CONFIG_TIMER_RANDOM_GENERATOR=y": ( + "qemu_x86", + "qemu_riscv32", + "qemu_cortex_r5", + "qemu_riscv64", + ), + "CONFIG_ENTROPY_GENERATOR_BOARDS=y": ( + "mps2_an521", + "nrf5340dk_nrf5340_cpuapp", + "nucleo_f746zg", + "nucleo_l4r5zi", + "stm32f746g_disco", + ), + } + + def _create_prj_conf(self, project_dir, options): + with open(project_dir / "prj.conf", "w") as f: + f.write( + "# For UART used from main().\n" + "CONFIG_RING_BUFFER=y\n" + "CONFIG_UART_CONSOLE=n\n" + "CONFIG_UART_INTERRUPT_DRIVEN=y\n" + "\n" + ) + f.write("# For TVMPlatformAbort().\n" "CONFIG_REBOOT=y\n" "\n") + + if options["project_type"] == "host_driven": + f.write("# For RPC server C++ bindings.\n" "CONFIG_CPLUSPLUS=y\n" "\n") + + f.write("# For math routines\n" "CONFIG_NEWLIB_LIBC=y\n" "\n") + + if self._has_fpu(options["zephyr_board"]): + f.write("# For models with floating point.\n" "CONFIG_FPU=y\n" "\n") + + main_stack_size = None + if self._is_qemu(options) and options["project_type"] == "host_driven": + main_stack_size = 1536 + + # Set main stack size, if needed. + if main_stack_size is not None: + f.write(f"CONFIG_MAIN_STACK_SIZE={main_stack_size}\n") + + f.write("# For random number generation.\n" "CONFIG_TEST_RANDOM_GENERATOR=y\n") + + f.write("\n# Extra prj.conf directives") + for line, board_list in self.EXTRA_PRJ_CONF_DIRECTIVES.items(): + if options["zephyr_board"] in board_list: + f.write(f"{line}\n") + + f.write("\n") + + API_SERVER_CRT_LIBS_TOKEN = "" + + CRT_LIBS_BY_PROJECT_TYPE = { + "host_driven": "microtvm_rpc_server microtvm_rpc_common common", + "aot_demo": "aot_executor memory microtvm_rpc_common common", + } + + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + project_dir = pathlib.Path(project_dir) + # Make project directory. + project_dir.mkdir() + + # Copy ourselves to the generated project. TVM may perform further build steps on the generated project + # by launching the copy. + shutil.copy2(__file__, project_dir / os.path.basename(__file__)) + + # Place Model Library Format tarball in the special location, which this script uses to decide + # whether it's being invoked in a template or generated project. + project_model_library_format_tar_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH + shutil.copy2(model_library_format_path, project_model_library_format_tar_path) + + # Extract Model Library Format tarball.into /model. + extract_path = os.path.splitext(project_model_library_format_tar_path)[0] + with tarfile.TarFile(project_model_library_format_tar_path) as tf: + os.makedirs(extract_path) + tf.extractall(path=extract_path) + + if self._is_qemu(options): + shutil.copytree(API_SERVER_DIR / "qemu-hack", project_dir / "qemu-hack") + + # Populate CRT. + crt_path = project_dir / "crt" + crt_path.mkdir() + for item in self.CRT_COPY_ITEMS: + src_path = os.path.join(standalone_crt_dir, item) + dst_path = crt_path / item + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + # Populate Makefile. + with open(API_SERVER_DIR / "CMakeLists.txt.template", "r") as cmake_template_f: + with open(project_dir / "CMakeLists.txt", "w") as cmake_f: + for line in cmake_template_f: + if self.API_SERVER_CRT_LIBS_TOKEN in line: + crt_libs = self.CRT_LIBS_BY_PROJECT_TYPE[options["project_type"]] + line = line.replace("", crt_libs) + + cmake_f.write(line) + + self._create_prj_conf(project_dir, options) + + # Populate crt-config.h + crt_config_dir = project_dir / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + API_SERVER_DIR / "crt_config" / "crt_config.h", crt_config_dir / "crt_config.h" + ) + + # Populate src/ + src_dir = project_dir / "src" + shutil.copytree(API_SERVER_DIR / "src" / options["project_type"], src_dir) + + # Populate extra_files + if options.get("extra_files_tar"): + with tarfile.open(options["extra_files_tar"], mode="r:*") as tf: + tf.extractall(project_dir) + + def build(self, options): + BUILD_DIR.mkdir() + + cmake_args = ["cmake", ".."] + if options.get("verbose"): + cmake_args.append("-DCMAKE_VERBOSE_MAKEFILE:BOOL=TRUE") + + if options.get("zephyr_base"): + cmake_args.append(f"-DZEPHYR_BASE:STRING={options['zephyr_base']}") + + cmake_args.append(f"-DBOARD:STRING={options['zephyr_board']}") + + check_call(cmake_args, cwd=BUILD_DIR) + + args = ["make", "-j2"] + if options.get("verbose"): + args.append("VERBOSE=1") + check_call(args, cwd=BUILD_DIR) + + # A list of all zephyr_board values which are known to launch using QEMU. Many platforms which + # launch through QEMU by default include "qemu" in their name. However, not all do. This list + # includes those tested platforms which do not include qemu. + _KNOWN_QEMU_ZEPHYR_BOARDS = ("mps2_an521",) + + @classmethod + def _is_qemu(cls, options): + return ( + "qemu" in options["zephyr_board"] + or options["zephyr_board"] in cls._KNOWN_QEMU_ZEPHYR_BOARDS + ) + + _KNOWN_FPU_ZEPHYR_BOARDS = ( + "nucleo_f746zg", + "nucleo_l4r5zi", + "nrf5340dk_nrf5340_cpuapp", + "qemu_cortex_r5", + "qemu_riscv32", + "qemu_riscv64", + "qemu_x86", + "stm32f746g_disco", + ) + + @classmethod + def _has_fpu(cls, zephyr_board): + return zephyr_board in cls._KNOWN_FPU_ZEPHYR_BOARDS + + def flash(self, options): + if self._is_qemu(options): + return # NOTE: qemu requires no flash step--it is launched from open_transport. + + zephyr_board = options["zephyr_board"] + + # The nRF5340DK requires an additional `nrfjprog --recover` before each flash cycle. + # This is because readback protection is enabled by default when this device is flashed. + # Otherwise, flashing may fail with an error such as the following: + # ERROR: The operation attempted is unavailable due to readback protection in + # ERROR: your device. Please use --recover to unlock the device. + if zephyr_board.startswith("nrf5340dk") and _get_flash_runner() == "nrfjprog": + recover_args = ["nrfjprog", "--recover"] + recover_args.extend(_get_nrf_device_args(options)) + check_call(recover_args, cwd=API_SERVER_DIR / "build") + + check_call(["make", "flash"], cwd=API_SERVER_DIR / "build") + + def open_transport(self, options): + if self._is_qemu(options): + transport = ZephyrQemuTransport(options) + else: + transport = ZephyrSerialTransport(options) + + to_return = transport.open() + self._transport = transport + atexit.register(lambda: self.close_transport()) + return to_return + + def close_transport(self): + if self._transport is not None: + self._transport.close() + self._transport = None + + def read_transport(self, n, timeout_sec): + if self._transport is None: + raise server.TransportClosedError() + + return self._transport.read(n, timeout_sec) + + def write_transport(self, data, timeout_sec): + if self._transport is None: + raise server.TransportClosedError() + + return self._transport.write(data, timeout_sec) + + +def _set_nonblock(fd): + flag = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) + new_flag = fcntl.fcntl(fd, fcntl.F_GETFL) + assert (new_flag & os.O_NONBLOCK) != 0, "Cannot set file descriptor {fd} to non-blocking" + + +class ZephyrSerialTransport: + @classmethod + def _lookup_baud_rate(cls, options): + zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"]) + sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts")) + try: + import dtlib # pylint: disable=import-outside-toplevel + finally: + sys.path.pop(0) + + dt_inst = dtlib.DT(BUILD_DIR / "zephyr" / "zephyr.dts") + uart_baud = ( + dt_inst.get_node("/chosen") + .props["zephyr,console"] + .to_path() + .props["current-speed"] + .to_num() + ) + _LOG.debug("zephyr transport: found UART baudrate from devicetree: %d", uart_baud) + + return uart_baud + + @classmethod + def _find_nrf_serial_port(cls, options): + com_ports = subprocess.check_output( + ["nrfjprog", "--com"] + _get_device_args(options), encoding="utf-8" + ) + ports_by_vcom = {} + for line in com_ports.split("\n")[:-1]: + parts = line.split() + ports_by_vcom[parts[2]] = parts[1] + + return ports_by_vcom["VCOM2"] + + @classmethod + def _find_openocd_serial_port(cls, options): + serial_number = openocd_serial(options) + ports = [p for p in serial.tools.list_ports.grep(serial_number)] + if len(ports) != 1: + raise Exception( + f"_find_openocd_serial_port: expected 1 port to match {serial_number}, " + f"found: {ports!r}" + ) + + return ports[0].device + + @classmethod + def _find_serial_port(cls, options): + flash_runner = _get_flash_runner() + + if flash_runner == "nrfjprog": + return cls._find_nrf_serial_port(options) + + if flash_runner == "openocd": + return cls._find_openocd_serial_port(options) + + raise FlashRunnerNotSupported( + f"Don't know how to deduce serial port for flash runner {flash_runner}" + ) + + def __init__(self, options): + self._options = options + self._port = None + + def open(self): + port_path = self._find_serial_port(self._options) + self._port = serial.Serial(port_path, baudrate=self._lookup_baud_rate(self._options)) + return server.TransportTimeouts( + session_start_retry_timeout_sec=2.0, + session_start_timeout_sec=5.0, + session_established_timeout_sec=5.0, + ) + + def close(self): + self._port.close() + self._port = None + + def read(self, n, timeout_sec): + self._port.timeout = timeout_sec + to_return = self._port.read(n) + if not to_return: + raise server.IoTimeoutError() + + return to_return + + def write(self, data, timeout_sec): + self._port.write_timeout = timeout_sec + bytes_written = 0 + while bytes_written < len(data): + n = self._port.write(data) + data = data[n:] + bytes_written += n + + +class ZephyrQemuMakeResult(enum.Enum): + QEMU_STARTED = "qemu_started" + MAKE_FAILED = "make_failed" + EOF = "eof" + + +class ZephyrQemuTransport: + """The user-facing Zephyr QEMU transport class.""" + + def __init__(self, options): + self.options = options + self.proc = None + self.pipe_dir = None + self.read_fd = None + self.write_fd = None + self._queue = queue.Queue() + + def open(self): + self.pipe_dir = pathlib.Path(tempfile.mkdtemp()) + self.pipe = self.pipe_dir / "fifo" + self.write_pipe = self.pipe_dir / "fifo.in" + self.read_pipe = self.pipe_dir / "fifo.out" + os.mkfifo(self.write_pipe) + os.mkfifo(self.read_pipe) + + if "gdbserver_port" in self.options: + if "env" in self.kwargs: + self.kwargs["env"] = copy.copy(self.kwargs["env"]) + else: + self.kwargs["env"] = os.environ.copy() + + self.kwargs["env"]["TVM_QEMU_GDBSERVER_PORT"] = str(self.options["gdbserver_port"]) + + self.proc = subprocess.Popen( + ["make", "run", f"QEMU_PIPE={self.pipe}"], + cwd=BUILD_DIR, + stdout=subprocess.PIPE, + ) + self._wait_for_qemu() + + # NOTE: although each pipe is unidirectional, open both as RDWR to work around a select + # limitation on linux. Without this, non-blocking I/O can't use timeouts because named + # FIFO are always considered ready to read when no one has opened them for writing. + self.read_fd = os.open(self.read_pipe, os.O_RDWR | os.O_NONBLOCK) + self.write_fd = os.open(self.write_pipe, os.O_RDWR | os.O_NONBLOCK) + _set_nonblock(self.read_fd) + _set_nonblock(self.write_fd) + + return server.TransportTimeouts( + session_start_retry_timeout_sec=2.0, + session_start_timeout_sec=5.0, + session_established_timeout_sec=5.0, + ) + + def close(self): + did_write = False + if self.write_fd is not None: + try: + server.write_with_timeout( + self.write_fd, b"\x01x", 1.0 + ) # Use a short timeout since we will kill the process + did_write = True + except server.IoTimeoutError: + pass + os.close(self.write_fd) + self.write_fd = None + + if self.proc: + if not did_write: + self.proc.terminate() + try: + self.proc.wait(5.0) + except subprocess.TimeoutExpired: + self.proc.kill() + + if self.read_fd: + os.close(self.read_fd) + self.read_fd = None + + if self.pipe_dir is not None: + shutil.rmtree(self.pipe_dir) + self.pipe_dir = None + + def read(self, n, timeout_sec): + return server.read_with_timeout(self.read_fd, n, timeout_sec) + + def write(self, data, timeout_sec): + to_write = bytearray() + escape_pos = [] + for i, b in enumerate(data): + if b == 0x01: + to_write.append(b) + escape_pos.append(i) + to_write.append(b) + + num_written = server.write_with_timeout(self.write_fd, to_write, timeout_sec) + num_written -= sum(1 if x < num_written else 0 for x in escape_pos) + return num_written + + def _qemu_check_stdout(self): + for line in self.proc.stdout: + line = str(line) + _LOG.info("%s", line) + if "[QEMU] CPU" in line: + self._queue.put(ZephyrQemuMakeResult.QEMU_STARTED) + else: + line = re.sub("[^a-zA-Z0-9 \n]", "", line) + pattern = r"recipe for target (\w*) failed" + if re.search(pattern, line, re.IGNORECASE): + self._queue.put(ZephyrQemuMakeResult.MAKE_FAILED) + self._queue.put(ZephyrQemuMakeResult.EOF) + + def _wait_for_qemu(self): + threading.Thread(target=self._qemu_check_stdout, daemon=True).start() + while True: + try: + item = self._queue.get(timeout=120) + except Exception: + raise TimeoutError("QEMU setup timeout.") + + if item == ZephyrQemuMakeResult.QEMU_STARTED: + break + + if item in [ZephyrQemuMakeResult.MAKE_FAILED, ZephyrQemuMakeResult.EOF]: + raise RuntimeError("QEMU setup failed.") + + raise ValueError(f"{item} not expected.") + + +if __name__ == "__main__": + server.main(Handler()) diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-arm b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm similarity index 100% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-arm rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-i386 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386 similarity index 91% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-i386 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386 index a30605204d31..6871efbc8b6f 100755 --- a/apps/microtvm/zephyr/qemu-hack/qemu-system-i386 +++ b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386 @@ -31,8 +31,8 @@ while [ "$#" -gt 0 ]; do done # For debugging -if [ "${TVM_QEMU_DEBUG}" != "" ]; then - ARGS=( "${ARGS[@]}" -s -S ) +if [ "${TVM_QEMU_GDBSERVER_PORT}" != "" ]; then + ARGS=( "${ARGS[@]}" -gdb "tcp::${TVM_QEMU_GDBSERVER_PORT}" -S ) fi "${ARGS[@]}" diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 similarity index 100% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 similarity index 100% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64 diff --git a/apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64 b/apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 similarity index 100% rename from apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64 rename to apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64 diff --git a/apps/microtvm/zephyr/aot_demo/src/main.c b/apps/microtvm/zephyr/template_project/src/aot_demo/main.c similarity index 97% rename from apps/microtvm/zephyr/aot_demo/src/main.c rename to apps/microtvm/zephyr/template_project/src/aot_demo/main.c index 0c16572fc744..276c23161fd7 100644 --- a/apps/microtvm/zephyr/aot_demo/src/main.c +++ b/apps/microtvm/zephyr/template_project/src/aot_demo/main.c @@ -46,8 +46,7 @@ extern tvm_model_t tvmgen_default_network; tvm_workspace_t app_workspace; // Wakeup sequence used to wake up QEMU on the host. -const unsigned char g_wakeup_sequence[12] = {0xfe, 0xff, 0xfd, 0x03, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02, 0x66, 0x77}; +const unsigned char g_wakeup_sequence[] = "#wakeup\n"; const char g_start_cmd[] = "start\n"; size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, diff --git a/apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c b/apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.c similarity index 100% rename from apps/microtvm/zephyr/aot_demo/src/zephyr_uart.c rename to apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.c diff --git a/apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h b/apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.h similarity index 100% rename from apps/microtvm/zephyr/aot_demo/include/zephyr_uart.h rename to apps/microtvm/zephyr/template_project/src/aot_demo/zephyr_uart.h diff --git a/apps/microtvm/zephyr/host_driven/src/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c similarity index 100% rename from apps/microtvm/zephyr/host_driven/src/main.c rename to apps/microtvm/zephyr/template_project/src/host_driven/main.c diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index 09f2ccc95d85..32d7e59f2058 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -46,8 +46,9 @@ if(USE_MICRO) "src/runtime/crt/graph_executor *.c -> src/runtime/crt/graph_executor" "src/runtime/crt/aot_executor *.c -> src/runtime/crt/aot_executor" "src/runtime/crt/graph_executor_module *.c -> src/runtime/crt/graph_executor_module" - "src/runtime/crt/host crt_config.h -> template/host" "src/runtime/crt/host *.cc -> template/host" + "src/runtime/crt/host *.py -> template/host" + "src/runtime/crt/host Makefile -> template/host" "src/runtime/crt/memory *.c -> src/runtime/crt/memory" "src/runtime/crt/microtvm_rpc_common *.cc -> src/runtime/crt/microtvm_rpc_common" "src/runtime/crt/microtvm_rpc_server *.cc -> src/runtime/crt/microtvm_rpc_server" @@ -104,7 +105,7 @@ if(USE_MICRO) endforeach() set(make_common_args - "CRT_CONFIG=template/host/crt_config.h" + "CRT_CONFIG=${CMAKE_SOURCE_DIR}/src/runtime/micro/crt_config.h" "BUILD_DIR=${host_build_dir_abspath}" "EXTRA_CFLAGS=-fPIC" "EXTRA_CXXFLAGS=-fPIC" @@ -145,7 +146,7 @@ if(USE_MICRO) string(REPLACE ".cc" "" __execname ${__srcname}) add_executable(${__execname} ${__srcpath}) list(APPEND TEST_EXECS ${__execname}) - target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host) + target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/micro) target_compile_options(${__execname} PRIVATE -pthread) target_link_libraries(${__execname} ${cmake_crt_libraries} ${GTEST_LIB} pthread) set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) diff --git a/include/tvm/runtime/crt/rpc_common/framing.h b/include/tvm/runtime/crt/rpc_common/framing.h index 32a0f56dab11..33f37a0af03f 100644 --- a/include/tvm/runtime/crt/rpc_common/framing.h +++ b/include/tvm/runtime/crt/rpc_common/framing.h @@ -134,7 +134,7 @@ class Unframer { /*! \brief number of bytes in buffer that are currently valid. */ size_t num_buffer_bytes_valid_; - /*! \brief number of payload bytes left to write before the CRC begins. */ + /*! \brief number of payload bytes left to receive before the CRC begins. */ size_t num_payload_bytes_remaining_; /*! \brief Running CRC value. */ diff --git a/python/tvm/contrib/utils.py b/python/tvm/contrib/utils.py index 68c6b3d5bf6b..e2ca182779c6 100644 --- a/python/tvm/contrib/utils.py +++ b/python/tvm/contrib/utils.py @@ -124,7 +124,7 @@ def remove(self): def path(self): return pathlib.Path(self.temp_dir) - def __div__(self, other): + def __truediv__(self, other): if not isinstance(other, (str, pathlib.Path)): raise TypeError( "TempDirectory / operator: must supply str or pathlib.Path; got %r" % (other,) diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index a70cb96d9b13..88dcde8ceaf0 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -16,18 +16,13 @@ # under the License. """MicroTVM module for bare-metal backends""" -from .artifact import Artifact -from .build import build_static_runtime, default_options, get_standalone_crt_dir -from .build import get_standalone_crt_lib, Workspace -from .compiler import Compiler, DefaultCompiler, Flasher -from .debugger import GdbRemoteDebugger -from .micro_library import MicroLibrary -from .micro_binary import MicroBinary +from .build import get_standalone_crt_dir from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError +from .project import generate_project, GeneratedProject, TemplateProject from .session import ( create_local_graph_executor, create_local_debug_executor, Session, SessionTerminatedError, ) -from .transport import TransportLogger, DebugWrapperTransport, SubprocessTransport +from .transport import TransportLogger diff --git a/python/tvm/micro/artifact.py b/python/tvm/micro/artifact.py deleted file mode 100644 index c8faccb3f512..000000000000 --- a/python/tvm/micro/artifact.py +++ /dev/null @@ -1,295 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries.""" - -import hashlib -import io -import os -import json -import shutil -import tarfile - - -class ArtifactFileNotFoundError(Exception): - """Raised when an artifact file cannot be found on disk.""" - - -class ArtifactBadSymlinkError(Exception): - """Raised when an artifact symlink points outside the base directory.""" - - -class ArtifactBadArchiveError(Exception): - """Raised when an artifact archive is malformed.""" - - -class ImmobileArtifactError(Exception): - """Raised when an artifact is declared immobile and thus cannot be archived.""" - - -class ArchiveModifiedError(Exception): - """Raised when the underlying files in a metadata-only archive were modified after archiving.""" - - -def sha256_hexdigest(path): - with open(path, "rb") as path_fd: - h = hashlib.sha256() - chunk = path_fd.read(1 * 1024 * 1024) - while chunk: - h.update(chunk) - chunk = path_fd.read(1 * 1024 * 1024) - - return h.hexdigest() - - -def _validate_metadata_only(metadata): - """Validate that the files in a metadata-only archive have not changed.""" - problems = [] - for files in metadata["labelled_files"].values(): - for f in files: - disk_path = os.path.join(metadata["base_dir"], f) - try: - sha = sha256_hexdigest(disk_path) - except FileNotFoundError: - problems.append(f"{f}: original file not found") - continue - - expected_sha = metadata["file_digests"][f] - if sha != expected_sha: - problems.append(f"{f}: sha256 mismatch: expected {expected_sha}, got {sha}") - - if problems: - raise ArchiveModifiedError( - "Files in metadata-only archive have been modified:\n" - + "\n".join([f" * {p}" for p in problems]) - ) - - -class Artifact: - """Describes a compiler artifact and defines common logic to archive it for transport.""" - - # A version number written to the archive. - ENCODING_VERSION = 2 - - # A unique string identifying the type of artifact in an archive. Subclasses must redefine this - # variable. - ARTIFACT_TYPE = None - - @classmethod - def unarchive(cls, archive_path, base_dir): - """Unarchive an artifact into base_dir. - - Parameters - ---------- - archive_path : str - Path to the archive file. - base_dir : str - Path to a non-existent, empty directory under which the artifact will live. If working - with a metadata-only archive, this directory will just hold the metadata.json. - - Returns - ------- - Artifact : - The unarchived artifact. - """ - if os.path.exists(base_dir): - raise ValueError(f"base_dir exists: {base_dir}") - - base_dir_parent, base_dir_name = os.path.split(base_dir) - temp_dir = os.path.join(base_dir_parent, f"__tvm__{base_dir_name}") - os.mkdir(temp_dir) - try: - with tarfile.open(archive_path) as tar_f: - tar_f.extractall(temp_dir) - - temp_dir_contents = os.listdir(temp_dir) - if len(temp_dir_contents) != 1: - raise ArtifactBadArchiveError( - "Expected exactly 1 subdirectory at root of archive, got " - f"{temp_dir_contents!r}" - ) - - metadata_path = os.path.join(temp_dir, temp_dir_contents[0], "metadata.json") - if not metadata_path: - raise ArtifactBadArchiveError("No metadata.json found in archive") - - with open(metadata_path) as metadata_f: - metadata = json.load(metadata_f) - - version = metadata.get("version") - if version != cls.ENCODING_VERSION: - raise ArtifactBadArchiveError( - f"archive version: expect {cls.EXPECTED_VERSION}, found {version}" - ) - - metadata_only = metadata.get("metadata_only") - if metadata_only: - _validate_metadata_only(metadata) - - os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir) - - artifact_cls = cls - for sub_cls in cls.__subclasses__(): - if sub_cls.ARTIFACT_TYPE is not None and sub_cls.ARTIFACT_TYPE == metadata.get( - "artifact_type" - ): - artifact_cls = sub_cls - break - - return artifact_cls.from_unarchived( - base_dir if not metadata_only else metadata["base_dir"], - metadata["labelled_files"], - metadata["metadata"], - immobile=metadata.get("immobile"), - ) - finally: - shutil.rmtree(temp_dir) - - @classmethod - def from_unarchived(cls, base_dir, labelled_files, metadata, immobile): - return cls(base_dir, labelled_files, metadata, immobile) - - def __init__(self, base_dir, labelled_files, metadata, immobile=False): - """Create a new artifact. - - Parameters - ---------- - base_dir : str - The path to a directory on disk which contains all the files in this artifact. - labelled_files : Dict[str, str] - A dict mapping a file label to the relative paths of the files that carry that label. - metadata : Dict - A dict containing artitrary JSON-serializable key-value data describing the artifact. - immobile : bool - True when this artifact can't be used after being moved out of its current location on - disk. This can happen when artifacts contain absolute paths or when it's not feasible to - include enough files in the artifact to reliably re-run commands in arbitrary locations. - Setting this flag will cause archive() to raise ImmboileArtifactError. - """ - self.base_dir = os.path.realpath(base_dir) - self.labelled_files = labelled_files - self.metadata = metadata - self.immobile = immobile - - for label, files in labelled_files.items(): - for f in files: - f_path = os.path.join(self.base_dir, f) - if not os.path.lexists(f_path): - raise ArtifactFileNotFoundError(f"{f} (label {label}): not found at {f_path}") - - if os.path.islink(f_path): - link_path = os.path.readlink(f_path) - if os.path.isabs(link_path): - link_fullpath = link_path - else: - link_fullpath = os.path.join(os.path.dirname(f_path), link_path) - - link_fullpath = os.path.realpath(link_fullpath) - if not link_fullpath.startswith(self.base_dir): - raise ArtifactBadSymlinkError( - f"{f} (label {label}): symlink points outside artifact tree" - ) - - def abspath(self, rel_path): - """Return absolute path to the member with the given relative path.""" - return os.path.join(self.base_dir, rel_path) - - def label(self, label): - """Return a list of relative paths to files with the given label.""" - return self.labelled_files[label] - - def label_abspath(self, label): - return [self.abspath(p) for p in self.labelled_files[label]] - - def archive(self, archive_path, metadata_only=False): - """Create a relocatable tar archive of the artifacts. - - Parameters - ---------- - archive_path : str - Path to the tar file to create. Or, path to a directory, under which a tar file will be - created named {base_dir}.tar. - metadata_only : bool - If true, don't archive artifacts; instead, just archive metadata plus original - base_path. A metadata-only archive can be unarchived and used like a regular archive - provided none of the files have changed in their original locations on-disk. - - Returns - ------- - str : - The value of archive_path, after potentially making the computation describe above. - - Raises - ------ - ImmboileArtifactError : - When immobile=True was passed to the constructor. - """ - if self.immobile and not metadata_only: - raise ImmobileArtifactError("This artifact can't be moved") - - if os.path.isdir(archive_path): - archive_path = os.path.join(archive_path, f"{os.path.basename(self.base_dir)}.tar") - - archive_name = os.path.splitext(os.path.basename(archive_path))[0] - with tarfile.open(archive_path, "w") as tar_f: - - def _add_file(name, data, f_type): - tar_info = tarfile.TarInfo(name=name) - tar_info.type = f_type - data_bytes = bytes(data, "utf-8") - tar_info.size = len(data) - tar_f.addfile(tar_info, io.BytesIO(data_bytes)) - - metadata = { - "version": self.ENCODING_VERSION, - "labelled_files": self.labelled_files, - "metadata": self.metadata, - "metadata_only": False, - } - if metadata_only: - metadata["metadata_only"] = True - metadata["base_dir"] = self.base_dir - metadata["immobile"] = self.immobile - metadata["file_digests"] = {} - for files in self.labelled_files.values(): - for f in files: - metadata["file_digests"][f] = sha256_hexdigest(self.abspath(f)) - - _add_file( - f"{archive_name}/metadata.json", - json.dumps(metadata, indent=2, sort_keys=True), - tarfile.REGTYPE, - ) - for dir_path, _, files in os.walk(self.base_dir): - for f in files: - file_path = os.path.join(dir_path, f) - archive_file_path = os.path.join( - archive_name, os.path.relpath(file_path, self.base_dir) - ) - if not os.path.islink(file_path): - tar_f.add(file_path, archive_file_path, recursive=False) - continue - - link_path = os.readlink(file_path) - if not os.path.isabs(link_path): - tar_f.add(file_path, archive_file_path, recursive=False) - continue - - relpath = os.path.relpath(link_path, os.path.dirname(file_path)) - _add_file(archive_file_path, relpath, tarfile.LNKTYPE) - - return archive_path diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index a83ccaa47cda..16e7ed24cb4f 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -17,42 +17,15 @@ """Defines top-level glue functions for building microTVM artifacts.""" -import copy import logging import os -import re -import typing -from tvm.contrib import utils -from .micro_library import MicroLibrary from .._ffi import libinfo _LOG = logging.getLogger(__name__) -class Workspace: - """Defines helper functions for manipulating temporary compilation workspaces.""" - - def __init__(self, root=None, debug=False): - if debug or root is not None: - with utils.TempDirectory.set_keep_for_debug(): - self.tempdir = utils.tempdir(custom_path=root) - _LOG.info("Created debug mode workspace at: %s", self.tempdir.temp_dir) - else: - self.tempdir = utils.tempdir() - - def relpath(self, path): - return self.tempdir.relpath(path) - - def listdir(self): - return self.tempdir.listdir() - - @property - def path(self): - return self.tempdir.temp_dir - - STANDALONE_CRT_DIR = None @@ -84,186 +57,3 @@ def get_standalone_crt_dir() -> str: raise CrtNotFoundError() return STANDALONE_CRT_DIR - - -def get_standalone_crt_lib(name: str) -> str: - """Find a source library directory in the standalone_crt. - - The standalone C runtime is split into various libraries (one per directory underneath - src/runtime/crt). This convenience function returns the full path to one of those libraries - located in get_standalone_crt_dir(). - - Parameters - ---------- - name : str - Name of the library subdirectory underneath src/runtime/crt. - - Returns - ------- - str : - The full path to the the library. - """ - return os.path.join(get_standalone_crt_dir(), "src", "runtime", "crt", name) - - -def get_runtime_libs(executor: str) -> str: - """Return abspath to all CRT directories in link order which contain - source (i.e. not header) files. - """ - if executor == "host-driven": - crt_runtime_lib_names = ["microtvm_rpc_server", "microtvm_rpc_common", "common"] - elif executor == "aot": - crt_runtime_lib_names = ["aot_executor", "common"] - else: - raise ValueError(f"Incorrect executor: {executor}") - return [get_standalone_crt_lib(n) for n in crt_runtime_lib_names] - - -RUNTIME_SRC_REGEX = re.compile(r"^.*\.cc?$", re.IGNORECASE) - - -_COMMON_CFLAGS = ["-Wall", "-Werror", "-DDMLC_USE_LOGGING_LIBRARY="] - - -def _build_default_compiler_options(standalone_crt_dir: typing.Optional[str] = None) -> str: - """Return a dict containing base compile flags for the CRT under gcc common to . - - Parameters - ---------- - standalone_crt_dir : Optional[str] - If given, the path to the standalone_crt - """ - if standalone_crt_dir is None: - standalone_crt_dir = get_standalone_crt_dir() - return { - "cflags": ["-std=c11"] + _COMMON_CFLAGS, - "ccflags": ["-std=c++11"] + _COMMON_CFLAGS, - "ldflags": ["-std=c++11"], - "include_dirs": [os.path.join(standalone_crt_dir, "include")], - } - - -def default_options(crt_config_include_dir, standalone_crt_dir=None): - """Return default opts passed to Compile commands. - - Parameters - ---------- - crt_config_include_dir : str - Path to a directory containing crt_config.h for the target. This will be appended - to the include path for cflags and ccflags. - standalone_crt_dir : Optional[str] - - Returns - ------- - Dict : - A dictionary containing 3 subkeys, each whose value is _build_default_compiler_options() - plus additional customization. - - - "bin_opts" - passed as "options" to Compiler.binary() when building MicroBinary. - - "lib_opts" - passed as "options" to Compiler.library() when building bundled CRT - libraries (or otherwise, non-generated libraries). - - "generated_lib_opts" - passed as "options" to Compiler.library() when building the - generated library. - """ - bin_opts = _build_default_compiler_options(standalone_crt_dir) - bin_opts["include_dirs"].append(crt_config_include_dir) - - lib_opts = _build_default_compiler_options(standalone_crt_dir) - lib_opts["cflags"] = ["-Wno-error=incompatible-pointer-types"] - lib_opts["include_dirs"].append(crt_config_include_dir) - - generated_lib_opts = copy.copy(lib_opts) - - # Disable due to limitation in the TVM C codegen, which generates lots of local variable - # declarations at the top of generated code without caring whether they're used. - # Example: - # void* arg0 = (((TVMValue*)args)[0].v_handle); - # int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)]; - generated_lib_opts["cflags"].append("-Wno-unused-variable") - generated_lib_opts["ccflags"].append("-Wno-unused-variable") - - # Many TVM-intrinsic operators (i.e. expf, in particular) - generated_lib_opts["cflags"].append("-fno-builtin") - - return {"bin_opts": bin_opts, "lib_opts": lib_opts, "generated_lib_opts": generated_lib_opts} - - -def build_static_runtime( - workspace, - compiler, - module, - compiler_options, - executor=None, - extra_libs=None, -): - """Build the on-device runtime, statically linking the given modules. - - Parameters - ---------- - compiler : tvm.micro.Compiler - Compiler instance used to build the runtime. - - module : IRModule - Module to statically link. - - compiler_options : dict - The return value of tvm.micro.default_options(), with any keys overridden to inject - compiler options specific to this build. If not given, tvm.micro.default_options() is - used. This dict contains the `options` parameter passed to Compiler.library() and - Compiler.binary() at various stages in the compilation process. - - executor : Optional[str] - Executor used for runtime. Based on this we determine the libraries that need to be - linked with runtime. - - extra_libs : Optional[List[MicroLibrary|str]] - If specified, extra libraries to be compiled into the binary. If a MicroLibrary, it is - included into the binary directly. If a string, the path to a directory; all direct children - of this directory matching RUNTIME_SRC_REGEX are built into a library. These libraries are - placed before any common CRT libraries in the link order. - - Returns - ------- - MicroBinary : - The compiled runtime. - """ - mod_build_dir = workspace.relpath(os.path.join("build", "module")) - os.makedirs(mod_build_dir) - mod_src_dir = workspace.relpath(os.path.join("src", "module")) - - if not executor: - executor = "host-driven" - - libs = [] - for mod_or_src_dir in (extra_libs or []) + get_runtime_libs(executor): - if isinstance(mod_or_src_dir, MicroLibrary): - libs.append(mod_or_src_dir) - continue - - lib_src_dir = mod_or_src_dir - lib_name = os.path.basename(lib_src_dir) - lib_build_dir = workspace.relpath(f"build/{lib_name}") - os.makedirs(lib_build_dir) - - lib_srcs = [] - for p in os.listdir(lib_src_dir): - if RUNTIME_SRC_REGEX.match(p): - lib_srcs.append(os.path.join(lib_src_dir, p)) - - libs.append(compiler.library(lib_build_dir, lib_srcs, compiler_options["lib_opts"])) - - mod_src_dir = workspace.relpath(os.path.join("src", "module")) - os.makedirs(mod_src_dir) - libs.append( - module.export_library( - mod_build_dir, - workspace_dir=mod_src_dir, - fcompile=lambda bdir, srcs, **kwargs: compiler.library( - bdir, srcs, compiler_options["generated_lib_opts"] - ), - ) - ) - - runtime_build_dir = workspace.relpath(f"build/runtime") - os.makedirs(runtime_build_dir) - return compiler.binary(runtime_build_dir, libs, compiler_options["bin_opts"]) diff --git a/python/tvm/micro/compiler.py b/python/tvm/micro/compiler.py deleted file mode 100644 index 5bc5aba8a1be..000000000000 --- a/python/tvm/micro/compiler.py +++ /dev/null @@ -1,361 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines interfaces and default implementations for compiling and flashing code.""" - -import abc -import glob -import os -import re -import subprocess - -import tvm.target -from . import class_factory -from . import debugger -from . import transport - - -def run_cmd(cmd): - """Runs `cmd` in a subprocess and awaits its completion. - - Parameters - ---------- - cmd : List[str] - list of command-line arguments - - Returns - ------- - output : str - resulting stdout capture from the subprocess - """ - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - (output, _) = proc.communicate() - output = output.decode("utf-8") - if proc.returncode != 0: - cmd_str = " ".join(cmd) - msg = f'error while running command "{cmd_str}":\n{output}' - raise RuntimeError(msg) - - -class DetectTargetError(Exception): - """Raised when no target comment was detected in the sources given.""" - - -class NoDefaultToolchainMatchedError(Exception): - """Raised when no default toolchain matches the target string.""" - - -class Compiler(metaclass=abc.ABCMeta): - """The compiler abstraction used with micro TVM.""" - - TVM_TARGET_RE = re.compile(r"^// tvm target: (.*)$") - - @classmethod - def _target_from_sources(cls, sources): - """Determine the target used to generate the given source files. - - Parameters - ---------- - sources : List[str] - The paths to source files to analyze. - - Returns - ------- - tvm.target.Target : - A Target instance reconstructed from the target string listed in the source files. - """ - target_strs = set() - - for obj in sources: - if os.path.splitext(obj)[1] not in (".cc", ".c"): - continue - - with open(obj) as obj_f: - for line in obj_f: - m = cls.TVM_TARGET_RE.match(line) - if m: - target_strs.add(m.group(1)) - - if len(target_strs) != 1: - raise DetectTargetError( - "autodetecting cross-compiler: could not extract TVM target from C source; regex " - f"{cls.TVM_TARGET_RE.pattern} does not match any line in sources: " - f'{", ".join(sources)}' - ) - - target_str = next(iter(target_strs)) - return tvm.target.Target(target_str) - - # Maps regexes identifying CPUs to the default toolchain prefix for that CPU. - TOOLCHAIN_PREFIX_BY_CPU_REGEX = { - r"cortex-[am].*": "arm-none-eabi-", - "x86[_-]64": "", - "native": "", - } - - def _autodetect_toolchain_prefix(self, target): - # Treat absence of -mcpu as if -mcpu=native is specified. The gcc shipped with OS X - # complains if -mcpu=native is given, so this approach allows model targets to avoid - # specifying this flag e.g. for tutorials. - if "mcpu" not in target.attrs: - return self.TOOLCHAIN_PREFIX_BY_CPU_REGEX["native"] - - matches = [] - for regex, prefix in self.TOOLCHAIN_PREFIX_BY_CPU_REGEX.items(): - if re.match(regex, target.attrs["mcpu"]): - matches.append(prefix) - - if matches: - if len(matches) != 1: - raise NoDefaultToolchainMatchedError( - f'{opt} matched more than 1 default toolchain prefix: {", ".join(matches)}. ' - "Specify cc.cross_compiler to create_micro_library()" - ) - - return matches[0] - - raise NoDefaultToolchainMatchedError( - f"target {str(target)} did not match any default toolchains" - ) - - def _defaults_from_target(self, target): - """Determine the default compiler options from the target specified. - - Parameters - ---------- - target : tvm.target.Target - - Returns - ------- - List[str] : - Default options used the configure the compiler for that target. - """ - opts = [] - # TODO use march for arm(https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html)? - if target.attrs.get("mcpu"): - opts.append(f'-mcpu={target.attrs["mcpu"]}') - if target.attrs.get("mfpu"): - opts.append(f'-mfpu={target.attrs["mfpu"]}') - if target.attrs.get("march"): - opts.append(f'-march={target.attrs["march"]}') - - return opts - - @abc.abstractmethod - def library(self, output, sources, options=None): - """Build a library from the given source files. - - Parameters - ---------- - output : str - The path to the library that should be created. The containing directory - is guaranteed to be empty and should be the base_dir for the returned - Artifact. - sources : List[str] - A list of paths to source files that should be compiled. - options : Optional[List[str]] - If given, additional command-line flags to pass to the compiler. - - Returns - ------- - MicroLibrary : - The compiled library, as a MicroLibrary instance. - """ - raise NotImplementedError() - - @abc.abstractmethod - def binary(self, output, objects, options=None, link_main=True, main_options=None): - """Link a binary from the given object and/or source files. - - Parameters - ---------- - output : str - The path to the binary that should be created. The containing directory - is guaranteed to be empty and should be the base_dir for the returned - Artifact. - objects : List[MicroLibrary] - A list of paths to source files or libraries that should be compiled. The final binary - should be statically-linked. - options: Optional[List[str]] - If given, additional command-line flags to pass to the compiler. - link_main: Optional[bool] - True if the standard main entry point for this Compiler should be included in the - binary. False if a main entry point is provided in one of `objects`. - main_options: Optional[List[str]] - If given, additional command-line flags to pass to the compiler when compiling the - main() library. In some cases, the main() may be compiled directly into the final binary - along with `objects` for logistical reasons. In those cases, specifying main_options is - an error and ValueError will be raised. - - Returns - ------- - MicroBinary : - The compiled binary, as a MicroBinary instance. - """ - raise NotImplementedError() - - @property - def flasher_factory(self): - """Produce a FlasherFactory for a Flasher instance suitable for this Compiler.""" - raise NotImplementedError("The Compiler base class doesn't define a flasher.") - - def flasher(self, **kw): - """Return a Flasher that can be used to program a produced MicroBinary onto the target.""" - return self.flasher_factory.override_kw(**kw).instantiate() - - -class IncompatibleTargetError(Exception): - """Raised when source files specify a target that differs from the compiler target.""" - - -class DefaultCompiler(Compiler): - """A Compiler implementation that attempts to use the system-installed GCC.""" - - def __init__(self, target=None): - super(DefaultCompiler, self).__init__() - self.target = target - if isinstance(target, str): - self.target = tvm.target.create(target) - - def library(self, output, sources, options=None): - options = options if options is not None else {} - try: - target = self._target_from_sources(sources) - except DetectTargetError: - assert self.target is not None, ( - "Must specify target= to constructor when compiling sources which don't specify a " - "target" - ) - - target = self.target - - if self.target is not None and str(self.target) != str(target): - raise IncompatibleTargetError( - f"auto-detected target {target} differs from configured {self.target}" - ) - - prefix = self._autodetect_toolchain_prefix(target) - outputs = [s for s in sources if os.path.splitext(s)[1] == ".o"] - sources = [s for s in sources if s not in outputs] - for src in sources: - src_base, src_ext = os.path.splitext(os.path.basename(src)) - - compiler_name = {".c": "gcc", ".cc": "g++", ".cpp": "g++"}[src_ext] - args = [prefix + compiler_name, "-g"] - args.extend(self._defaults_from_target(target)) - - args.extend(options.get(f"{src_ext[1:]}flags", [])) - - for include_dir in options.get("include_dirs", []): - args.extend(["-I", include_dir]) - - output_filename = f"{src_base}.o" - output_abspath = os.path.join(output, output_filename) - run_cmd(args + ["-c", "-o", output_abspath, src]) - outputs.append(output_abspath) - - output_filename = f"{os.path.basename(output)}.a" - output_abspath = os.path.join(output, output_filename) - run_cmd([prefix + "ar", "-r", output_abspath] + outputs) - run_cmd([prefix + "ranlib", output_abspath]) - - return tvm.micro.MicroLibrary(output, [output_filename]) - - def binary(self, output, objects, options=None, link_main=True, main_options=None): - assert self.target is not None, ( - "must specify target= to constructor, or compile sources which specify the target " - "first" - ) - - args = [self._autodetect_toolchain_prefix(self.target) + "g++"] - args.extend(self._defaults_from_target(self.target)) - if options is not None: - args.extend(options.get("ldflags", [])) - - for include_dir in options.get("include_dirs", []): - args.extend(["-I", include_dir]) - - output_filename = os.path.basename(output) - output_abspath = os.path.join(output, output_filename) - args.extend(["-g", "-o", output_abspath]) - - if link_main: - host_main_srcs = glob.glob( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host", "*.cc") - ) - if main_options: - main_lib = self.library(os.path.join(output, "host"), host_main_srcs, main_options) - for lib_name in main_lib.library_files: - args.append(main_lib.abspath(lib_name)) - else: - args.extend(host_main_srcs) - - for obj in objects: - for lib_name in obj.library_files: - args.append(obj.abspath(lib_name)) - - run_cmd(args) - return tvm.micro.MicroBinary(output, output_filename, []) - - @property - def flasher_factory(self): - return FlasherFactory(HostFlasher, [], {}) - - -class Flasher(metaclass=abc.ABCMeta): - """An interface for flashing binaries and returning a transport factory.""" - - @abc.abstractmethod - def flash(self, micro_binary): - """Flash a binary onto the device. - - Parameters - ---------- - micro_binary : MicroBinary - A MicroBinary instance. - - Returns - ------- - transport.TransportContextManager : - A ContextManager that can be used to create and tear down an RPC transport layer between - this TVM instance and the newly-flashed binary. - """ - raise NotImplementedError() - - -class FlasherFactory(class_factory.ClassFactory): - """A ClassFactory for Flasher instances.""" - - SUPERCLASS = Flasher - - -class HostFlasher(Flasher): - """A Flasher implementation that spawns a subprocess on the host.""" - - def __init__(self, debug=False): - self.debug = debug - - def flash(self, micro_binary): - if self.debug: - gdb_wrapper = debugger.GdbTransportDebugger( - [micro_binary.abspath(micro_binary.binary_file)] - ) - return transport.DebugWrapperTransport( - debugger=gdb_wrapper, transport=gdb_wrapper.transport() - ) - - return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)]) diff --git a/python/tvm/micro/contrib/__init__.py b/python/tvm/micro/contrib/__init__.py deleted file mode 100644 index 13a83393a912..000000000000 --- a/python/tvm/micro/contrib/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/python/tvm/micro/contrib/base.py b/python/tvm/micro/contrib/base.py deleted file mode 100644 index 9c4f4863e3bc..000000000000 --- a/python/tvm/micro/contrib/base.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines common helper functions useful for integrating custom compiler toolchains.""" - -import glob -import os -import shutil - - -GLOB_PATTERNS = ["__tvm_*", "libtvm__*"] - - -def populate_tvm_objs(dest_dir, objs): - """Replace tvm-prefixed files in a build worktree. - - This function is intended to be used to place TVM source files and libraries into a - template on-device runtime project. - - Parameters - ---------- - dest_dir : str - Path to the destination directory. - - objs : List[MicroLibrary] - List of MicroLibrary to place in the project directory. - - Returns - ------- - List[str] : - List of paths, each relative to `dest_dir` to the newly-copied MicroLibrary files. - """ - copied = [] - for p in GLOB_PATTERNS: - for f in glob.glob(os.path.join(dest_dir, p)): - if os.path.isdir(f): - shutil.rmtree(f) - else: - os.unlink(f) - - for obj in objs: - for lib_file in obj.library_files: - obj_base = os.path.basename(lib_file) - if obj_base.endswith(".a"): - dest_basename = f"libtvm__{obj_base}" - else: - dest_basename = f"__tvm_{obj_base}" - - copied.append(dest_basename) - dest = os.path.join(dest_dir, dest_basename) - shutil.copy(obj.abspath(lib_file), dest) - - return copied diff --git a/python/tvm/micro/contrib/zephyr.py b/python/tvm/micro/contrib/zephyr.py deleted file mode 100644 index 77cfb8d09bf2..000000000000 --- a/python/tvm/micro/contrib/zephyr.py +++ /dev/null @@ -1,789 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines a compiler integration that uses an externally-supplied Zephyr project.""" - -import collections -import copy -import logging -import multiprocessing -import os -import pathlib -import re -import tempfile -import textwrap -import shlex -import shutil -import subprocess -import sys -import threading -import queue -import enum - -import yaml - -import tvm.micro -from . import base -from .. import compiler -from .. import debugger -from ..transport import debug -from ..transport import file_descriptor - -from ..transport import serial -from ..transport import Transport, TransportClosedError, TransportTimeouts -from ..transport import wakeup - - -_LOG = logging.getLogger(__name__) - - -class SubprocessEnv(object): - def __init__(self, default_overrides): - self.default_overrides = default_overrides - - def run(self, cmd, **kw): - env = dict(os.environ) - for k, v in self.default_overrides.items(): - env[k] = v - - return subprocess.check_output(cmd, env=env, **kw, universal_newlines=True) - - -class ProjectNotFoundError(Exception): - """Raised when the project_dir supplied to ZephyrCompiler does not exist.""" - - -class FlashRunnerNotSupported(Exception): - """Raised when the FLASH_RUNNER for a project isn't supported by this Zephyr adapter.""" - - -class ZephyrCompiler(tvm.micro.Compiler): - """A Compiler instance that builds against a pre-existing zephyr project.""" - - def __init__( - self, - project_dir=None, - board=None, - west_cmd=None, - zephyr_base=None, - zephyr_toolchain_variant=None, - env_vars=None, - ): - """Configure the compiler for use. - - Parameters - ---------- - project_dir : str - Path to the pre-existing Zephyr project. - board : str - Name of the Zephyr board to build for (i.e. passed to `west build -b`) - west_cmd : Optional[list] - If given, argv that invoke the west build tool. Used only for flashing. - zephyr_base : Optional[str] - If given, path to Zephyr, as would normally be present in the ZEPHYR_BASE environment - variable. If not given, consults this environment variable. This value must be set in - one of those two places. - zephyr_toolchain_variant: Optional[str] - If given, overrides the toolchain used by Zephyr. If not given, uses the default - zephyr toolchain. When running on OS X outside of docker, you need to specify this. - env_vars : Optional[Dict[str,str]] - If given, additional environment variables present when invoking west, cmake, or make. - """ - self._project_dir = project_dir - if not os.path.exists(project_dir): - # Raise this error instead of a potentially-more-cryptic compiler error due to a missing - # prj.conf. - raise ProjectNotFoundError( - f"project_dir supplied to ZephyrCompiler does not exist: {project_dir}" - ) - - self._qemu = "qemu" in board - - # For Zephyr boards that run emulated by default but don't have the prefix "qemu_" in their - # board names, a suffix "-qemu" is added by users of microTVM when specifying the board - # name to inform that the QEMU transporter must be used just like for the boards with - # the prefix. Zephyr does not recognize the suffix, so we trim it off before passing it. - if "-qemu" in board: - board = board.replace("-qemu", "") - - self._board = board - - if west_cmd is None: - self._west_cmd = [sys.executable, "-mwest.app.main"] - elif isinstance(west_cmd, str): - self._west_cmd = [west_cmd] - elif isinstance(west_cmd, list): - self._west_cmd = west_cmd - else: - raise TypeError("west_cmd: expected string, list, or None; got %r" % (west_cmd,)) - - env = {} - if zephyr_toolchain_variant is not None: - env["ZEPHYR_TOOLCHAIN_VARIANT"] = zephyr_toolchain_variant - - self._zephyr_base = zephyr_base or os.environ["ZEPHYR_BASE"] - assert ( - self._zephyr_base is not None - ), f"Must specify zephyr_base=, or ZEPHYR_BASE must be in environment variables" - env["ZEPHYR_BASE"] = self._zephyr_base - - if env_vars: - env.update(env_vars) - - self._subprocess_env = SubprocessEnv(env) - - OPT_KEY_TO_CMAKE_DEFINE = { - "cflags": "CFLAGS", - "ccflags": "CXXFLAGS", - "ldflags": "LDFLAGS", - } - - @classmethod - def _options_to_cmake_args(cls, options): - args = [] - for key, define in cls.OPT_KEY_TO_CMAKE_DEFINE.items(): - if key in options: - quoted_opts = [shlex.quote(o).replace(";", "\\;") for o in options[key]] - args.append(f'-DEXTRA_{define}={" ".join(quoted_opts)}') - - if "cmake_args" in options: - args.extend(options["cmake_args"]) - - return args - - def library(self, output, sources, options=None): - project_name = os.path.basename(output) - if project_name.startswith("lib"): - project_name = project_name[3:] - - lib_prj_conf = os.path.join(output, "prj.conf") - if self._project_dir is not None: - project_dir_conf = os.path.join(self._project_dir, "prj.conf") - if os.path.exists(project_dir_conf): - shutil.copy(project_dir_conf, lib_prj_conf) - - # Copy board-specific Zephyr config file from the project_dir to - # the build lib dir so board-specific configs can be found and used by - # Zephyr's build system in conjunction with the generic prj.conf configs. - board_conf = os.path.join("boards", self._board + ".conf") - project_dir_board_conf = os.path.join(self._project_dir, board_conf) - if os.path.exists(project_dir_board_conf): - os.mkdir(os.path.join(output, "boards")) - lib_dir_board_conf = os.path.join(output, board_conf) - shutil.copy(project_dir_board_conf, lib_dir_board_conf) - - else: - with open(lib_prj_conf, "w") as prj_conf_f: - prj_conf_f.write("CONFIG_CPLUSPLUS=y\n") - - cmakelists_path = os.path.join(output, "CMakeLists.txt") - with open(cmakelists_path, "w") as cmake_f: - sources = " ".join(f'"{o}"' for o in sources) - cmake_f.write( - textwrap.dedent( - f"""\ - cmake_minimum_required(VERSION 3.13.1) - - find_package(Zephyr HINTS $ENV{{ZEPHYR_BASE}}) - project({project_name}_prj) - target_sources(app PRIVATE) - zephyr_library_named({project_name}) - target_sources({project_name} PRIVATE {sources}) - target_sources(app PRIVATE main.c) - target_link_libraries(app PUBLIC {project_name}) - """ - ) - ) - if "include_dirs" in options: - cmake_f.write( - f"target_include_directories({project_name} PRIVATE " - f'{" ".join(os.path.abspath(d) for d in options["include_dirs"])})\n' - ) - - with open(os.path.join(output, "main.c"), "w"): - pass - - # expected not to exist after populate_tvm_libs - build_dir = os.path.join(output, "__tvm_build") - os.mkdir(build_dir) - self._subprocess_env.run( - ["cmake", "..", f"-DBOARD={self._board}"] + self._options_to_cmake_args(options), - cwd=build_dir, - ) - num_cpus = multiprocessing.cpu_count() - self._subprocess_env.run( - ["make", f"-j{num_cpus}", "VERBOSE=1", project_name], cwd=build_dir - ) - return tvm.micro.MicroLibrary(build_dir, [f"lib{project_name}.a"]) - - def _print_make_statistics(self, output): - output = output.splitlines() - lines = iter(output) - for line in lines: - if line.startswith("Memory region"): - # print statistics header - _LOG.info(line) - _LOG.info("--------------------- ---------- ------------ ---------") - line = next(lines) - # while there is a region print it - try: - while ":" in line: - _LOG.info(line) - line = next(lines) - else: - break - except StopIteration: - pass - - def binary(self, output, objects, options=None, link_main=True, main_options=None): - assert link_main, "Must pass link_main=True" - assert self._project_dir is not None, "Must supply project_dir= to build binaries" - - copied_libs = base.populate_tvm_objs(self._project_dir, objects) - - # expected not to exist after populate_tvm_objs - cmake_args = [ - "cmake", - os.path.abspath(self._project_dir), - f"-DBOARD={self._board}", - ] + self._options_to_cmake_args(options) - if "include_dirs" in options: - cmake_args.append( - "-DTVM_INCLUDE_DIRS=" - f'{";".join(os.path.abspath(d) for d in options["include_dirs"])}' - ) - cmake_args.append(f'-DTVM_LIBS={";".join(copied_libs)}') - self._subprocess_env.run(cmake_args, cwd=output) - - make_output = self._subprocess_env.run(["make"], cwd=output) - - self._print_make_statistics(make_output) - - return tvm.micro.MicroBinary( - output, - binary_file=os.path.join("zephyr", "zephyr.elf"), - debug_files=[os.path.join("zephyr", "zephyr.elf")], - labelled_files={ - "cmake_cache": ["CMakeCache.txt"], - "device_tree": [os.path.join("zephyr", "zephyr.dts")], - }, - immobile=bool(self._qemu), - ) - - @property - def flasher_factory(self): - return compiler.FlasherFactory( - ZephyrFlasher, - ( - self._board, - self._qemu, - ), - dict( - zephyr_base=self._zephyr_base, - project_dir=self._project_dir, - subprocess_env=self._subprocess_env.default_overrides, - west_cmd=self._west_cmd, - ), - ) - - -CACHE_ENTRY_RE = re.compile(r"(?P[^:]+):(?P[^=]+)=(?P.*)") - - -CMAKE_BOOL_MAP = dict( - [(k, True) for k in ("1", "ON", "YES", "TRUE", "Y")] - + [(k, False) for k in ("0", "OFF", "NO", "FALSE", "N", "IGNORE", "NOTFOUND", "")] -) - - -def read_cmake_cache(file_name): - """Read a CMakeCache.txt-like file and return a dictionary of values.""" - entries = collections.OrderedDict() - with open(file_name, encoding="utf-8") as f: - for line in f: - m = CACHE_ENTRY_RE.match(line.rstrip("\n")) - if not m: - continue - - if m.group("type") == "BOOL": - value = CMAKE_BOOL_MAP[m.group("value").upper()] - else: - value = m.group("value") - - entries[m.group("name")] = value - - return entries - - -class BoardError(Exception): - """Raised when an attached board cannot be opened (i.e. missing /dev nodes, etc).""" - - -class BoardAutodetectFailed(Exception): - """Raised when no attached hardware is found matching the board= given to ZephyrCompiler.""" - - -class ZephyrFlasher(tvm.micro.compiler.Flasher): - """A Flasher implementation that delegates to Zephyr/west.""" - - def __init__( - self, - board, - qemu, - zephyr_base=None, - project_dir=None, - subprocess_env=None, - nrfjprog_snr=None, - openocd_serial=None, - flash_args=None, - debug_rpc_session=None, - serial_timeouts=None, - west_cmd=None, - ): - zephyr_base = zephyr_base or os.environ["ZEPHYR_BASE"] - sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts")) - try: - import dtlib # pylint: disable=import-outside-toplevel - - self._dtlib = dtlib - finally: - sys.path.pop(0) - - self._board = board - self._qemu = qemu - self._zephyr_base = zephyr_base - self._project_dir = project_dir - self._west_cmd = west_cmd - self._flash_args = flash_args - self._openocd_serial = openocd_serial - self._autodetected_openocd_serial = None - self._subprocess_env = SubprocessEnv(subprocess_env) - self._debug_rpc_session = debug_rpc_session - self._nrfjprog_snr = nrfjprog_snr - self._serial_timeouts = serial_timeouts - - def _get_nrf_device_args(self): - nrfjprog_args = ["nrfjprog", "--ids"] - nrfjprog_ids = subprocess.check_output(nrfjprog_args, encoding="utf-8") - if not nrfjprog_ids.strip("\n"): - raise BoardAutodetectFailed( - f'No attached boards recognized by {" ".join(nrfjprog_args)}' - ) - - boards = nrfjprog_ids.split("\n")[:-1] - if len(boards) > 1: - if self._nrfjprog_snr is None: - raise BoardError( - "Multiple boards connected; specify one with nrfjprog_snr=: " - f'{", ".join(boards)}' - ) - - if str(self._nrfjprog_snr) not in boards: - raise BoardError( - f"nrfjprog_snr ({self._nrfjprog_snr}) not found in {nrfjprog_args}: {boards}" - ) - - return ["--snr", str(self._nrfjprog_snr)] - - if not boards: - return [] - - return ["--snr", boards[0]] - - # kwargs passed to usb.core.find to find attached boards for the openocd flash runner. - BOARD_USB_FIND_KW = { - "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, - "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, - "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, - } - - def openocd_serial(self, cmake_entries): - """Find the serial port to use for a board with OpenOCD flash strategy.""" - if self._openocd_serial is not None: - return self._openocd_serial - - if self._autodetected_openocd_serial is None: - import usb # pylint: disable=import-outside-toplevel - - find_kw = self.BOARD_USB_FIND_KW[cmake_entries["BOARD"]] - boards = usb.core.find(find_all=True, **find_kw) - serials = [] - for b in boards: - serials.append(b.serial_number) - - if len(serials) == 0: - raise BoardAutodetectFailed(f"No attached USB devices matching: {find_kw!r}") - serials.sort() - - self._autodetected_openocd_serial = serials[0] - _LOG.debug("zephyr openocd driver: autodetected serial %s", serials[0]) - - return self._autodetected_openocd_serial - - def _get_openocd_device_args(self, cmake_entries): - return ["--serial", self.openocd_serial(cmake_entries)] - - @classmethod - def _get_flash_runner(cls, cmake_entries): - flash_runner = cmake_entries.get("ZEPHYR_BOARD_FLASH_RUNNER") - if flash_runner is not None: - return flash_runner - - with open(cmake_entries["ZEPHYR_RUNNERS_YAML"]) as f: - doc = yaml.load(f, Loader=yaml.FullLoader) - return doc["flash-runner"] - - def _get_device_args(self, cmake_entries): - flash_runner = self._get_flash_runner(cmake_entries) - - if flash_runner == "nrfjprog": - return self._get_nrf_device_args() - if flash_runner == "openocd": - return self._get_openocd_device_args(cmake_entries) - - raise BoardError( - f"Don't know how to find serial terminal for board {cmake_entries['BOARD']} with flash " - f"runner {flash_runner}" - ) - - def _zephyr_transport(self, micro_binary): - qemu_debugger = None - if self._debug_rpc_session: - qemu_debugger = debugger.RpcDebugger( - self._debug_rpc_session, - debugger.DebuggerFactory( - QemuGdbDebugger, - (micro_binary.abspath(micro_binary.debug_files[0]),), - {}, - ), - ) - - return ZephyrQemuTransport( - micro_binary.base_dir, startup_timeout_sec=30.0, qemu_debugger=qemu_debugger - ) - - def flash(self, micro_binary): - if self._qemu: - return self._zephyr_transport(micro_binary) - - cmake_cache_path = micro_binary.abspath(micro_binary.labelled_files["cmake_cache"][0]) - cmake_entries = read_cmake_cache(cmake_cache_path) - - build_dir = os.path.dirname(cmake_cache_path) - - # The nRF5340DK requires an additional `nrfjprog --recover` before each flash cycle. - # This is because readback protection is enabled by default when this device is flashed. - # Otherwise, flashing may fail with an error such as the following: - # ERROR: The operation attempted is unavailable due to readback protection in - # ERROR: your device. Please use --recover to unlock the device. - if ( - self._board.startswith("nrf5340dk") - and self._get_flash_runner(cmake_entries) == "nrfjprog" - ): - recover_args = ["nrfjprog", "--recover"] - recover_args.extend(self._get_nrf_device_args()) - self._subprocess_env.run(recover_args, cwd=build_dir) - - west_args = ( - self._west_cmd - + ["flash", "--build-dir", build_dir, "--skip-rebuild"] - + self._get_device_args(cmake_entries) - ) - if self._flash_args is not None: - west_args.extend(self._flash_args) - self._subprocess_env.run(west_args, cwd=build_dir) - - return self.transport(micro_binary) - - def _find_nrf_serial_port(self, cmake_entries): - com_ports = subprocess.check_output( - ["nrfjprog", "--com"] + self._get_device_args(cmake_entries), encoding="utf-8" - ) - ports_by_vcom = {} - for line in com_ports.split("\n")[:-1]: - parts = line.split() - ports_by_vcom[parts[2]] = parts[1] - - return {"port_path": ports_by_vcom["VCOM2"]} - - def _find_openocd_serial_port(self, cmake_entries): - return {"grep": self.openocd_serial(cmake_entries)} - - def _find_serial_port(self, micro_binary): - cmake_entries = read_cmake_cache( - micro_binary.abspath(micro_binary.labelled_files["cmake_cache"][0]) - ) - flash_runner = self._get_flash_runner(cmake_entries) - - if flash_runner == "nrfjprog": - return self._find_nrf_serial_port(cmake_entries) - - if flash_runner == "openocd": - return self._find_openocd_serial_port(cmake_entries) - - raise FlashRunnerNotSupported( - f"Don't know how to deduce serial port for flash runner {flash_runner}" - ) - - def transport(self, micro_binary): - """Instantiate the transport for use with non-QEMU Zephyr.""" - dt_inst = self._dtlib.DT( - micro_binary.abspath(micro_binary.labelled_files["device_tree"][0]) - ) - uart_baud = ( - dt_inst.get_node("/chosen") - .props["zephyr,console"] - .to_path() - .props["current-speed"] - .to_num() - ) - _LOG.debug("zephyr transport: found UART baudrate from devicetree: %d", uart_baud) - - port_kwargs = self._find_serial_port(micro_binary) - serial_transport = serial.SerialTransport( - timeouts=self._serial_timeouts, baudrate=uart_baud, **port_kwargs - ) - if self._debug_rpc_session is None: - return serial_transport - - return debug.DebugWrapperTransport( - debugger.RpcDebugger( - self._debug_rpc_session, - debugger.DebuggerFactory( - ZephyrDebugger, - ( - " ".join(shlex.quote(x) for x in self._west_cmd), - os.path.dirname(micro_binary.abspath(micro_binary.label("cmake_cache")[0])), - micro_binary.abspath(micro_binary.debug_files[0]), - self._zephyr_base, - ), - {}, - ), - ), - serial_transport, - ) - - -class QemuGdbDebugger(debugger.GdbDebugger): - def __init__(self, elf_file): - super(QemuGdbDebugger, self).__init__() - self._elf_file = elf_file - - def popen_kwargs(self): - # expect self._elf file to follow the form .../zephyr/zephyr.elf - cmake_cache_path = pathlib.Path(self._elf_file).parent.parent / "CMakeCache.txt" - cmake_cache = read_cmake_cache(cmake_cache_path) - return { - "args": [ - cmake_cache["CMAKE_GDB"], - "-ex", - "target remote localhost:1234", - "-ex", - f"file {self._elf_file}", - ], - } - - -class QemuStartupFailureError(Exception): - """Raised when the qemu pipe is not present within startup_timeout_sec.""" - - -class QemuFdTransport(file_descriptor.FdTransport): - """An FdTransport subclass that escapes written data to accommodate the QEMU monitor. - - It's supposedly possible to disable the monitor, but Zephyr controls most of the command-line - arguments for QEMU and there are too many options which implictly enable the monitor, so this - approach seems more robust. - """ - - def write_monitor_quit(self): - file_descriptor.FdTransport.write(self, b"\x01x", 1.0) - - def close(self): - file_descriptor.FdTransport.close(self) - - def timeouts(self): - assert False, "should not get here" - - def write(self, data, timeout_sec): - """Write data, escaping for QEMU monitor.""" - to_write = bytearray() - escape_pos = [] - for i, b in enumerate(data): - if b == 0x01: - to_write.append(b) - escape_pos.append(i) - to_write.append(b) - - num_written = file_descriptor.FdTransport.write(self, to_write, timeout_sec) - num_written -= sum(1 if x < num_written else 0 for x in escape_pos) - return num_written - - -class ZephyrQemuMakeResult(enum.Enum): - QEMU_STARTED = "qemu_started" - MAKE_FAILED = "make_failed" - EOF = "eof" - - -class ZephyrQemuTransport(Transport): - """The user-facing Zephyr QEMU transport class.""" - - def __init__(self, base_dir, startup_timeout_sec=5.0, qemu_debugger=None, **kwargs): - self.base_dir = base_dir - self.startup_timeout_sec = startup_timeout_sec - self.kwargs = kwargs - self.proc = None - self.fd_transport = None - self.pipe_dir = None - self.qemu_debugger = qemu_debugger - self._queue = queue.Queue() - - def timeouts(self): - return TransportTimeouts( - session_start_retry_timeout_sec=2.0, - session_start_timeout_sec=self.startup_timeout_sec, - session_established_timeout_sec=5.0 if self.qemu_debugger is None else 0, - ) - - def open(self): - self.pipe_dir = tempfile.mkdtemp() - self.pipe = os.path.join(self.pipe_dir, "fifo") - self.write_pipe = os.path.join(self.pipe_dir, "fifo.in") - self.read_pipe = os.path.join(self.pipe_dir, "fifo.out") - - os.mkfifo(self.write_pipe) - os.mkfifo(self.read_pipe) - if self.qemu_debugger is not None: - if "env" in self.kwargs: - self.kwargs["env"] = copy.copy(self.kwargs["env"]) - else: - self.kwargs["env"] = os.environ.copy() - - self.kwargs["env"]["TVM_QEMU_DEBUG"] = "1" - - self.proc = subprocess.Popen( - ["make", "run", f"QEMU_PIPE={self.pipe}"], - cwd=self.base_dir, - **self.kwargs, - stdout=subprocess.PIPE, - ) - try: - self._wait_for_qemu() - except Exception as error: - raise error - - if self.qemu_debugger is not None: - self.qemu_debugger.start() - - # NOTE: although each pipe is unidirectional, open both as RDWR to work around a select - # limitation on linux. Without this, non-blocking I/O can't use timeouts because named - # FIFO are always considered ready to read when no one has opened them for writing. - self.fd_transport = wakeup.WakeupTransport( - QemuFdTransport( - os.open(self.read_pipe, os.O_RDWR | os.O_NONBLOCK), - os.open(self.write_pipe, os.O_RDWR | os.O_NONBLOCK), - self.timeouts(), - ), - b"\xfe\xff\xfd\x03\0\0\0\0\0\x02" b"fw", - ) - self.fd_transport.open() - - def close(self): - if self.qemu_debugger is not None: - self.qemu_debugger.stop() - - if self.fd_transport is not None: - self.fd_transport.child_transport.write_monitor_quit() - self.proc.wait() - self.fd_transport.close() - self.fd_transport = None - - if self.proc is not None: - self.proc = None - - if self.pipe_dir is not None: - shutil.rmtree(self.pipe_dir) - self.pipe_dir = None - - def read(self, n, timeout_sec): - if self.fd_transport is None: - raise TransportClosedError() - return self.fd_transport.read(n, timeout_sec) - - def write(self, data, timeout_sec): - if self.fd_transport is None: - raise TransportClosedError() - return self.fd_transport.write(data, timeout_sec) - - def _qemu_check_stdout(self): - for line in self.proc.stdout: - line = str(line) - _LOG.debug(line) - if "[QEMU] CPU" in line: - self._queue.put(ZephyrQemuMakeResult.QEMU_STARTED) - else: - line = re.sub("[^a-zA-Z0-9 \n]", "", line) - pattern = r"recipe for target (\w*) failed" - if re.search(pattern, line, re.IGNORECASE): - self._queue.put(ZephyrQemuMakeResult.MAKE_FAILED) - self._queue.put(ZephyrQemuMakeResult.EOF) - - def _wait_for_qemu(self): - threading.Thread(target=self._qemu_check_stdout, daemon=True).start() - while True: - try: - item = self._queue.get(timeout=120) - except Exception: - raise TimeoutError("QEMU setup timeout.") - - if item == ZephyrQemuMakeResult.QEMU_STARTED: - break - - if item in [ZephyrQemuMakeResult.MAKE_FAILED, ZephyrQemuMakeResult.EOF]: - raise RuntimeError("QEMU setup failed.") - - raise ValueError(f"{item} not expected.") - - -class ZephyrDebugger(debugger.GdbDebugger): - """A Zephyr debugger implementation.""" - - def __init__(self, west_cmd, build_dir, elf_path, zephyr_base): - super(ZephyrDebugger, self).__init__() - self._west_cmd = shlex.split(west_cmd) - self._build_dir = build_dir - self._elf_path = elf_path - self._zephyr_base = zephyr_base - - def popen_kwargs(self): - env = dict(os.environ) - env["ZEPHYR_BASE"] = self._zephyr_base - - args = dict( - args=self._west_cmd - + [ - "debug", - "--skip-rebuild", - "--build-dir", - self._build_dir, - "--elf-file", - self._elf_path, - ], - env=env, - ) - return args diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py index 915bee08175c..8086b1ed6554 100644 --- a/python/tvm/micro/interface_api.py +++ b/python/tvm/micro/interface_api.py @@ -41,8 +41,12 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path): List of module output names to be placed in generated structs output_path : str Path to the output folder to generate the header into - """ + Returns + ------- + str : + Name of the generated file. + """ mangled_name = mangle_module_name(module_name) metadata_header = os.path.join(output_path, f"{mangled_name}.h") with open(metadata_header, "w") as header_file: @@ -77,3 +81,5 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path): ) header_file.write(f"#endif // {mangled_name.upper()}_H_\n") + + return metadata_header diff --git a/python/tvm/micro/micro_binary.py b/python/tvm/micro/micro_binary.py deleted file mode 100644 index 74b760b67650..000000000000 --- a/python/tvm/micro/micro_binary.py +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an Artifact implementation for representing compiled micro TVM binaries.""" - -from . import artifact - - -class MicroBinary(artifact.Artifact): - """An Artifact that describes a compiled binary.""" - - ARTIFACT_TYPE = "micro_binary" - - @classmethod - def from_unarchived(cls, base_dir, labelled_files, metadata, immobile): - binary_file = labelled_files["binary_file"][0] - del labelled_files["binary_file"] - - debug_files = None - if "debug_files" in labelled_files: - debug_files = labelled_files["debug_files"] - del labelled_files["debug_files"] - - return cls( - base_dir, - binary_file, - debug_files=debug_files, - labelled_files=labelled_files, - metadata=metadata, - immobile=immobile, - ) - - def __init__( - self, - base_dir, - binary_file, - debug_files=None, - labelled_files=None, - metadata=None, - immobile=False, - ): - labelled_files = {} if labelled_files is None else dict(labelled_files) - metadata = {} if metadata is None else dict(metadata) - labelled_files["binary_file"] = [binary_file] - if debug_files is not None: - labelled_files["debug_files"] = debug_files - - super(MicroBinary, self).__init__(base_dir, labelled_files, metadata, immobile=immobile) - - self.binary_file = binary_file - self.debug_files = debug_files diff --git a/python/tvm/micro/micro_library.py b/python/tvm/micro/micro_library.py deleted file mode 100644 index 74687ede1235..000000000000 --- a/python/tvm/micro/micro_library.py +++ /dev/null @@ -1,93 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an Artifact subclass that describes a compiled static library.""" - -from tvm.contrib import utils -from . import artifact -from . import compiler - - -class MicroLibrary(artifact.Artifact): - """An Artifact that describes a compiled static library.""" - - ARTIFACT_TYPE = "micro_library" - - @classmethod - def from_unarchived(cls, base_dir, labelled_files, metadata, immobile): - library_files = labelled_files["library_files"] - del labelled_files["library_files"] - - debug_files = None - if "debug_files" in labelled_files: - debug_files = labelled_files["debug_files"] - del labelled_files["debug_files"] - - return cls( - base_dir, - library_files, - debug_files=debug_files, - labelled_files=labelled_files, - metadata=metadata, - immobile=immobile, - ) - - def __init__( - self, - base_dir, - library_files, - debug_files=None, - labelled_files=None, - metadata=None, - immobile=False, - ): - labelled_files = {} if labelled_files is None else dict(labelled_files) - metadata = {} if metadata is None else dict(metadata) - labelled_files["library_files"] = library_files - if debug_files is not None: - labelled_files["debug_files"] = debug_files - - super(MicroLibrary, self).__init__(base_dir, labelled_files, metadata, immobile=immobile) - - self.library_files = library_files - self.debug_file = debug_files - - -def create_micro_library(output, objects, options=None): - """Create a MicroLibrary using the default compiler options. - - Parameters - ---------- - output : str - Path to the output file, expected to end in .tar. - objects : List[str] - Paths to the source files to include in the library. - options : Optional[List[str]] - If given, additional command-line flags for the compiler. - """ - temp_dir = utils.tempdir() - comp = compiler.DefaultCompiler() - output = temp_dir.relpath("micro-library.o") - comp.library(output, objects, options=options) - - with open(output, "rb") as output_f: - elf_data = output_f.read() - - # TODO(areusch): Define a mechanism to determine compiler and linker flags for each lib - # enabled by the target str, and embed here. - micro_lib = MicroLibrary("", elf_data, {"target": comp.target.str()}) - micro_lib.save(output) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 5e682c72ed73..ed44a3336a52 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -170,9 +170,14 @@ def _build_function_memory_map(function_metadata): target_local_entries[func_name] = list() for func_name, finfo in function_metadata.items(): - if func_name == MAIN_FUNC_NAME_STR: + # Skip a few unsupported cases: + # 1. The main function metadata is exported elsewhere. + # 2. BYOC operator implementations do not currently export useful FunctionInfo. + if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs: continue - assert len(finfo.constant_sizes.items()) == num_targets + assert ( + len(finfo.constant_sizes.items()) == num_targets + ), f"{func_name}: found {finfo.constant_sizes!r} vs {num_targets}" assert len(finfo.io_sizes.items()) == num_targets target = finfo.workspace_sizes.items()[i][0] workspace_size = finfo.workspace_sizes.items()[i][1] diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py new file mode 100644 index 000000000000..8d1408c679fb --- /dev/null +++ b/python/tvm/micro/project.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines glue wrappers around the Project API which mate to TVM interfaces.""" + +import pathlib +import typing + +from .. import __version__ +from ..contrib import utils +from .build import get_standalone_crt_dir +from .model_library_format import ExportableModule, export_model_library_format +from .project_api import client +from .transport import Transport, TransportTimeouts + + +class ProjectTransport(Transport): + """A Transport implementation that uses the Project API client.""" + + def __init__(self, api_client, options): + self._api_client = api_client + self._options = options + self._timeouts = None + + def timeouts(self): + assert self._timeouts is not None, "Transport not yet opened" + return self._timeouts + + def open(self): + reply = self._api_client.open_transport(self._options) + self._timeouts = TransportTimeouts(**reply["timeouts"]) + + def close(self): + if not self._api_client.is_shutdown: + self._api_client.close_transport() + self._api_client.shutdown() + + def write(self, data, timeout_sec): + self._api_client.write_transport(data, timeout_sec) + + def read(self, n, timeout_sec): + return self._api_client.read_transport(n, timeout_sec)["data"] + + +class TemplateProjectError(Exception): + """Raised when the Project API server given to GeneratedProject reports is_template=True.""" + + +class GeneratedProject: + """Defines a glue interface to interact with a generated project through the API server.""" + + @classmethod + def from_directory(cls, project_dir: typing.Union[pathlib.Path, str], options: dict): + return cls(client.instantiate_from_dir(project_dir), options) + + def __init__(self, api_client, options): + self._api_client = api_client + self._options = options + self._info = self._api_client.server_info_query(__version__) + if self._info["is_template"]: + raise TemplateProjectError() + + def build(self): + self._api_client.build(self._options) + + def flash(self): + self._api_client.flash(self._options) + + def transport(self): + return ProjectTransport(self._api_client, self._options) + + +class NotATemplateProjectError(Exception): + """Raised when the API server given to TemplateProject reports is_template=false.""" + + +class TemplateProject: + """Defines a glue interface to interact with a template project through the API Server.""" + + @classmethod + def from_directory(cls, template_project_dir, options): + return cls(client.instantiate_from_dir(template_project_dir), options) + + def __init__(self, api_client, options): + self._api_client = api_client + self._options = options + self._info = self._api_client.server_info_query(__version__) + if not self._info["is_template"]: + raise NotATemplateProjectError() + + def generate_project(self, graph_executor_factory, project_dir): + """Generate a project given GraphRuntimeFactory.""" + model_library_dir = utils.tempdir() + model_library_format_path = model_library_dir.relpath("model.tar") + export_model_library_format(graph_executor_factory, model_library_format_path) + + self._api_client.generate_project( + model_library_format_path=model_library_format_path, + standalone_crt_dir=get_standalone_crt_dir(), + project_dir=project_dir, + options=self._options, + ) + + return GeneratedProject.from_directory(project_dir, self._options) + + +def generate_project( + template_project_dir: typing.Union[pathlib.Path, str], + module: ExportableModule, + generated_project_dir: typing.Union[pathlib.Path, str], + options: dict = None, +): + """Generate a project for an embedded platform that contains the given model. + + Parameters + ---------- + template_project_path : pathlib.Path or str + Path to a template project containing a microTVM Project API server. + + generated_project_path : pathlib.Path or str + Path to a directory to be created and filled with the built project. + + module : ExportableModule + A runtime.Module exportable as Model Library Format. The value returned from tvm.relay.build + or tvm.build. + + options : dict + If given, Project API options given to the microTVM API server found in both + template_project_path and generated_project_path. + + Returns + ------- + GeneratedProject : + A class that wraps the generated project and which can be used to further interact with it. + """ + template = TemplateProject.from_directory(str(template_project_dir), options) + return template.generate_project(module, str(generated_project_dir)) diff --git a/python/tvm/micro/project_api/client.py b/python/tvm/micro/project_api/client.py new file mode 100644 index 000000000000..f650ad946d87 --- /dev/null +++ b/python/tvm/micro/project_api/client.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import io +import json +import logging +import os +import pathlib +import subprocess +import sys +import typing + +from . import server + +_LOG = logging.getLogger(__name__) + + +class ProjectAPIErrorBase(Exception): + """Base class for all Project API errors.""" + + +class ConnectionShutdownError(ProjectAPIErrorBase): + """Raised when a request is made but the connection has been closed.""" + + +class MalformedReplyError(ProjectAPIErrorBase): + """Raised when the server responds with an invalid reply.""" + + +class MismatchedIdError(ProjectAPIErrorBase): + """Raised when the reply ID does not match the request.""" + + +class ProjectAPIServerNotFoundError(ProjectAPIErrorBase): + """Raised when the Project API server can't be found in the repo.""" + + +class UnsupportedProtocolVersionError(ProjectAPIErrorBase): + """Raised when the protocol version returned by the API server is unsupported.""" + + +class RPCError(ProjectAPIErrorBase): + def __init__(self, request, error): + self.request = request + self.error = error + + def __str__(self): + return f"Calling project API method {self.request['method']}:" "\n" f"{self.error}" + + +class ProjectAPIClient: + """A client for the Project API.""" + + def __init__( + self, + read_file: typing.BinaryIO, + write_file: typing.BinaryIO, + testonly_did_write_request: typing.Optional[typing.Callable] = None, + ): + self.read_file = io.TextIOWrapper(read_file, encoding="UTF-8", errors="strict") + self.write_file = io.TextIOWrapper( + write_file, encoding="UTF-8", errors="strict", write_through=True + ) + self.testonly_did_write_request = testonly_did_write_request + self.next_request_id = 1 + + @property + def is_shutdown(self): + return self.read_file is None + + def shutdown(self): + if self.is_shutdown: + return + + self.read_file.close() + self.write_file.close() + + def _request_reply(self, method, params): + if self.is_shutdown: + raise ConnectionShutdownError("connection already closed") + + request = { + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": self.next_request_id, + } + self.next_request_id += 1 + + request_str = json.dumps(request) + self.write_file.write(request_str) + _LOG.debug("send -> %s", request_str) + self.write_file.write("\n") + if self.testonly_did_write_request: + self.testonly_did_write_request() # Allow test to assert on server processing. + reply_line = self.read_file.readline() + _LOG.debug("recv <- %s", reply_line) + if not reply_line: + self.shutdown() + raise ConnectionShutdownError("got EOF reading reply from API server") + + reply = json.loads(reply_line) + + if reply.get("jsonrpc") != "2.0": + raise MalformedReplyError( + f"Server reply should include 'jsonrpc': '2.0'; " + f"saw jsonrpc={reply.get('jsonrpc')!r}" + ) + + if reply["id"] != request["id"]: + raise MismatchedIdError( + f"Reply id ({reply['id']}) does not equal request id ({request['id']}" + ) + + if "error" in reply: + raise server.JSONRPCError.from_json(f"calling method {method}", reply["error"]) + elif "result" not in reply: + raise MalformedReplyError(f"Expected 'result' key in server reply, got {reply!r}") + + return reply["result"] + + def server_info_query(self, tvm_version: str): + reply = self._request_reply("server_info_query", {"tvm_version": tvm_version}) + if reply["protocol_version"] != server.ProjectAPIServer._PROTOCOL_VERSION: + raise UnsupportedProtocolVersionError( + f'microTVM API Server supports protocol version {reply["protocol_version"]}; ' + f"want {server.ProjectAPIServer._PROTOCOL_VERSION}" + ) + + return reply + + def generate_project( + self, + model_library_format_path: str, + standalone_crt_dir: str, + project_dir: str, + options: dict = None, + ): + return self._request_reply( + "generate_project", + { + "model_library_format_path": model_library_format_path, + "standalone_crt_dir": standalone_crt_dir, + "project_dir": project_dir, + "options": (options if options is not None else {}), + }, + ) + + def build(self, options: dict = None): + return self._request_reply("build", {"options": (options if options is not None else {})}) + + def flash(self, options: dict = None): + return self._request_reply("flash", {"options": (options if options is not None else {})}) + + def open_transport(self, options: dict = None): + return self._request_reply( + "open_transport", {"options": (options if options is not None else {})} + ) + + def close_transport(self): + return self._request_reply("close_transport", {}) + + def read_transport(self, n, timeout_sec): + reply = self._request_reply("read_transport", {"n": n, "timeout_sec": timeout_sec}) + reply["data"] = base64.b85decode(reply["data"]) + return reply + + def write_transport(self, data, timeout_sec): + return self._request_reply( + "write_transport", + {"data": str(base64.b85encode(data), "utf-8"), "timeout_sec": timeout_sec}, + ) + + +# NOTE: windows support untested +SERVER_LAUNCH_SCRIPT_FILENAME = ( + f"launch_microtvm_api_server.{'sh' if os.system != 'win32' else '.bat'}" +) + + +SERVER_PYTHON_FILENAME = "microtvm_api_server.py" + + +def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bool = False): + """Launch server located in project_dir, and instantiate a Project API Client connected to it.""" + args = None + + project_dir = pathlib.Path(project_dir) + + python_script = project_dir / SERVER_PYTHON_FILENAME + if python_script.is_file(): + args = [sys.executable, str(python_script)] + + launch_script = project_dir / SERVER_LAUNCH_SCRIPT_FILENAME + if launch_script.is_file(): + args = [str(launch_script)] + + if args is None: + raise ProjectAPIServerNotFoundError( + f"No Project API server found in project directory: {project_dir}" + "\n" + f"Tried: {SERVER_LAUNCH_SCRIPT_FILENAME}, {SERVER_PYTHON_FILENAME}" + ) + + api_server_read_fd, tvm_write_fd = os.pipe() + tvm_read_fd, api_server_write_fd = os.pipe() + + args.extend(["--read-fd", str(api_server_read_fd), "--write-fd", str(api_server_write_fd)]) + if debug: + args.append("--debug") + + api_server_proc = subprocess.Popen( + args, bufsize=0, pass_fds=(api_server_read_fd, api_server_write_fd), cwd=project_dir + ) + os.close(api_server_read_fd) + os.close(api_server_write_fd) + + return ProjectAPIClient( + os.fdopen(tvm_read_fd, "rb", buffering=0), os.fdopen(tvm_write_fd, "wb", buffering=0) + ) diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py new file mode 100644 index 000000000000..144f0cb6dee1 --- /dev/null +++ b/python/tvm/micro/project_api/server.py @@ -0,0 +1,776 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines a basic Project API server template. + +This file is meant to be imported or copied into Project API servers, so it should not have any +imports or dependencies outside of things strictly required to run the API server. +""" + +import abc +import argparse +import base64 +import collections +import enum +import io +import json +import logging +import os +import pathlib +import re +import select +import sys +import textwrap +import time +import traceback +import typing + + +_LOG = logging.getLogger(__name__) + + +_ProjectOption = collections.namedtuple("ProjectOption", ("name", "choices", "help")) + + +class ProjectOption(_ProjectOption): + def __new__(cls, name, **kw): + """Override __new__ to force all options except name to be specified as kwargs.""" + assert "name" not in kw + kw["name"] = name + kw.setdefault("choices", None) + return super().__new__(cls, **kw) + + +ServerInfo = collections.namedtuple( + "ServerInfo", ("platform_name", "is_template", "model_library_format_path", "project_options") +) + + +# Timeouts supported by the underlying C++ MicroSession. +# +# session_start_retry_timeout_sec : float +# Number of seconds to wait for the device to send a kSessionStartReply after sending the +# initial session start message. After this time elapses another +# kSessionTerminated-kSessionStartInit train is sent. 0 disables this. +# session_start_timeout_sec : float +# Total number of seconds to wait for the session to be established. After this time, the +# client gives up trying to establish a session and raises an exception. +# session_established_timeout_sec : float +# Number of seconds to wait for a reply message after a session has been established. 0 +# disables this. +TransportTimeouts = collections.namedtuple( + "TransportTimeouts", + [ + "session_start_retry_timeout_sec", + "session_start_timeout_sec", + "session_established_timeout_sec", + ], +) + + +class ErrorCode(enum.IntEnum): + """Enumerates error codes which can be returned. Includes JSON-RPC standard and custom codes.""" + + # Custom (in reserved error code space). + SERVER_ERROR = -32000 # A generic error was raised while processing the request. + + # JSON-RPC standard + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + + +class JSONRPCError(Exception): + """An error class with properties that meet the JSON-RPC error spec.""" + + def __init__(self, code, message, data, client_context=None): + self.code = code + self.message = message + self.data = data + self.client_context = client_context + + def to_json(self): + return { + "code": self.code, + "message": self.message, + "data": self.data, + } + + def __str__(self): + data_str = "" + if self.data: + if isinstance(self.data, dict) and self.data.get("traceback"): + data_str = f'\n{self.data["traceback"]}' + else: + data_str = f"\n{self.data!r}" + return f"JSON-RPC error # {self.code}: {self.message}" + data_str + + @classmethod + def from_json(cls, client_context, json_error): + # Subclasses of ServerError capture exceptions that occur in the Handler, and thus return a + # traceback. The encoding in `json_error` is also slightly different to allow the specific subclass + # to be identified. + found_server_error = False + try: + if ErrorCode(json_error["code"]) == ErrorCode.SERVER_ERROR: + found_server_error = True + except ValueError: + ServerError.from_json(client_context, json_error) + + if found_server_error: + return ServerError.from_json(client_context, json_error) + + return cls( + json_error["code"], + json_error["message"], + json_error.get("data", None), + client_context=client_context, + ) + + +class ServerError(JSONRPCError): + @classmethod + def from_exception(cls, exc, **kw): + to_return = cls(**kw) + to_return.set_traceback(traceback.TracebackException.from_exception(exc).format()) + return to_return + + def __init__(self, message=None, data=None, client_context=None): + if self.__class__ == ServerError: + assert message is not None, "Plain ServerError must have message=" + else: + assert ( + message is None + ), f"ServerError subclasses must not supply message=; got {message!r}" + message = self.__class__.__name__ + + super(ServerError, self).__init__(ErrorCode.SERVER_ERROR, message, data) + self.client_context = client_context + + def __str__(self): + context_str = f"{self.client_context}: " if self.client_context is not None else "" + super_str = super(ServerError, self).__str__() + return context_str + super_str + + def set_traceback(self, traceback): + if self.data is None: + self.data = {} + + if "traceback" not in self.data: + # NOTE: TVM's FFI layer reorders Python stack traces several times and strips + # intermediary lines that start with "Traceback". This logic adds a comment to the first + # stack frame to explicitly identify the first stack frame line that occurs on the server. + traceback_list = list(traceback) + + # The traceback list contains one entry per stack frame, and each entry contains 1-2 lines: + # File "path/to/file", line 123, in : + # + # We want to place a comment on the first line of the outermost frame to indicate this is the + # server-side stack frame. + first_frame_list = traceback_list[1].split("\n") + self.data["traceback"] = ( + traceback_list[0] + + f"{first_frame_list[0]} # <--- Outermost server-side stack frame\n" + + "\n".join(first_frame_list[1:]) + + "".join(traceback_list[2:]) + ) + + @classmethod + def from_json(cls, client_context, json_error): + assert json_error["code"] == ErrorCode.SERVER_ERROR + + for sub_cls in cls.__subclasses__(): + if sub_cls.__name__ == json_error["message"]: + return sub_cls( + data=json_error.get("data"), + client_context=client_context, + ) + + return cls( + json_error["message"], data=json_error.get("data"), client_context=client_context + ) + + +class TransportClosedError(ServerError): + """Raised when a transport can no longer be used due to underlying I/O problems.""" + + +class IoTimeoutError(ServerError): + """Raised when the I/O operation could not be completed before the timeout. + + Specifically: + - when no data could be read before the timeout + - when some of the write data could be written before the timeout + + Note the asymmetric behavior of read() vs write(), since in one case the total length of the + data to transfer is known. + """ + + +class UnsupportedTVMVersionError(ServerError): + """Raised when the version of TVM supplied to server_info_query is unsupported.""" + + +class ProjectAPIHandler(metaclass=abc.ABCMeta): + """The interface class for all Project API implementations. + + Extend this class in your microtvm_api_server.py and implement each function defined here. + """ + + @abc.abstractmethod + def server_info_query(self, tvm_version: str) -> ServerInfo: + """Initial request issued by TVM to retrieve metadata about this API server and project. + + Should this API server not + + Parameters + ---------- + tvm_version : str + The value of tvm.__version__. + + Returns + ------- + ServerInfo : + A ServerInfo namedtuple containing the metadata needed by TVM. + + Raises + ------ + UnsupportedTVMVersionError : + When tvm_version indicates a known-unsupported version of TVM. + """ + raise NotImplementedError() + + @abc.abstractmethod + def generate_project( + self, + model_library_format_path: pathlib.Path, + standalone_crt_dir: pathlib.Path, + project_dir: pathlib.Path, + options: dict, + ): + """Generate a project from the given artifacts, copying ourselves to that project. + + Parameters + ---------- + model_library_format_path : pathlib.Path + Path to the Model Library Format tar archive. + standalone_crt_dir : pathlib.Path + Path to the root directory of the "standalone_crt" TVM build artifact. This contains the + TVM C runtime. + project_dir : pathlib.Path + Path to a nonexistent directory which should be created and filled with the generated + project. + options : dict + Dict mapping option name to ProjectOption. + """ + raise NotImplementedError() + + @abc.abstractmethod + def build(self, options: dict): + """Build the project, enabling the flash() call to made. + + Parameters + ---------- + options : Dict[str, ProjectOption] + ProjectOption which may influence the build, keyed by option name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def flash(self, options: dict): + """Program the project onto the device. + + Parameters + ---------- + options : Dict[str, ProjectOption] + ProjectOption which may influence the programming process, keyed by option name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def open_transport(self, options: dict) -> TransportTimeouts: + """Open resources needed for the transport layer. + + This function might e.g. open files or serial ports needed in write_transport or read_transport. + + Calling this function enables the write_transport and read_transport calls. If the + transport is not open, this method is a no-op. + + Parameters + ---------- + options : Dict[str, ProjectOption] + ProjectOption which may influence the programming process, keyed by option name. + """ + raise NotImplementedError() + + @abc.abstractmethod + def close_transport(self): + """Close resources needed to operate the transport layer. + + This function might e.g. close files or serial ports needed in write_transport or read_transport. + + Calling this function disables the write_transport and read_transport calls. If the + transport is not open, this method is a no-op. + """ + raise NotImplementedError() + + @abc.abstractmethod + def read_transport(self, n: int, timeout_sec: typing.Union[float, type(None)]) -> bytes: + """Read data from the transport. + + Parameters + ---------- + n : int + The exact number of bytes to read from the transport. + timeout_sec : Union[float, None] + Number of seconds to wait for at least one byte to be written before timing out. If + timeout_sec is 0, write should attempt to service the request in a non-blocking fashion. + If timeout_sec is None, write should block until all `n` bytes of data can be returned. + + Returns + ------- + bytes : + Data read from the channel. Should be exactly `n` bytes long. + + Raises + ------ + TransportClosedError : + When the transport layer determines that the transport can no longer send or receive + data due to an underlying I/O problem (i.e. file descriptor closed, cable removed, etc). + + IoTimeoutError : + When `timeout_sec` elapses without receiving any data. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write_transport(self, data: bytes, timeout_sec: float): + """Write data to the transport. + + This function should either write all bytes in `data` or raise an exception. + + Parameters + ---------- + data : bytes + The data to write over the channel. + timeout_sec : Union[float, None] + Number of seconds to wait for all bytes to be written before timing out. If timeout_sec + is 0, write should attempt to service the request in a non-blocking fashion. If + timeout_sec is None, write should block until it has written all data. + + Raises + ------ + TransportClosedError : + When the transport layer determines that the transport can no longer send or receive + data due to an underlying I/O problem (i.e. file descriptor closed, cable removed, etc). + + IoTimeoutError : + When `timeout_sec` elapses without receiving any data. + """ + raise NotImplementedError() + + +class ProjectAPIServer: + """Base class for Project API Servers. + + This API server implements communication using JSON-RPC 2.0: https://www.jsonrpc.org/specification + + Suggested use of this class is to import this module or copy this file into Project Generator + implementations, then instantiate it with server.start(). + + This RPC server is single-threaded, blocking, and one-request-at-a-time. Don't get anxious. + """ + + _PROTOCOL_VERSION = 1 + + def __init__( + self, read_file: typing.BinaryIO, write_file: typing.BinaryIO, handler: ProjectAPIHandler + ): + """Initialize a new ProjectAPIServer. + + Parameters + ---------- + read_file : BinaryIO + A file-like object used to read binary data from the client. + write_file : BinaryIO + A file-like object used to write binary data to the client. + handler : ProjectAPIHandler + A class which extends the abstract class ProjectAPIHandler and implements the server RPC + functions. + """ + self._read_file = io.TextIOWrapper(read_file, encoding="UTF-8", errors="strict") + self._write_file = io.TextIOWrapper( + write_file, encoding="UTF-8", errors="strict", write_through=True + ) + self._handler = handler + + def serve_forever(self): + """Serve requests until no more are available.""" + has_more = True + while has_more: + has_more = self.serve_one_request() + + def serve_one_request(self): + """Read, process, and reply to a single request from read_file. + + When errors occur reading the request line or loading the request into JSON, they are + propagated to the caller (the stream is then likely corrupted and no further requests + should be served. When errors occur past this point, they are caught and send back to the + client. + + Return + ---------- + bool : + True when more data could be read from read_file, False otherwise. + """ + try: + line = self._read_file.readline() + _LOG.debug("read request <- %s", line) + if not line: + return False + + request = json.loads(line) + + except EOFError: + _LOG.error("EOF") + return False + + except Exception as exc: + _LOG.error("Caught error reading request", exc_info=1) + return False + + did_validate = False + try: + self._validate_request(request) + did_validate = True + self._dispatch_request(request) + except JSONRPCError as exc: + if isinstance(exc, ServerError): + exc.set_traceback(traceback.TracebackException.from_exception(exc).format()) + request_id = None if not did_validate else request.get("id") + self._reply_error(request_id, exc) + return did_validate + except Exception as exc: + message = "validating request" + if did_validate: + message = f"calling method {request['method']}" + + exc = ServerError.from_exception(exc, message=message) + request_id = None if not isinstance(request, dict) else request.get("id") + self._reply_error(request_id, exc) + return did_validate + + return True + + VALID_METHOD_RE = re.compile("^[a-zA-Z0-9_]+$") + + def _validate_request(self, request): + if type(request) is not dict: + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f"request: want dict; got {request!r}", None + ) + + jsonrpc = request.get("jsonrpc") + if jsonrpc != "2.0": + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f'request["jsonrpc"]: want "2.0"; got {jsonrpc!r}', None + ) + + method = request.get("method") + if type(method) != str: + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f'request["method"]: want str; got {method!r}', None + ) + + if not self.VALID_METHOD_RE.match(method): + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, + f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; got {method!r}', + None, + ) + + params = request.get("params") + if type(params) != dict: + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, f'request["params"]: want dict; got {type(params)}', None + ) + + request_id = request.get("id") + if type(request_id) not in (str, int, type(None)): + raise JSONRPCError( + ErrorCode.INVALID_REQUEST, + f'request["id"]: want str, number, null; got {request_id!r}', + None, + ) + + def _dispatch_request(self, request): + method = request["method"] + + interface_method = getattr(ProjectAPIHandler, method, None) + if interface_method is None: + raise JSONRPCError( + ErrorCode.METHOD_NOT_FOUND, f'{request["method"]}: no such method', None + ) + + has_preprocessing = True + dispatch_method = getattr(self, f"_dispatch_{method}", None) + if dispatch_method is None: + dispatch_method = getattr(self._handler, method) + has_preprocessing = False + + request_params = request["params"] + params = {} + + for var_name, var_type in typing.get_type_hints(interface_method).items(): + if var_name == "self" or var_name == "return": + continue + + # NOTE: types can only be JSON-compatible types, so var_type is expected to be of type 'type'. + if var_name not in request_params: + raise JSONRPCError( + ErrorCode.INVALID_PARAMS, + f'method {request["method"]}: parameter {var_name} not given', + None, + ) + + param = request_params[var_name] + if not has_preprocessing and not isinstance(param, var_type): + raise JSONRPCError( + ErrorCode.INVALID_PARAMS, + f'method {request["method"]}: parameter {var_name}: want {var_type!r}, got {type(param)!r}', + None, + ) + + params[var_name] = param + + extra_params = [p for p in request["params"] if p not in params] + if extra_params: + raise JSONRPCError( + ErrorCode.INVALID_PARAMS, + f'{request["method"]}: extra parameters: {", ".join(extra_params)}', + None, + ) + + return_value = dispatch_method(**params) + self._write_reply(request["id"], result=return_value) + + def _write_reply(self, request_id, result=None, error=None): + reply_dict = { + "jsonrpc": "2.0", + "id": request_id, + } + + if error is not None: + assert ( + result is None + ), f"Want either result= or error=, got result={result!r} and error={error!r})" + reply_dict["error"] = error + else: + reply_dict["result"] = result + + reply_str = json.dumps(reply_dict) + _LOG.debug("write reply -> %r", reply_dict) + self._write_file.write(reply_str) + self._write_file.write("\n") + + def _reply_error(self, request_id, exception): + self._write_reply(request_id, error=exception.to_json()) + + def _dispatch_generate_project( + self, model_library_format_path, standalone_crt_dir, project_dir, options + ): + return self._handler.generate_project( + pathlib.Path(model_library_format_path), + pathlib.Path(standalone_crt_dir), + pathlib.Path(project_dir), + options, + ) + + def _dispatch_server_info_query(self, tvm_version): + query_reply = self._handler.server_info_query(tvm_version) + to_return = query_reply._asdict() + if to_return["model_library_format_path"] is not None: + to_return["model_library_format_path"] = str(to_return["model_library_format_path"]) + to_return.setdefault("protocol_version", self._PROTOCOL_VERSION) + to_return["project_options"] = [o._asdict() for o in query_reply.project_options] + return to_return + + def _dispatch_open_transport(self, options): + reply = self._handler.open_transport(options) + return {"timeouts": reply._asdict()} + + def _dispatch_read_transport(self, n, timeout_sec): + reply_data = self._handler.read_transport(n, timeout_sec) + return {"data": str(base64.b85encode(reply_data), "utf-8")} + + def _dispatch_write_transport(self, data, timeout_sec): + self._handler.write_transport(base64.b85decode(data), timeout_sec) + + +def _await_nonblocking_ready(rlist, wlist, timeout_sec=None, end_time=None): + if end_time is None: + return True + + if timeout_sec is None: + timeout_sec = max(0, end_time - time.monotonic()) + rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) + if not rlist and not wlist and not xlist: + raise IoTimeoutError() + + return True + + +def read_with_timeout(fd, n, timeout_sec): + """Read data from a file descriptor, with timeout. + + This function is intended as a helper function for implementations of ProjectAPIHandler + read_transport. Tested on Linux and OS X. Not tested on Windows. + + Parameters + ---------- + fd : int + File descriptor to read from. Must be opened in non-blocking mode (e.g. with O_NONBLOCK) + if timeout_sec is not None. + + n : int + Maximum number of bytes to read. + + timeout_sec : float or None + If not None, maximum number of seconds to wait before raising IoTimeoutError. + + Returns + ------- + bytes : + If at least one byte was received before timeout_sec, returns a bytes object with length + in [1, n]. If timeout_sec is None, returns the equivalent of os.read(fd, n). + + Raises + ------ + IoTimeoutException : + When timeout_sec is not None and that number of seconds elapses before any data is read. + """ + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + while True: + _await_nonblocking_ready([fd], [], end_time=end_time) + try: + to_return = os.read(fd, n) + break + except BlockingIOError: + pass + + # When EOF is reached, close the file. + if not to_return: + os.close(fd) + raise TransportClosedError() + + return to_return + + +def write_with_timeout(fd, data, timeout_sec): + """Write data to a file descriptor, with timeout. + + This function is intended as a helper function for implementations of ProjectAPIHandler + write_transport. Tested on Linux and OS X. Not tested on Windows. + + Parameters + ---------- + fd : int + File descriptor to read from. Must be opened in non-blocking mode (e.g. with O_NONBLOCK) + if timeout_sec is not None. + + data : bytes + Data to write. + + timeout_sec : float or None + If not None, maximum number of seconds to wait before raising IoTimeoutError. + + Returns + ------- + int : + The number of bytes written to the file descriptor, if any bytes were written. A value + in [1, len(data)]. If timeout_sec is None, returns the equivalent of os.write(fd, data). + + Raises + ------ + IoTimeoutException : + When timeout_sec is not None and that number of seconds elapses before any data is read. + """ + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + num_written = 0 + while data: + try: + _await_nonblocking_ready([], [fd], end_time=end_time) + except IoTimeoutError as exc: + if num_written: + return num_written + + raise exc + + num_written_this_cycle = os.write(fd, data) + + if not num_written_this_cycle: + os.close(fd) + raise base.TransportClosedError() + + data = data[num_written_this_cycle:] + num_written += num_written_this_cycle + + return num_written + + +def main(handler: ProjectAPIHandler, argv: typing.List[str] = None): + """Start a Project API server. + + Parameters + ---------- + argv : list[str] + Command-line parameters to this program. If not given, sys.argv is used. + handler : ProjectAPIHandler + Handler class that implements the API server RPC calls. + """ + if argv is None: + argv = sys.argv[1:] + + parser = argparse.ArgumentParser(description="Generic TVM Project API server entry point") + parser.add_argument( + "--read-fd", + type=int, + required=True, + help="Numeric file descriptor where RPC requests should be read.", + ) + parser.add_argument( + "--write-fd", + type=int, + required=True, + help="Numeric file descriptor where RPC replies should be written.", + ) + parser.add_argument( + "--debug", action="store_true", help="When given, configure logging at DEBUG level." + ) + args = parser.parse_args() + + logging.basicConfig(level="DEBUG" if args.debug else "INFO", stream=sys.stderr) + + read_file = os.fdopen(args.read_fd, "rb", buffering=0) + write_file = os.fdopen(args.write_fd, "wb", buffering=0) + + server = ProjectAPIServer(read_file, write_file, handler) + server.serve_forever() diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 78bf03379939..d4ad5b84fb76 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -60,8 +60,6 @@ class Session: def __init__( self, - binary=None, - flasher=None, transport_context_manager=None, session_name="micro-rpc", timeout_override=None, @@ -70,12 +68,6 @@ def __init__( Parameters ---------- - binary : MicroBinary - If given, `flasher` must also be given. During session initialization, this binary will - be flashed to the device before the transport is created. - flasher : Flasher - If given, `binary` must also be given. Used to flash `binary` during session - initialization. transport_context_manager : ContextManager[transport.Transport] If given, `flasher` and `binary` should not be given. On entry, this context manager should establish a tarnsport between this TVM instance and the device. @@ -85,8 +77,6 @@ def __init__( If given, TransportTimeouts that govern the way Receive() behaves. If not given, this is determined by calling has_flow_control() on the transport. """ - self.binary = binary - self.flasher = flasher self.transport_context_manager = transport_context_manager self.session_name = session_name self.timeout_override = timeout_override @@ -106,12 +96,11 @@ def _wrap_transport_read(self, n, timeout_microsec): return bytes([]) def _wrap_transport_write(self, data, timeout_microsec): - try: - return self.transport.write( - data, float(timeout_microsec) / 1e6 if timeout_microsec is not None else None - ) - except IoTimeoutError: - return 0 + self.transport.write( + data, float(timeout_microsec) / 1e6 if timeout_microsec is not None else None + ) + + return len(data) # TODO(areusch): delete def __enter__(self): """Initialize this session and establish an RPC session with the on-device RPC server. @@ -121,9 +110,6 @@ def __enter__(self): Session : Returns self. """ - if self.flasher is not None: - self.transport_context_manager = self.flasher.flash(self.binary) - self.transport = TransportLogger( self.session_name, self.transport_context_manager, level=logging.DEBUG ).__enter__() diff --git a/python/tvm/micro/transport/base.py b/python/tvm/micro/transport.py similarity index 84% rename from python/tvm/micro/transport/base.py rename to python/tvm/micro/transport.py index fdc7e9b2afce..8e95ff7ea77a 100644 --- a/python/tvm/micro/transport/base.py +++ b/python/tvm/micro/transport.py @@ -18,50 +18,18 @@ """Defines abstractions and implementations of the RPC transport used with micro TVM.""" import abc -import collections import logging import string import typing -_LOG = logging.getLogger(__name__) - - -class TransportClosedError(Exception): - """Raised when a transport can no longer be used due to underlying I/O problems.""" +from .project_api.server import IoTimeoutError, TransportTimeouts +from .project_api.server import TransportClosedError -class IoTimeoutError(Exception): - """Raised when the I/O operation could not be completed before the timeout. +_ = TransportClosedError # work around pylint unused-import error - Specifically: - - when no data could be read before the timeout - - when some of the write data could be written before the timeout - Note the asymmetric behavior of read() vs write(), since in one case the total length of the - data to transfer is known. - """ - - -# Timeouts supported by the underlying C++ MicroSession. -# -# session_start_retry_timeout_sec : float -# Number of seconds to wait for the device to send a kSessionStartReply after sending the -# initial session start message. After this time elapses another -# kSessionTerminated-kSessionStartInit train is sent. 0 disables this. -# session_start_timeout_sec : float -# Total number of seconds to wait for the session to be established. After this time, the -# client gives up trying to establish a session and raises an exception. -# session_established_timeout_sec : float -# Number of seconds to wait for a reply message after a session has been established. 0 -# disables this. -TransportTimeouts = collections.namedtuple( - "TransportTimeouts", - [ - "session_start_retry_timeout_sec", - "session_start_timeout_sec", - "session_established_timeout_sec", - ], -) +_LOG = logging.getLogger(__name__) def debug_transport_timeouts(session_start_retry_timeout_sec=0): @@ -263,7 +231,7 @@ def read(self, n, timeout_sec): def write(self, data, timeout_sec): timeout_str = f"{timeout_sec:5.2f}s" if timeout_sec is not None else " None " try: - bytes_written = self.child.write(data, timeout_sec) + self.child.write(data, timeout_sec) except IoTimeoutError: self.logger.log( self.level, @@ -286,14 +254,14 @@ def write(self, data, timeout_sec): ) raise err - hex_lines = self._to_hex(data[:bytes_written]) + hex_lines = self._to_hex(data) if len(hex_lines) > 1: self.logger.log( self.level, "%s: write {%s} <- [%3d B]:\n%s", self.name, timeout_str, - bytes_written, + len(data), "\n".join(hex_lines), ) else: @@ -302,11 +270,9 @@ def write(self, data, timeout_sec): "%s: write {%s} <- [%3d B]: %s", self.name, timeout_str, - bytes_written, + len(data), hex_lines[0], ) - return bytes_written - TransportContextManager = typing.ContextManager[Transport] diff --git a/python/tvm/micro/transport/__init__.py b/python/tvm/micro/transport/__init__.py deleted file mode 100644 index dffe9ae32792..000000000000 --- a/python/tvm/micro/transport/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines abstractions and implementations related to the microTVM RPC transport layer.""" - -from .base import IoTimeoutError -from .base import Transport -from .base import TransportClosedError -from .base import TransportLogger -from .base import TransportTimeouts -from .base import debug_transport_timeouts -from .debug import DebugWrapperTransport -from .subprocess import SubprocessTransport diff --git a/python/tvm/micro/transport/debug.py b/python/tvm/micro/transport/debug.py deleted file mode 100644 index 71e12c7ed391..000000000000 --- a/python/tvm/micro/transport/debug.py +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines a wrapper Transport class that launches a debugger before opening.""" - -from .base import Transport, TransportTimeouts - - -class DebugWrapperTransport(Transport): - """A Transport wrapper class that launches a debugger before opening the transport. - - This is primiarly useful when debugging the other end of a SubprocessTransport. It allows you - to pipe data through the GDB process to drive the subprocess with a debugger attached. - """ - - def __init__(self, debugger, transport, disable_session_start_retry=False): - self.debugger = debugger - self.transport = transport - self.disable_session_start_retry = disable_session_start_retry - - def timeouts(self): - child_timeouts = self.transport.timeouts() - return TransportTimeouts( - session_start_retry_timeout_sec=( - 0 - if self.disable_session_start_retry - else child_timeouts.session_start_retry_timeout_sec - ), - session_start_timeout_sec=0, - session_established_timeout_sec=0, - ) - - def open(self): - self.debugger.start() - - try: - self.transport.open() - except Exception: - self.debugger.stop() - raise - - def write(self, data, timeout_sec): - return self.transport.write(data, timeout_sec) - - def read(self, n, timeout_sec): - return self.transport.read(n, timeout_sec) - - def close(self): - self.transport.close() - self.debugger.stop() diff --git a/python/tvm/micro/transport/file_descriptor.py b/python/tvm/micro/transport/file_descriptor.py deleted file mode 100644 index 58c4026f6704..000000000000 --- a/python/tvm/micro/transport/file_descriptor.py +++ /dev/null @@ -1,119 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an implementation of Transport that uses file descriptors.""" - -import fcntl -import os -import select -import time -from . import base - - -class FdConfigurationError(Exception): - """Raised when specified file descriptors can't be placed in non-blocking mode.""" - - -class FdTransport(base.Transport): - """A Transport implementation that implements timeouts using non-blocking I/O.""" - - @classmethod - def _validate_configure_fd(cls, file_descriptor): - file_descriptor = ( - file_descriptor if isinstance(file_descriptor, int) else file_descriptor.fileno() - ) - flag = fcntl.fcntl(file_descriptor, fcntl.F_GETFL) - if flag & os.O_NONBLOCK != 0: - return file_descriptor - - fcntl.fcntl(file_descriptor, fcntl.F_SETFL, os.O_NONBLOCK | flag) - new_flag = fcntl.fcntl(file_descriptor, fcntl.F_GETFL) - if (new_flag & os.O_NONBLOCK) == 0: - raise FdConfigurationError( - f"Cannot set file descriptor {file_descriptor} to non-blocking" - ) - return file_descriptor - - def __init__(self, read_fd, write_fd, timeouts): - self.read_fd = self._validate_configure_fd(read_fd) - self.write_fd = self._validate_configure_fd(write_fd) - self._timeouts = timeouts - - def timeouts(self): - return self._timeouts - - def open(self): - pass - - def close(self): - if self.read_fd is not None: - os.close(self.read_fd) - self.read_fd = None - - if self.write_fd is not None: - os.close(self.write_fd) - self.write_fd = None - - def _await_ready(self, rlist, wlist, timeout_sec=None, end_time=None): - if end_time is None: - return True - - if timeout_sec is None: - timeout_sec = max(0, end_time - time.monotonic()) - rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) - if not rlist and not wlist and not xlist: - raise base.IoTimeoutError() - - return True - - def read(self, n, timeout_sec): - if self.read_fd is None: - raise base.TransportClosedError() - - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - - while True: - self._await_ready([self.read_fd], [], end_time=end_time) - try: - to_return = os.read(self.read_fd, n) - break - except BlockingIOError: - pass - - if not to_return: - self.close() - raise base.TransportClosedError() - - return to_return - - def write(self, data, timeout_sec): - if self.write_fd is None: - raise base.TransportClosedError() - - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - - data_len = len(data) - while data: - self._await_ready(end_time, [], [self.write_fd]) - num_written = os.write(self.write_fd, data) - if not num_written: - self.close() - raise base.TransportClosedError() - - data = data[num_written:] - - return data_len diff --git a/python/tvm/micro/transport/serial.py b/python/tvm/micro/transport/serial.py deleted file mode 100644 index dc107d68abc2..000000000000 --- a/python/tvm/micro/transport/serial.py +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines a Transport implementation using pyserial.""" - -import atexit -import time -import serial -import serial.tools.list_ports -from .base import IoTimeoutError, Transport, TransportTimeouts - - -_DEFAULT_SERIAL_TIMEOUTS = TransportTimeouts( - session_start_retry_timeout_sec=5, - session_start_timeout_sec=10.0, - session_established_timeout_sec=30.0, -) - - -class SerialTransport(Transport): - """A Transport implementation using pySerial.""" - - _OPEN_PORTS = [] - - @classmethod - def close_atexit(cls): - """Close all serial ports before exit. - - Some USB-UART kernel drivers are particularly sensitive to being left open (i.e. require - unplugging and replugging of attached hardware or reboot of machine); try very hard to - close all serial ports at exit. - """ - for port in cls._OPEN_PORTS: - try: - port.close() - except Exception: # pylint: disable=broad-except - _LOG.warn("exception closing port", exc_info=True) - - cls._OPEN_PORTS = [] - - def __init__(self, grep=None, port_path=None, timeouts=None, **kw): - self._port_path = port_path - self._grep = grep - self._timeouts = timeouts if timeouts is not None else _DEFAULT_SERIAL_TIMEOUTS - self._kw = kw - if self._port_path is None and self._grep is None: - raise SerialPortNotFoundError("Must specify one of grep= or port_path=") - - def timeouts(self): - return self._timeouts - - def open(self): - if self._port_path is not None: - port_path = self._port_path - else: - ports = list(serial.tools.list_ports.grep(self._grep)) - if len(ports) != 1: - raise SerialPortNotFoundError( - f"grep expression should find 1 serial port; found {ports!r}" - ) - - port_path = ports[0].device - - self._port = serial.Serial(port_path, timeout=0.1, exclusive=True, **self._kw) - self._port.cancel_read() - self._port.reset_input_buffer() - self._port.reset_output_buffer() - self._OPEN_PORTS.append(self._port) - - def close(self): - if self._port is None: - return - - self._port.close() - self._OPEN_PORTS.remove(self._port) - self._port = None - - def read(self, n, timeout_sec): - if timeout_sec is None: - self._port.timeout = None - in_waiting = self._port.in_waiting - if in_waiting > 0: - return self._port.read(min(n, in_waiting)) - return self._port.read(1) - - end_time = time.monotonic() + timeout_sec - to_return = bytearray() - while True: - timeout_remaining = end_time - time.monotonic() - if timeout_sec != 0 and timeout_remaining < 0: - break - - # Read until *something* can be returned. If nothing is sent within 5 chars' time, stop. - # 5 is an arbitrary number. - self._port.timeout = 1 / self._port.baudrate * 5 - try: - data = self._port.read(n if timeout_sec != 0 else 1) - if not data and to_return: - break - - to_return.extend(data) - except serial.SerialTimeoutException: - if to_return: - break - - if not to_return: - raise IoTimeoutError() - - return to_return - - def write(self, data, timeout_sec): - self._port.write_timeout = timeout_sec - try: - to_return = self._port.write(data) - self._port.flush() - return to_return - except serial.SerialTimeoutException: - raise IoTimeoutError() - - -atexit.register(SerialTransport.close_atexit) diff --git a/python/tvm/micro/transport/subprocess.py b/python/tvm/micro/transport/subprocess.py deleted file mode 100644 index 4de1fa1266d3..000000000000 --- a/python/tvm/micro/transport/subprocess.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an implementation of Transport that uses subprocesses.""" - -import subprocess -from . import base -from . import file_descriptor - - -class SubprocessFdTransport(file_descriptor.FdTransport): - def timeouts(self): - raise NotImplementedError() - - -class SubprocessTransport(base.Transport): - """A Transport implementation that uses a subprocess's stdin/stdout as the channel.""" - - def __init__(self, args, max_startup_latency_sec=5.0, max_latency_sec=5.0, **kwargs): - self.max_startup_latency_sec = max_startup_latency_sec - self.max_latency_sec = max_latency_sec - self.args = args - self.kwargs = kwargs - self.popen = None - self.child_transport = None - - def timeouts(self): - return base.TransportTimeouts( - session_start_retry_timeout_sec=0, - session_start_timeout_sec=self.max_startup_latency_sec, - session_established_timeout_sec=self.max_latency_sec, - ) - - def open(self): - self.kwargs["stdout"] = subprocess.PIPE - self.kwargs["stdin"] = subprocess.PIPE - self.kwargs["bufsize"] = 0 - self.popen = subprocess.Popen(self.args, **self.kwargs) - self.child_transport = SubprocessFdTransport( - self.popen.stdout, self.popen.stdin, self.timeouts() - ) - - def write(self, data, timeout_sec): - return self.child_transport.write(data, timeout_sec) - - def read(self, n, timeout_sec): - return self.child_transport.read(n, timeout_sec) - - def close(self): - if self.child_transport is not None: - self.child_transport.close() - - self.popen.terminate() diff --git a/python/tvm/micro/transport/wakeup.py b/python/tvm/micro/transport/wakeup.py deleted file mode 100644 index 418f8bdbb27a..000000000000 --- a/python/tvm/micro/transport/wakeup.py +++ /dev/null @@ -1,79 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines an implementation of Transport that uses subprocesses.""" - -import logging -import time -from . import base - - -_LOG = logging.getLogger(__name__) - - -class WakeupTransport(base.Transport): - """A Transport implementation that waits for a "wakeup sequence" from the remote end.""" - - def __init__(self, child_transport, wakeup_sequence): - self.child_transport = child_transport - self.wakeup_sequence = bytes(wakeup_sequence) - self.wakeup_sequence_buffer = bytearray() - self.line_start_index = 0 - self.found_wakeup_sequence = False - - def open(self): - return self.child_transport.open() - - def close(self): - return self.child_transport.close() - - def timeouts(self): - return self.child_transport.timeouts() - - def _await_wakeup(self, end_time): - def _time_remaining(): - if end_time is None: - return None - return max(0, end_time - time.monotonic()) - - if not self.found_wakeup_sequence: - while self.wakeup_sequence not in self.wakeup_sequence_buffer: - x = self.child_transport.read(1, _time_remaining()) - self.wakeup_sequence_buffer.extend(x) - if x[0] in (b"\n", b"\xff"): - _LOG.debug("%s", self.wakeup_sequence_buffer[self.line_start_index : -1]) - self.line_start_index = len(self.wakeup_sequence_buffer) - - _LOG.info("remote side woke up!") - self.found_wakeup_sequence = True - time.sleep(0.2) - - return _time_remaining() - - def read(self, n, timeout_sec): - if not self.found_wakeup_sequence: - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - timeout_sec = self._await_wakeup(end_time) - - return self.child_transport.read(n, timeout_sec) - - def write(self, data, timeout_sec): - if not self.found_wakeup_sequence: - end_time = None if timeout_sec is None else time.monotonic() + timeout_sec - timeout_sec = self._await_wakeup(end_time) - - return self.child_transport.write(data, timeout_sec) diff --git a/python/tvm/relay/testing/byoc.py b/python/tvm/relay/testing/byoc.py new file mode 100644 index 000000000000..619c9b99ca1d --- /dev/null +++ b/python/tvm/relay/testing/byoc.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines test utilties useful for testing BYOC flows.""" + +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +from tvm.relay.op.annotation import compiler_begin, compiler_end + + +class CcompilerAnnotator(ExprMutator): + """ + This is used to create external functions for ccompiler. + A simple annotator that creates the following program: + | + -- begin -- + | + add + | + subtract + | + multiply + | + -- end -- + | + """ + + def __init__(self): + super(CcompilerAnnotator, self).__init__() + self.in_compiler = 0 + + def visit_call(self, call): + if call.op.name == "add": # Annotate begin at args + if self.in_compiler == 1: + lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") + rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") + op = relay.add(lhs, rhs) + self.in_compiler = 2 + return op + elif call.op.name == "subtract": + if self.in_compiler == 1: + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + return relay.subtract(lhs, rhs) + elif call.op.name == "multiply": # Annotate end at output + self.in_compiler = 1 + lhs = super().visit(call.args[0]) + rhs = super().visit(call.args[1]) + if isinstance(lhs, relay.expr.Var): + lhs = compiler_begin(lhs, "ccompiler") + if isinstance(rhs, relay.expr.Var): + rhs = compiler_begin(rhs, "ccompiler") + op = relay.multiply(lhs, rhs) + if self.in_compiler == 2: + op = compiler_end(op, "ccompiler") + self.in_compiler = 0 + return op + return super().visit_call(call) diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index 907559421e5d..7949aea6f171 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -24,6 +24,12 @@ #ifndef TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ #define TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG + +/*! Support low-level debugging in MISRA-C runtime */ +#define TVM_CRT_DEBUG 0 + /*! Maximum supported dimension in NDArray */ #define TVM_CRT_MAX_NDIM 6 @@ -31,7 +37,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 250 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 @@ -48,9 +54,6 @@ /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 -/*! \brief DLDataType for the return value from strlen */ -#define TVM_CRT_STRLEN_DLTYPE 10 - /*! \brief Enable checks to enforce the stack allocator with a FIFO ordering. Off by default */ // #define TVM_CRT_STACK_ALLOCATOR_ENABLE_FIFO_CHECK diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 7b7690b66528..950f3e4ef215 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -265,8 +265,8 @@ int TVMGraphExecutorGraphAttr_Load(TVMGraphExecutorGraphAttr* attr, JSONReader* break; } DLDevice dev = {kDLCPU, 0}; - tvm_crt_error_t err = - TVMPlatformMemoryAllocate(TVM_CRT_STRLEN_DLTYPE * num_items, dev, (void**)&attr->dltype); + tvm_crt_error_t err = TVMPlatformMemoryAllocate(TVM_CRT_MAX_STRLEN_DLTYPE * num_items, dev, + (void**)&attr->dltype); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); return -1; @@ -278,8 +278,8 @@ int TVMGraphExecutorGraphAttr_Load(TVMGraphExecutorGraphAttr* attr, JSONReader* status = -1; return status; } - status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE, - TVM_CRT_STRLEN_DLTYPE); + status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_MAX_STRLEN_DLTYPE, + TVM_CRT_MAX_STRLEN_DLTYPE); if (status != 0) { fprintf(stderr, "error reading dltype array item"); break; @@ -792,14 +792,14 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl // read names char* names = NULL; DLDevice dev = {kDLCPU, 0}; - tvm_crt_error_t err = - TVMPlatformMemoryAllocate(TVM_CRT_STRLEN_NAME * executor->nodes_count, dev, (void**)&names); + tvm_crt_error_t err = TVMPlatformMemoryAllocate( + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count, dev, (void**)&names); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); status = -1; return status; } - memset(names, 0, TVM_CRT_STRLEN_NAME * executor->nodes_count); + memset(names, 0, TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count); uint64_t names_count; int idx; memcpy(&names_count, bptr, sizeof(names_count)); @@ -808,11 +808,11 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl uint64_t name_length; memcpy(&name_length, bptr, sizeof(name_length)); bptr += sizeof(name_length); - if (name_length >= TVM_CRT_STRLEN_NAME) { + if (name_length >= TVM_CRT_MAX_STRLEN_FUNCTION_NAME) { fprintf(stderr, "Error: function name longer than expected.\n"); status = -1; } - memcpy(names + TVM_CRT_STRLEN_NAME * idx, bptr, name_length); + memcpy(names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, bptr, name_length); bptr += name_length; } @@ -827,9 +827,10 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl } for (idx = 0; idx < size; idx++) { - int32_t in_idx = TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_STRLEN_NAME * idx); + int32_t in_idx = + TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n", - names + TVM_CRT_STRLEN_NAME * idx); + names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); uint32_t eid = TVMGraphExecutor_GetEntryId(executor, executor->input_nodes[in_idx], 0); if (!(eid < executor->data_entry_count)) { fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid, @@ -855,7 +856,7 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl #if TVM_CRT_DEBUG TVMNDArray* entry = &(executor->data_entry[eid]); printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", - names + TVM_CRT_STRLEN_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, + names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) #endif // TVM_CRT_DEBUG } @@ -937,7 +938,7 @@ int TVMGraphExecutor_SetupStorage(TVMGraphExecutor* executor) { return -1; } for (idx = 0; idx < attrs->dltype_count; idx++) { - vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_STRLEN_DLTYPE); + vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_MAX_STRLEN_DLTYPE); } // Size and device type of each storage pool entry. diff --git a/src/runtime/crt/host/Makefile b/src/runtime/crt/host/Makefile new file mode 100644 index 000000000000..efed3c438699 --- /dev/null +++ b/src/runtime/crt/host/Makefile @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +INCLUDES ?= -isystem crt/include -Icrt_config +CFLAGS ?= -Werror -Wall +CXXFLAGS ?= -Werror -Wall -std=c++11 +LDFLAGS ?= -Werror -Wall + +# Codegen produces spurious lines like: int32_t arg2_code = ((int32_t*)arg_type_ids)[(2)]; +MODEL_CFLAGS ?= -Wno-error=unused-variable + +AR ?= ${PREFIX}ar +CC ?= ${PREFIX}gcc +CXX ?= ${PREFIX}g++ +RANLIB ?= ${PREFIX}ranlib + +QUIET ?= @ + +PWD = $(shell pwd) +BUILD_DIR = build +CRT_LIB_NAMES = microtvm_rpc_server microtvm_rpc_common graph_executor graph_executor_module common memory +CRT_LIBS = $(patsubst %, $(BUILD_DIR)/crt/lib%.a, $(CRT_LIB_NAMES)) + +CRT_INCLUDES = $(glob crt/include/**) + +$(BUILD_DIR)/crt/lib%.a: $(glob crt/src/runtime/%/*.c) + ${QUIET}cd crt && $(MAKE) \ + BUILD_DIR=../$(BUILD_DIR)/crt \ + CRT_CONFIG=$(PWD)/crt_config/crt_config.h \ + EXTRA_CFLAGS="$(CFLAGS)" \ + EXTRA_CXXFLAGS="$(CXXFLAGS)" \ + EXTRA_LDFLAGS="$(EXTRA_LDFLAGS)" \ + $(patsubst $(BUILD_DIR)/crt/lib%.a,%,$@) + +crt: $(CRT_LIBS) +.PHONY: crt + +# Compile codegen files +$(BUILD_DIR)/model/codegen/host/%.o: model/codegen/host/%.c + ${QUIET}mkdir -p $(dir $@) + ${QUIET}$(CC) $(INCLUDES) $(CFLAGS) $(MODEL_CFLAGS) -c -o "$@" "$<" + +MODEL_LIBS = \ + $(patsubst model/codegen/host/src/%.c, $(BUILD_DIR)/model/codegen/host/src/%.o, $(wildcard model/codegen/host/src/*.c)) \ + $(wildcard model/codegen/host/lib/*.o) + +# Compile src/ files +build/%.o: src/%.cc + ${QUIET}mkdir -p $(dir $@) + ${QUIET}$(CXX) $(INCLUDES) $(CXXFLAGS) -c -o "$@" "$<" + +SRCS = $(wildcard src/*.cc) +OBJS = $(patsubst src/%.cc,build/%.o,$(SRCS)) + +build/main: ${OBJS} ${MODEL_LIBS} ${CRT_LIBS} + ${QUIET}mkdir -p $(dir $@) + ${QUIET}$(CXX) $(LDFLAGS) -o "$@" $^ + +all: build/main +.PHONY = all + +.DEFAULT_GOAL = all diff --git a/src/runtime/crt/host/microtvm_api_server.py b/src/runtime/crt/host/microtvm_api_server.py new file mode 100644 index 000000000000..5f9019817e82 --- /dev/null +++ b/src/runtime/crt/host/microtvm_api_server.py @@ -0,0 +1,200 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import fcntl +import os +import os.path +import pathlib +import select +import shutil +import subprocess +import tarfile +import time +from tvm.micro.project_api import server + + +PROJECT_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) + + +MODEL_LIBRARY_FORMAT_RELPATH = "model.tar" + + +IS_TEMPLATE = not os.path.exists(os.path.join(PROJECT_DIR, MODEL_LIBRARY_FORMAT_RELPATH)) + + +class Handler(server.ProjectAPIHandler): + + BUILD_TARGET = "build/main" + + def __init__(self): + super(Handler, self).__init__() + self._proc = None + + def server_info_query(self, tvm_version): + return server.ServerInfo( + platform_name="host", + is_template=IS_TEMPLATE, + model_library_format_path="" + if IS_TEMPLATE + else PROJECT_DIR / MODEL_LIBRARY_FORMAT_RELPATH, + project_options=[server.ProjectOption("verbose", help="Run make with verbose output")], + ) + + # These files and directories will be recursively copied into generated projects from the CRT. + CRT_COPY_ITEMS = ("include", "Makefile", "src") + + # The build target given to make + BUILD_TARGET = "build/main" + + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + # Make project directory. + project_dir.mkdir(parents=True) + + # Copy ourselves to the generated project. TVM may perform further build steps on the generated project + # by launching the copy. + shutil.copy2(__file__, project_dir / os.path.basename(__file__)) + + # Place Model Library Format tarball in the special location, which this script uses to decide + # whether it's being invoked in a template or generated project. + project_model_library_format_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH + shutil.copy2(model_library_format_path, project_model_library_format_path) + + # Extract Model Library Format tarball.into /model. + extract_path = project_dir / project_model_library_format_path.stem + with tarfile.TarFile(project_model_library_format_path) as tf: + os.makedirs(extract_path) + tf.extractall(path=extract_path) + + # Populate CRT. + crt_path = project_dir / "crt" + os.mkdir(crt_path) + for item in self.CRT_COPY_ITEMS: + src_path = standalone_crt_dir / item + dst_path = crt_path / item + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + # Populate Makefile. + shutil.copy2(pathlib.Path(__file__).parent / "Makefile", project_dir / "Makefile") + + # Populate crt-config.h + crt_config_dir = project_dir / "crt_config" + crt_config_dir.mkdir() + shutil.copy2( + os.path.join(os.path.dirname(__file__), "..", "crt_config-template.h"), + os.path.join(crt_config_dir, "crt_config.h"), + ) + + # Populate src/ + src_dir = os.path.join(project_dir, "src") + os.mkdir(src_dir) + shutil.copy2( + os.path.join(os.path.dirname(__file__), "main.cc"), os.path.join(src_dir, "main.cc") + ) + + def build(self, options): + args = ["make"] + if options.get("verbose"): + args.append("QUIET=") + + args.append(self.BUILD_TARGET) + + subprocess.check_call(args, cwd=PROJECT_DIR) + + def flash(self, options): + pass # Flashing does nothing on host. + + def _set_nonblock(self, fd): + flag = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) + new_flag = fcntl.fcntl(fd, fcntl.F_GETFL) + assert (new_flag & os.O_NONBLOCK) != 0, "Cannot set file descriptor {fd} to non-blocking" + + def open_transport(self, options): + self._proc = subprocess.Popen( + [self.BUILD_TARGET], stdin=subprocess.PIPE, stdout=subprocess.PIPE, bufsize=0 + ) + self._set_nonblock(self._proc.stdin.fileno()) + self._set_nonblock(self._proc.stdout.fileno()) + return server.TransportTimeouts( + session_start_retry_timeout_sec=0, + session_start_timeout_sec=0, + session_established_timeout_sec=0, + ) + + def close_transport(self): + if self._proc is not None: + proc = self._proc + self._proc = None + proc.terminate() + proc.wait() + + def _await_ready(self, rlist, wlist, timeout_sec=None, end_time=None): + if timeout_sec is None and end_time is not None: + timeout_sec = max(0, end_time - time.monotonic()) + + rlist, wlist, xlist = select.select(rlist, wlist, rlist + wlist, timeout_sec) + if not rlist and not wlist and not xlist: + raise server.IoTimeoutError() + + return True + + def read_transport(self, n, timeout_sec): + if self._proc is None: + raise server.TransportClosedError() + + fd = self._proc.stdout.fileno() + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + try: + self._await_ready([fd], [], end_time=end_time) + to_return = os.read(fd, n) + except BrokenPipeError: + to_return = 0 + + if not to_return: + self.disconnect_transport() + raise server.TransportClosedError() + + return to_return + + def write_transport(self, data, timeout_sec): + if self._proc is None: + raise server.TransportClosedError() + + fd = self._proc.stdin.fileno() + end_time = None if timeout_sec is None else time.monotonic() + timeout_sec + + data_len = len(data) + while data: + self._await_ready([], [fd], end_time=end_time) + try: + num_written = os.write(fd, data) + except BrokenPipeError: + num_written = 0 + + if not num_written: + self.disconnect_transport() + raise server.TransportClosedError() + + data = data[num_written:] + + +if __name__ == "__main__": + server.main(Handler()) diff --git a/src/runtime/crt/microtvm_rpc_common/framing.cc b/src/runtime/crt/microtvm_rpc_common/framing.cc index f89c6e5688c0..47e4a33a718c 100644 --- a/src/runtime/crt/microtvm_rpc_common/framing.cc +++ b/src/runtime/crt/microtvm_rpc_common/framing.cc @@ -66,6 +66,26 @@ void Unframer::Reset() { num_buffer_bytes_valid_ = 0; } +size_t Unframer::BytesNeeded() { + size_t bytes_needed = 0; + switch (state_) { + case State::kFindPacketStart: + return 1; + case State::kFindPacketLength: + bytes_needed = PacketFieldSizeBytes::kPayloadLength; + break; + case State::kFindPacketCrc: + return num_payload_bytes_remaining_; + case State::kFindCrcEnd: + bytes_needed = PacketFieldSizeBytes::kCrc; + break; + default: + CHECK(false); + } + + return bytes_needed > num_buffer_bytes_valid_ ? bytes_needed - num_buffer_bytes_valid_ : 0; +} + tvm_crt_error_t Unframer::Write(const uint8_t* data, size_t data_size_bytes, size_t* bytes_consumed) { tvm_crt_error_t return_code = kTvmErrorNoError; diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/micro/crt_config.h similarity index 90% rename from src/runtime/crt/host/crt_config.h rename to src/runtime/micro/crt_config.h index b81a74eb4ae6..c3e8fea1ba08 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/micro/crt_config.h @@ -21,8 +21,8 @@ * \file tvm/runtime/crt/host/crt_config.h * \brief CRT configuration for the host-linked CRT. */ -#ifndef TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ -#define TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ +#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_ +#define TVM_RUNTIME_MICRO_CRT_CONFIG_H_ /*! Log level of the CRT runtime */ #define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG @@ -35,9 +35,9 @@ /*! Maximum supported arguments in generated functions */ #define TVM_CRT_MAX_ARGS 10 /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ -#define TVM_CRT_STRLEN_DLTYPE 10 +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_STRLEN_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 @@ -53,4 +53,4 @@ // #define TVM_CRT_FRAMER_ENABLE_LOGS -#endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ +#endif // TVM_RUNTIME_MICRO_CRT_CONFIG_H_ diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index cd916d46971d..2dcd928b24f8 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -37,10 +37,10 @@ #include #include "../../support/str_escape.h" -#include "../crt/host/crt_config.h" #include "../rpc/rpc_channel.h" #include "../rpc/rpc_endpoint.h" #include "../rpc/rpc_session.h" +#include "crt_config.h" namespace tvm { namespace runtime { @@ -56,10 +56,12 @@ class CallbackWriteStream : public WriteStream { bytes.data = (const char*)data; bytes.size = data_size_bytes; if (write_timeout_ == ::std::chrono::microseconds::zero()) { - return static_cast(fsend_(bytes, nullptr)); + fsend_(bytes, nullptr); } else { - return static_cast(fsend_(bytes, write_timeout_.count())); + fsend_(bytes, write_timeout_.count()); } + + return static_cast(data_size_bytes); } void PacketDone(bool is_valid) override {} @@ -143,15 +145,16 @@ class MicroTransportChannel : public RPCChannel { } ::std::string chunk; + size_t bytes_needed = unframer_.BytesNeeded(); + CHECK_GT(bytes_needed, 0) << "unframer unexpectedly needs no data"; if (timeout != nullptr) { ::std::chrono::microseconds iter_timeout{ ::std::max(::std::chrono::microseconds{0}, ::std::chrono::duration_cast<::std::chrono::microseconds>( end_time - ::std::chrono::steady_clock::now()))}; - chunk = - frecv_(size_t(kReceiveBufferSizeBytes), iter_timeout.count()).operator std::string(); + chunk = frecv_(bytes_needed, iter_timeout.count()).operator std::string(); } else { - chunk = frecv_(size_t(kReceiveBufferSizeBytes), nullptr).operator std::string(); + chunk = frecv_(bytes_needed, nullptr).operator std::string(); } pending_chunk_ = chunk; if (pending_chunk_.size() == 0) { diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index f31630c2a705..01447ac6183f 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -80,6 +80,8 @@ "idl", # opencl file "cl", + # zephyr config file + "conf", } # List of file names allowed @@ -132,33 +134,12 @@ "tests/micro/zephyr/testdata/mnist-8.onnx", "tests/micro/zephyr/testdata/ic_sample_fp32_8.npy", # microTVM Zephyr runtime - "apps/microtvm/zephyr/qemu-hack/qemu-system-i386", - "apps/microtvm/zephyr/qemu-hack/qemu-system-arm", - "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv32", - "apps/microtvm/zephyr/qemu-hack/qemu-system-riscv64", - "apps/microtvm/zephyr/qemu-hack/qemu-system-xilinx-aarch64", - "apps/microtvm/zephyr/host_driven/prj.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_x86.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_riscv32.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_riscv64.conf", - "apps/microtvm/zephyr/host_driven/boards/nrf5340dk_nrf5340_cpuapp.conf", - "apps/microtvm/zephyr/host_driven/boards/nucleo_f746zg.conf", - "apps/microtvm/zephyr/host_driven/boards/stm32f746g_disco.conf", - "apps/microtvm/zephyr/host_driven/boards/mps2_an521.conf", - "apps/microtvm/zephyr/host_driven/boards/nucleo_l4r5zi.conf", - "apps/microtvm/zephyr/host_driven/boards/qemu_cortex_r5.conf", - "apps/microtvm/zephyr/host_driven/qemu-hack", - "apps/microtvm/zephyr/aot_demo/prj.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_x86.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_riscv32.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_riscv64.conf", - "apps/microtvm/zephyr/aot_demo/boards/nrf5340dk_nrf5340_cpuapp.conf", - "apps/microtvm/zephyr/aot_demo/boards/nucleo_f746zg.conf", - "apps/microtvm/zephyr/aot_demo/boards/stm32f746g_disco.conf", - "apps/microtvm/zephyr/aot_demo/boards/mps2_an521.conf", - "apps/microtvm/zephyr/aot_demo/boards/nucleo_l4r5zi.conf", - "apps/microtvm/zephyr/aot_demo/boards/qemu_cortex_r5.conf", - "apps/microtvm/zephyr/aot_demo/qemu-hack", + "apps/microtvm/zephyr/template_project/CMakeLists.txt.template", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-xilinx-aarch64", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-i386", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32", + "apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv64", # microTVM Virtual Machines "apps/microtvm/reference-vm/zephyr/Vagrantfile", "apps/microtvm/reference-vm/zephyr/base-box/Vagrantfile.packer-template", diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 0b50ecd12ec5..2b30401a90e9 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -14,8 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import datetime +import os +import pathlib + import pytest +import tvm.contrib.utils import tvm.target.target # The models that should pass this configuration. Maps a short, identifying platform string to @@ -24,7 +29,7 @@ "host": ("host", "qemu_x86"), "host_riscv32": ("host", "qemu_riscv32"), "host_riscv64": ("host", "qemu_riscv64"), - "mps2_an521": ("mps2_an521", "mps2_an521-qemu"), + "mps2_an521": ("mps2_an521", "mps2_an521"), "nrf5340dk": ("nrf5340dk", "nrf5340dk_nrf5340_cpuapp"), "stm32f746xx_disco": ("stm32f746xx", "stm32f746g_disco"), "stm32f746xx_nucleo": ("stm32f746xx", "nucleo_f746zg"), @@ -77,3 +82,25 @@ def skip_build(request): @pytest.fixture def tvm_debug(request): return request.config.getoption("--tvm-debug") + + +@pytest.fixture +def temp_dir(platform): + _, zephyr_board = PLATFORMS[platform] + parent_dir = pathlib.Path(os.path.dirname(__file__)) + filename = os.path.splitext(os.path.basename(__file__))[0] + board_workspace = ( + parent_dir + / f"workspace_{filename}_{zephyr_board}" + / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + ) + board_workspace_base = str(board_workspace) + number = 1 + while board_workspace.exists(): + board_workspace = pathlib.Path(board_workspace_base + f"-{number}") + number += 1 + + if not os.path.exists(board_workspace.parent): + os.makedirs(board_workspace.parent) + + return tvm.contrib.utils.tempdir(board_workspace) diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 18587acd46ae..b3dbcbadd886 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -21,6 +21,7 @@ import glob import logging import os +import pathlib import subprocess import sys import logging @@ -35,8 +36,8 @@ import tvm.micro import tvm.testing import tvm.relay as relay +from tvm.relay.testing import byoc -from tvm.micro.contrib import zephyr from tvm.contrib import utils from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -48,85 +49,60 @@ PLATFORMS = conftest.PLATFORMS -def _make_sess_from_op(model, zephyr_board, west_cmd, op_name, sched, arg_bufs, build_config): +def _make_sess_from_op( + temp_dir, model, zephyr_board, west_cmd, op_name, sched, arg_bufs, build_config +): target = tvm.target.target.micro(model) target = tvm.target.Target(target=target, host=target) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.build(sched, arg_bufs, target=target, name=op_name) - return _make_session(model, target, zephyr_board, west_cmd, mod, build_config) - - -def _make_session(model, target, zephyr_board, west_cmd, mod, build_config): - parent_dir = os.path.dirname(__file__) - filename = os.path.splitext(os.path.basename(__file__))[0] - prev_build = f"{os.path.join(parent_dir, 'archive')}_{filename}_{zephyr_board}_last_build.micro" - workspace_root = os.path.join( - f"{os.path.join(parent_dir, 'workspace')}_{filename}_{zephyr_board}", - datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), - ) - workspace_parent = os.path.dirname(workspace_root) - if not os.path.exists(workspace_parent): - os.makedirs(workspace_parent) - workspace = tvm.micro.Workspace(debug=True, root=workspace_root) - - test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - tvm_source_dir = os.path.join(test_dir, "..", "..", "..") - runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "host_driven") - compiler = zephyr.ZephyrCompiler( - project_dir=runtime_path, - board=zephyr_board, - zephyr_toolchain_variant="zephyr", - west_cmd=west_cmd, + return _make_session(temp_dir, model, target, zephyr_board, west_cmd, mod, build_config) + + +TEMPLATE_PROJECT_DIR = ( + pathlib.Path(__file__).parent + / ".." + / ".." + / ".." + / "apps" + / "microtvm" + / "zephyr" + / "template_project" +).resolve() + + +def _make_session(temp_dir, model, target, zephyr_board, west_cmd, mod, build_config): + project = tvm.micro.generate_project( + str(TEMPLATE_PROJECT_DIR), + mod, + temp_dir / "project", + { + "project_type": "host_driven", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": zephyr_board, + }, ) - - opts = tvm.micro.default_options(os.path.join(runtime_path, "crt")) - # TODO(weberlo) verify this is necessary - opts["bin_opts"]["ccflags"] = ["-std=gnu++14"] - opts["lib_opts"]["ccflags"] = ["-std=gnu++14"] - - flasher_kw = {} - if build_config["debug"]: - flasher_kw["debug_rpc_session"] = tvm.rpc.connect("127.0.0.1", 9090) - - session_kw = { - "flasher": compiler.flasher(**flasher_kw), - } - - if not build_config["skip_build"]: - session_kw["binary"] = tvm.micro.build_static_runtime( - # the x86 compiler *expects* you to give the exact same dictionary for both - # lib_opts and bin_opts. so the library compiler is mutating lib_opts and - # the binary compiler is expecting those mutations to be in bin_opts. - # TODO(weberlo) fix this very bizarre behavior - workspace, - compiler, - mod, - opts, - ) - if os.path.exists(prev_build): - os.unlink(prev_build) - session_kw["binary"].archive(prev_build, metadata_only=True) - else: - unarchive_dir = utils.tempdir() - session_kw["binary"] = tvm.micro.MicroBinary.unarchive( - prev_build, unarchive_dir.relpath("binary") - ) - - return tvm.micro.Session(**session_kw) + if not build_config.get("skip_build"): + project.build() + project.flash() + return tvm.micro.Session(project.transport()) -def _make_add_sess(model, zephyr_board, west_cmd, build_config): - A = tvm.te.placeholder((2,), dtype="int8") - B = tvm.te.placeholder((1,), dtype="int8") +def _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config, dtype="int8"): + A = tvm.te.placeholder((2,), dtype=dtype) + B = tvm.te.placeholder((1,), dtype=dtype) C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, B, C], build_config) + return _make_sess_from_op( + temp_dir, model, zephyr_board, west_cmd, "add", sched, [A, B, C], build_config + ) # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro -def test_compile_runtime(platform, west_cmd, skip_build, tvm_debug): +def test_add_uint(temp_dir, platform, west_cmd, skip_build, tvm_debug): """Test compiling the on-device runtime.""" model, zephyr_board = PLATFORMS[platform] @@ -145,12 +121,51 @@ def test_basic_add(sess): system_lib.get_function("add")(A_data, B_data, C_data) assert (C_data.numpy() == np.array([6, 7])).all() - with _make_add_sess(model, zephyr_board, west_cmd, build_config) as sess: + with _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config) as sess: test_basic_add(sess) +def has_fpu(zephyr_board): + sys.path.insert(0, str(TEMPLATE_PROJECT_DIR)) + try: + import microtvm_api_server + finally: + sys.path.pop(0) + + return microtvm_api_server.Handler._has_fpu(zephyr_board) + + +# The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro -def test_platform_timer(platform, west_cmd, skip_build, tvm_debug): +def test_add_float(temp_dir, platform, west_cmd, skip_build, tvm_debug): + """Test compiling the on-device runtime.""" + model, zephyr_board = PLATFORMS[platform] + if not has_fpu(zephyr_board): + pytest.skip(f"FPU not enabled for {platform}") + + build_config = {"skip_build": skip_build, "debug": tvm_debug} + + # NOTE: run test in a nested function so cPython will delete arrays before closing the session. + def test_basic_add(sess): + A_data = tvm.nd.array(np.array([2.5, 3.5], dtype="float32"), device=sess.device) + assert (A_data.numpy() == np.array([2.5, 3.5])).all() + B_data = tvm.nd.array(np.array([4.5], dtype="float32"), device=sess.device) + assert (B_data.numpy() == np.array([4.5])).all() + C_data = tvm.nd.array(np.array([0, 0], dtype="float32"), device=sess.device) + assert (C_data.numpy() == np.array([0, 0])).all() + + system_lib = sess.get_system_lib() + system_lib.get_function("add")(A_data, B_data, C_data) + assert (C_data.numpy() == np.array([7, 8])).all() + + with _make_add_sess( + temp_dir, model, zephyr_board, west_cmd, build_config, dtype="float32" + ) as sess: + test_basic_add(sess) + + +@tvm.testing.requires_micro +def test_platform_timer(temp_dir, platform, west_cmd, skip_build, tvm_debug): """Test compiling the on-device runtime.""" model, zephyr_board = PLATFORMS[platform] @@ -174,12 +189,12 @@ def test_basic_add(sess): assert result.mean > 0 assert len(result.results) == 3 - with _make_add_sess(model, zephyr_board, west_cmd, build_config) as sess: + with _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config) as sess: test_basic_add(sess) @tvm.testing.requires_micro -def test_relay(platform, west_cmd, skip_build, tvm_debug): +def test_relay(temp_dir, platform, west_cmd, skip_build, tvm_debug): """Testing a simple relay graph""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} @@ -194,13 +209,15 @@ def test_relay(platform, west_cmd, skip_build, tvm_debug): target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph, mod, params = tvm.relay.build(func, target=target) + mod = tvm.relay.build(func, target=target) - with _make_session(model, target, zephyr_board, west_cmd, mod, build_config) as session: + with _make_session( + temp_dir, model, target, zephyr_board, west_cmd, mod, build_config + ) as session: graph_mod = tvm.micro.create_local_graph_executor( - graph, session.get_system_lib(), session.device + mod.get_graph_json(), session.get_system_lib(), session.device ) - graph_mod.set_input(**params) + graph_mod.set_input(**mod.get_params()) x_in = np.random.randint(10, size=shape[0], dtype=dtype) graph_mod.run(x=x_in) result = graph_mod.get_output(0).numpy() @@ -209,7 +226,7 @@ def test_relay(platform, west_cmd, skip_build, tvm_debug): @tvm.testing.requires_micro -def test_onnx(platform, west_cmd, skip_build, tvm_debug): +def test_onnx(temp_dir, platform, west_cmd, skip_build, tvm_debug): """Testing a simple ONNX model.""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} @@ -239,7 +256,9 @@ def test_onnx(platform, west_cmd, skip_build, tvm_debug): lowered = relay.build(relay_mod, target, params=params) graph = lowered.get_graph_json() - with _make_session(model, target, zephyr_board, west_cmd, lowered.lib, build_config) as session: + with _make_session( + temp_dir, model, target, zephyr_board, west_cmd, lowered, build_config + ) as session: graph_mod = tvm.micro.create_local_graph_executor( graph, session.get_system_lib(), session.device ) @@ -257,77 +276,25 @@ def test_onnx(platform, west_cmd, skip_build, tvm_debug): assert np.argmax(result) == 9 -class CcompilerAnnotator(ExprMutator): - """ - This is used to create external functions for ccompiler. - A simple annotator that creates the following program: - | - -- begin -- - | - add - | - subtract - | - multiply - | - -- end -- - | - """ - - def __init__(self): - super(CcompilerAnnotator, self).__init__() - self.in_compiler = 0 - - def visit_call(self, call): - if call.op.name == "add": # Annotate begin at args - if self.in_compiler == 1: - lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") - rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") - op = relay.add(lhs, rhs) - self.in_compiler = 2 - return op - elif call.op.name == "subtract": - if self.in_compiler == 1: - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - return relay.subtract(lhs, rhs) - elif call.op.name == "multiply": # Annotate end at output - self.in_compiler = 1 - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - op = relay.multiply(lhs, rhs) - if self.in_compiler == 2: - op = compiler_end(op, "ccompiler") - self.in_compiler = 0 - return op - return super().visit_call(call) - - def check_result( - relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape, result, build_config + temp_dir, relay_mod, model, zephyr_board, west_cmd, map_inputs, out_shape, result, build_config ): """Helper function to verify results""" TOL = 1e-5 target = tvm.target.target.micro(model) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph, mod, params = tvm.relay.build(relay_mod, target=target) + mod = tvm.relay.build(relay_mod, target=target) - with _make_session(model, target, zephyr_board, west_cmd, mod, build_config) as session: + with _make_session( + temp_dir, model, target, zephyr_board, west_cmd, mod, build_config + ) as session: rt_mod = tvm.micro.create_local_graph_executor( - graph, session.get_system_lib(), session.device + mod.get_graph_json(), session.get_system_lib(), session.device ) - rt_mod.set_input(**params) + rt_mod.set_input(**mod.get_params()) for name, data in map_inputs.items(): rt_mod.set_input(name, data) - rt_mod.set_input(**params) + rt_mod.set_input(**mod.get_params()) rt_mod.run() out_shapes = out_shape if isinstance(out_shape, list) else [out_shape] @@ -340,7 +307,7 @@ def check_result( @tvm.testing.requires_micro -def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): +def test_byoc_microtvm(temp_dir, platform, west_cmd, skip_build, tvm_debug): """This is a simple test case to check BYOC capabilities of microTVM""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} @@ -370,7 +337,7 @@ def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) mod = tvm.IRModule() - ann = CcompilerAnnotator() + ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) mod = tvm.relay.transform.PartitionGraph()(mod) mod = tvm.relay.transform.InferType()(mod) @@ -383,6 +350,7 @@ def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} map_inputs["x"] = x_data check_result( + temp_dir=temp_dir, relay_mod=mod, map_inputs=map_inputs, out_shape=(30, 10), @@ -401,11 +369,13 @@ def test_byoc_microtvm(platform, west_cmd, skip_build, tvm_debug): ) -def _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config): +def _make_add_sess_with_shape(temp_dir, model, zephyr_board, west_cmd, shape, build_config): A = tvm.te.placeholder(shape, dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + A[i], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(model, zephyr_board, west_cmd, "add", sched, [A, C], build_config) + return _make_sess_from_op( + temp_dir, model, zephyr_board, west_cmd, "add", sched, [A, C], build_config + ) @pytest.mark.parametrize( @@ -417,7 +387,7 @@ def _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config ], ) @tvm.testing.requires_micro -def test_rpc_large_array(platform, west_cmd, skip_build, tvm_debug, shape): +def test_rpc_large_array(temp_dir, platform, west_cmd, skip_build, tvm_debug, shape): """Test large RPC array transfer.""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} @@ -431,7 +401,9 @@ def test_tensors(sess): C_data = tvm.nd.array(np.zeros(shape, dtype="int8"), device=sess.device) assert (C_data.asnumpy() == np.zeros(shape)).all() - with _make_add_sess_with_shape(model, zephyr_board, west_cmd, shape, build_config) as sess: + with _make_add_sess_with_shape( + temp_dir, model, zephyr_board, west_cmd, shape, build_config + ) as sess: test_tensors(sess) diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index f136cee96199..1602a4185462 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -17,11 +17,14 @@ import datetime from hashlib import new +import io import logging import os import sys import logging import pathlib +import tarfile +import tempfile import pytest import numpy as np @@ -29,10 +32,10 @@ import tvm import tvm.rpc import tvm.micro +from tvm.micro.project_api import server import tvm.testing import tvm.relay as relay -from tvm.micro.contrib import zephyr from tvm.contrib import utils from tvm.contrib.download import download_testdata from tvm.micro.interface_api import generate_c_interface_header @@ -44,89 +47,70 @@ PLATFORMS = conftest.PLATFORMS -def _build_session_kw(model, target, zephyr_board, west_cmd, mod, runtime_path, build_config): - parent_dir = os.path.dirname(__file__) - filename = os.path.splitext(os.path.basename(__file__))[0] - prev_build = f"{os.path.join(parent_dir, 'archive')}_{filename}_{zephyr_board}_last_build.micro" - workspace_root = os.path.join( - f"{os.path.join(parent_dir, 'workspace')}_{filename}_{zephyr_board}", - datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"), - ) - workspace_parent = os.path.dirname(workspace_root) - if not os.path.exists(workspace_parent): - os.makedirs(workspace_parent) - workspace = tvm.micro.Workspace(debug=True, root=workspace_root) - - compiler = zephyr.ZephyrCompiler( - project_dir=runtime_path, - board=zephyr_board, - zephyr_toolchain_variant="zephyr", - west_cmd=west_cmd, - env_vars={"ZEPHYR_RUNTIME": "ZEPHYR-AOT"}, +def _build_project( + temp_dir, model, target, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None +): + template_project_dir = ( + pathlib.Path(__file__).parent + / ".." + / ".." + / ".." + / "apps" + / "microtvm" + / "zephyr" + / "template_project" + ).resolve() + project_dir = temp_dir / "project" + project = tvm.micro.generate_project( + str(template_project_dir), + mod, + project_dir, + { + "extra_files_tar": extra_files_tar, + "project_type": "aot_demo", + "west_cmd": west_cmd, + "verbose": 0, + "zephyr_board": zephyr_board, + }, ) + project.build() + return project, project_dir - opts = tvm.micro.default_options(os.path.join(runtime_path, "crt")) - opts["bin_opts"]["include_dirs"].append(os.path.join(runtime_path, "include")) - opts["lib_opts"]["include_dirs"].append(os.path.join(runtime_path, "include")) - - flasher_kw = {} - if build_config["debug"]: - flasher_kw["debug_rpc_session"] = tvm.rpc.connect("127.0.0.1", 9090) - - session_kw = { - "flasher": compiler.flasher(**flasher_kw), - } - - if not build_config["skip_build"]: - session_kw["binary"] = tvm.micro.build_static_runtime( - workspace, - compiler, - mod, - opts, - executor="aot", - extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], - ) - if os.path.exists(prev_build): - os.unlink(prev_build) - session_kw["binary"].archive(prev_build, metadata_only=True) - else: - unarchive_dir = utils.tempdir() - session_kw["binary"] = tvm.micro.MicroBinary.unarchive( - prev_build, unarchive_dir.relpath("binary") - ) - return session_kw - - -def _create_header_file(tensor_name, npy_data, output_path): +def _create_header_file(tensor_name, npy_data, output_path, tar_file): """ This method generates a header file containing the data contained in the numpy array provided. It is used to capture the tensor data (for both inputs and expected outputs). """ - file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve() - # create header file - raw_path = file_path.with_suffix(".h").resolve() - with open(raw_path, "w") as header_file: - header_file.write("#include \n") - header_file.write("#include \n") - header_file.write("#include \n") - header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") - - if npy_data.dtype == "int8": - header_file.write(f"int8_t {tensor_name}[] =") - elif npy_data.dtype == "int32": - header_file.write(f"int32_t {tensor_name}[] = ") - elif npy_data.dtype == "uint8": - header_file.write(f"uint8_t {tensor_name}[] = ") - elif npy_data.dtype == "float32": - header_file.write(f"float {tensor_name}[] = ") - else: - raise ValueError("Data type not expected.") - - header_file.write("{") - for i in np.ndindex(npy_data.shape): - header_file.write(f"{npy_data[i]}, ") - header_file.write("};\n\n") + header_file = io.StringIO() + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write("#include \n") + header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + + if npy_data.dtype == "int8": + header_file.write(f"int8_t {tensor_name}[] =") + elif npy_data.dtype == "int32": + header_file.write(f"int32_t {tensor_name}[] = ") + elif npy_data.dtype == "uint8": + header_file.write(f"uint8_t {tensor_name}[] = ") + elif npy_data.dtype == "float32": + header_file.write(f"float {tensor_name}[] = ") + else: + raise ValueError("Data type not expected.") + + header_file.write("{") + for i in np.ndindex(npy_data.shape): + header_file.write(f"{npy_data[i]}, ") + header_file.write("};\n\n") + + header_file_bytes = bytes(header_file.getvalue(), "utf-8") + raw_path = pathlib.Path(output_path) / f"{tensor_name}.h" + ti = tarfile.TarInfo(name=str(raw_path)) + ti.size = len(header_file_bytes) + ti.mode = 0o644 + ti.type = tarfile.REGTYPE + tar_file.addfile(ti, io.BytesIO(header_file_bytes)) def _read_line(fd): @@ -155,19 +139,20 @@ def _get_message(fd, expr: str): @tvm.testing.requires_micro -def test_tflite(platform, west_cmd, skip_build, tvm_debug): +def test_tflite(temp_dir, platform, west_cmd, skip_build, tvm_debug): + """Testing a TFLite model.""" + if platform not in ["host", "mps2_an521", "nrf5340dk", "stm32l4r5zi_nucleo", "zynq_mp_r5"]: pytest.skip(msg="Model does not fit.") - """Testing a TFLite model.""" model, zephyr_board = PLATFORMS[platform] input_shape = (1, 32, 32, 3) output_shape = (1, 10) build_config = {"skip_build": skip_build, "debug": tvm_debug} - this_dir = os.path.dirname(__file__) - tvm_source_dir = os.path.join(this_dir, "..", "..", "..") - runtime_path = os.path.join(tvm_source_dir, "apps", "microtvm", "zephyr", "aot_demo") + this_dir = pathlib.Path(os.path.dirname(__file__)) + tvm_source_dir = this_dir / ".." / ".." / ".." + runtime_path = tvm_source_dir / "apps" / "microtvm" / "zephyr" / "aot_demo" model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite" model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model") @@ -199,21 +184,40 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): sample_url, "testdata_image_classification_fp32_8.npy", module="data" ) sample = np.load(sample_path) - model_files_path = os.path.join(runtime_path, "include") - generate_c_interface_header(lowered.libmod_name, ["input_1"], ["output"], model_files_path) - _create_header_file((f"input_data"), sample, model_files_path) - _create_header_file( - "output_data", np.zeros(shape=output_shape, dtype="float32"), model_files_path - ) - session_kw = _build_session_kw( - model, target, zephyr_board, west_cmd, lowered.lib, runtime_path, build_config - ) - transport = session_kw["flasher"].flash(session_kw["binary"]) - transport.open() - transport.write(b"start\n", timeout_sec=5) + with tempfile.NamedTemporaryFile() as tar_temp_file: + with tarfile.open(tar_temp_file.name, "w:gz") as tf: + with tempfile.TemporaryDirectory() as tar_temp_dir: + model_files_path = os.path.join(tar_temp_dir, "include") + os.mkdir(model_files_path) + header_path = generate_c_interface_header( + lowered.libmod_name, ["input_1"], ["output"], model_files_path + ) + tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) + + _create_header_file("input_data", sample, "include", tf) + _create_header_file( + "output_data", np.zeros(shape=output_shape, dtype="float32"), "include", tf + ) + + project, _ = _build_project( + temp_dir, + model, + target, + zephyr_board, + west_cmd, + lowered, + build_config, + extra_files_tar=tar_temp_file.name, + ) + + project.flash() + with project.transport() as transport: + _get_message(transport, "#wakeup") + transport.write(b"start\n", timeout_sec=5) + + result_line = _get_message(transport, "#result") - result_line = _get_message(transport, "#result") result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) @@ -223,20 +227,16 @@ def test_tflite(platform, west_cmd, skip_build, tvm_debug): @tvm.testing.requires_micro -def test_qemu_make_fail(platform, west_cmd, skip_build, tvm_debug): +def test_qemu_make_fail(temp_dir, platform, west_cmd, skip_build, tvm_debug): + """Testing QEMU make fail.""" if platform not in ["host", "mps2_an521"]: pytest.skip(msg="Only for QEMU targets.") - """Testing QEMU make fail.""" model, zephyr_board = PLATFORMS[platform] build_config = {"skip_build": skip_build, "debug": tvm_debug} shape = (10,) dtype = "float32" - this_dir = pathlib.Path(__file__).parent - tvm_source_dir = this_dir / ".." / ".." / ".." - runtime_path = tvm_source_dir / "apps" / "microtvm" / "zephyr" / "aot_demo" - # Construct Relay program. x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) xx = relay.multiply(x, x) @@ -248,22 +248,39 @@ def test_qemu_make_fail(platform, west_cmd, skip_build, tvm_debug): lowered = relay.build(func, target) # Generate input/output header files - model_files_path = os.path.join(runtime_path, "include") - _create_header_file((f"input_data"), np.zeros(shape=shape, dtype=dtype), model_files_path) - _create_header_file("output_data", np.zeros(shape=shape, dtype=dtype), model_files_path) + with tempfile.NamedTemporaryFile() as tar_temp_file: + with tarfile.open(tar_temp_file.name, "w:gz") as tf: + with tempfile.TemporaryDirectory() as tar_temp_dir: + model_files_path = os.path.join(tar_temp_dir, "include") + os.mkdir(model_files_path) + header_path = generate_c_interface_header( + lowered.libmod_name, ["input_1"], ["output"], model_files_path + ) + tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) + _create_header_file("input_data", np.zeros(shape=shape, dtype=dtype), "include", tf) + _create_header_file("output_data", np.zeros(shape=shape, dtype=dtype), "include", tf) + + project, project_dir = _build_project( + temp_dir, + model, + target, + zephyr_board, + west_cmd, + lowered, + build_config, + extra_files_tar=tar_temp_file.name, + ) - session_kw = _build_session_kw( - model, target, zephyr_board, west_cmd, lowered.lib, runtime_path, build_config + file_path = ( + pathlib.Path(project_dir) / "build" / "zephyr" / "CMakeFiles" / "run.dir" / "build.make" ) - - file_path = os.path.join(session_kw["binary"].base_dir, "zephyr/CMakeFiles/run.dir/build.make") - assert os.path.isfile(file_path), f"[{file_path}] does not exist." + assert file_path.is_file(), f"[{file_path}] does not exist." # Remove a file to create make failure. os.remove(file_path) - transport = session_kw["flasher"].flash(session_kw["binary"]) - with pytest.raises(RuntimeError) as excinfo: - transport.open() + project.flash() + with pytest.raises(server.JSONRPCError) as excinfo: + project.transport().open() assert "QEMU setup failed" in str(excinfo.value) diff --git a/tests/python/relay/aot/aot_test.mk b/tests/python/relay/aot/aot_test.mk index 81e31762611f..0c47a32ea1f8 100644 --- a/tests/python/relay/aot/aot_test.mk +++ b/tests/python/relay/aot/aot_test.mk @@ -16,26 +16,20 @@ # under the License. # Setup build environment # -AOT_ROOT ?= $(TVM_ROOT)/src/runtime/crt/aot +AOT_ROOT ?= $(CRT_ROOT)/aot ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core -PKG_COMPILE_OPTS = -g +PKG_COMPILE_OPTS = -g CC = gcc AR = ar RANLIB = ranlib CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB) - PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ - -I$(TVM_ROOT)/src/runtime/crt/include \ - -I$(TVM_ROOT)/src/runtime/crt/host \ - -I$(TVM_ROOT)/include \ - -I$(DMLC_CORE)/include \ - -I$(TVM_ROOT)/3rdparty/dlpack/include \ - -I$(AOT_ROOT)\ - -I$(build_dir) \ - -I$(CODEGEN_ROOT)/host/include + -I$(build_dir)/../include \ + -I$(CODEGEN_ROOT)/host/include \ + -isystem$(STANDALONE_CRT_DIR)/include $(ifeq VERBOSE,1) QUIET ?= @@ -43,12 +37,10 @@ $(else) QUIET ?= @ $(endif) -CRT_SRCS = $(shell find $(CRT_ROOT)) - aot_test_runner: $(build_dir)/aot_test_runner source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c) -lib_objs =$(source_libs:.c=.o) +lib_objs =$(source_libs:.c=.o) $(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/aot_executor.o $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o $(QUIET)mkdir -p $(@D) @@ -58,15 +50,15 @@ $(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/aot_executor.o: $(TVM_ROOT)/src/runtime/crt/aot_executor/aot_executor.c +$(build_dir)/aot_executor.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/aot_executor/aot_executor.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/stack_allocator.o: $(TVM_ROOT)/src/runtime/crt/memory/stack_allocator.c +$(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/crt_backend_api.o: $(TVM_ROOT)/src/runtime/crt/common/crt_backend_api.c +$(build_dir)/crt_backend_api.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/common/crt_backend_api.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 900eb67e2b48..d4d16346f8c2 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -15,12 +15,15 @@ # specific language governing permissions and limitations # under the License. -import os +import datetime import itertools +import json +import logging +import os import pathlib +import shutil import subprocess import tarfile -import json import pytest import numpy as np @@ -33,6 +36,9 @@ from tvm.micro import export_model_library_format +_LOG = logging.getLogger(__name__) + + def mangle_name(mod_name, name): mod_name = mangle_module_name(mod_name) return mod_name + "_" + name @@ -98,24 +104,35 @@ def parametrize_aot_options(test): )(test) -def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout): +def subprocess_log_output(cmd, cwd, logfile): """ This method runs a process and logs the output to both a log file and stdout """ - with subprocess.Popen( + _LOG.info("Execute (%s): %s", cwd, cmd) + cmd_base = cmd[0] if isinstance(cmd, (list, tuple)) else cmd.split(" ", 1)[0] + proc = subprocess.Popen( cmd, cwd=cwd, shell=True, bufsize=0, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) as proc, open(logfile, "a") as f: + ) + with open(logfile, "ab") as f: + f.write( + bytes( + "\n" + + "-" * 80 + + f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}: Execute ({cwd}): {cmd}\n" + + "-" * 80, + "utf-8", + ) + ) while True: data = proc.stdout.readline() - result = proc.poll() + _LOG.debug("%s: %s", cmd_base, str(data, "utf-8", "replace").rstrip("\n")) + f.write(data) + # process is done if there is no data and the result is valid - if data == b"" and result is not None: - return int(result) - if data: - text = data.decode("ascii", errors="backslashreplace") - f.write(text) - if stdout: - print(text, end="") + if not data: # EOF + break + + return proc.wait() def emit_main_prologue(main_file, workspace_bytes): @@ -135,12 +152,12 @@ def emit_main_prologue(main_file, workspace_bytes): return StackMemoryManager_Free(&app_workspace,ptr); } -void TVMPlatformAbort(tvm_crt_error_t code) { } +void TVMPlatformAbort(tvm_crt_error_t code) { } void TVMLogf(const char* msg, ...) { } TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} -int main(){\n +int main(){\n """ ) @@ -405,17 +422,31 @@ def compile_and_run( else: workspace_bytes = 16384 * 1024 + include_path = os.path.join(base_path, "include") + os.mkdir(include_path) + crt_root = tvm.micro.get_standalone_crt_dir() + shutil.copy2( + os.path.join(crt_root, "template", "crt_config-template.h"), + os.path.join(include_path, "crt_config.h"), + ) + for key in inputs: - create_header_file(f'{mangle_name(mod_name, "input_data")}_{key}', inputs[key], build_path) + create_header_file( + f'{mangle_name(mod_name, "input_data")}_{key}', + inputs[key], + os.path.join(base_path, "include"), + ) for i in range(len(output_list)): create_header_file( - (f'{mangle_name(mod_name,"output_data")}{i}'), + f'{mangle_name(mod_name,"output_data")}{i}', np.zeros(output_list[i].shape, output_list[i].dtype), - build_path, + os.path.join(base_path, "include"), ) create_header_file( - (f'{mangle_name(mod_name, "expected_output_data")}{i}'), output_list[i], build_path + f'{mangle_name(mod_name, "expected_output_data")}{i}', + output_list[i], + os.path.join(base_path, "include"), ) create_main( @@ -436,15 +467,16 @@ def compile_and_run( + build_path + f" TVM_ROOT={file_dir}/../../../.." + f" CODEGEN_ROOT={codegen_path}" + + f" STANDALONE_CRT_DIR={tvm.micro.get_standalone_crt_dir()}" ) compile_log_path = os.path.join(build_path, "test_compile.log") - ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) + ret = subprocess_log_output(make_cmd, ".", compile_log_path) assert ret == 0 # Verify that runs fine run_log_path = os.path.join(build_path, "test_run.log") - ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) + ret = subprocess_log_output("./aot_test_runner", build_path, run_log_path) assert ret == 0 @@ -470,6 +502,15 @@ def compile_and_run_multiple_models( base_path = os.path.join(tmp_dir, "test") build_path = os.path.join(base_path, "build") os.makedirs(build_path, exist_ok=True) + + include_path = os.path.join(base_path, "include") + os.mkdir(include_path) + crt_root = tvm.micro.get_standalone_crt_dir() + shutil.copy2( + os.path.join(crt_root, "template", "crt_config-template.h"), + os.path.join(include_path, "crt_config.h"), + ) + for mod_name, mod in mod_map.items(): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): @@ -518,15 +559,16 @@ def compile_and_run_multiple_models( + build_path + f" TVM_ROOT={file_dir}/../../../.." + f" CODEGEN_ROOT={codegen_path}" + + f" STANDALONE_CRT_DIR={tvm.micro.get_standalone_crt_dir()}" ) compile_log_path = os.path.join(build_path, "test_compile.log") - ret = subprocess_with_stdout_and_log(make_cmd, ".", compile_log_path, False) + ret = subprocess_log_output(make_cmd, ".", compile_log_path) assert ret == 0 # Verify that runs fine run_log_path = os.path.join(build_path, "test_run.log") - ret = subprocess_with_stdout_and_log("./aot_test_runner", build_path, run_log_path, False) + ret = subprocess_log_output("./aot_test_runner", build_path, run_log_path) assert ret == 0 diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 26eca2688436..3a972861f8ef 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -16,13 +16,30 @@ # under the License. from collections import OrderedDict +import os +import io +import pathlib +import shutil +import struct +import subprocess +import sys +import tempfile +import tarfile import numpy as np import pytest import tvm from tvm import relay -from tvm.relay import testing, transform +from tvm.relay import transform +from tvm.relay.op.contrib import get_pattern_table +from tvm.contrib import utils +from tvm.relay.backend import compile_engine +from tvm.contrib import utils +from tvm.contrib import graph_executor +from tvm.micro import export_model_library_format +from tvm.relay import testing +from tvm.relay.testing import byoc from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.expr_functor import ExprMutator from aot_test_utils import ( @@ -360,60 +377,6 @@ def test_mobilenet(use_calculated_workspaces, workspace_byte_alignment): ) -class CcompilerAnnotator(ExprMutator): - """ - This is used to create external functions for ccompiler. - A simple annotator that creates the following program: - | - -- begin -- - | - add - | - subtract - | - multiply - | - -- end -- - | - """ - - def __init__(self): - super(CcompilerAnnotator, self).__init__() - self.in_compiler = 0 - - def visit_call(self, call): - if call.op.name == "add": # Annotate begin at args - if self.in_compiler == 1: - lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") - rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") - op = relay.add(lhs, rhs) - self.in_compiler = 2 - return op - elif call.op.name == "subtract": - if self.in_compiler == 1: - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - return relay.subtract(lhs, rhs) - elif call.op.name == "multiply": # Annotate end at output - self.in_compiler = 1 - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - op = relay.multiply(lhs, rhs) - if self.in_compiler == 2: - op = compiler_end(op, "ccompiler") - self.in_compiler = 0 - return op - return super().visit_call(call) - - @pytest.mark.parametrize("use_calculated_workspaces", [True, False]) def test_byoc_microtvm(use_calculated_workspaces): """This is a simple test case to check BYOC capabilities of AOT""" @@ -446,7 +409,7 @@ def test_byoc_microtvm(use_calculated_workspaces): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) mod = tvm.IRModule() - ann = CcompilerAnnotator() + ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) @@ -619,4 +582,4 @@ def test_transpose(interface_api, use_unpacked_api, use_calculated_workspaces): if __name__ == "__main__": - pytest.main([__file__]) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 29d420def184..5467589f956b 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -27,6 +27,7 @@ from tvm import relay from tvm import runtime from tvm.relay import transform +from tvm.relay.testing import byoc from tvm.contrib import utils from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator @@ -63,59 +64,6 @@ def visit_call(self, call): return Annotator().visit(func) -class CcompilerAnnotator(ExprMutator): - """ - A simple annotator that creates the following program: - | - -- begin -- - | - add - | - subtract - | - multiply - | - -- end -- - | - """ - - def __init__(self): - super(CcompilerAnnotator, self).__init__() - self.in_compiler = 0 - - def visit_call(self, call): - if call.op.name == "add": # Annotate begin at args - if self.in_compiler == 1: - lhs = compiler_begin(super().visit(call.args[0]), "ccompiler") - rhs = compiler_begin(super().visit(call.args[1]), "ccompiler") - op = relay.add(lhs, rhs) - self.in_compiler = 2 - return op - elif call.op.name == "subtract": - if self.in_compiler == 1: - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - return relay.subtract(lhs, rhs) - elif call.op.name == "multiply": # Annotate end at output - self.in_compiler = 1 - lhs = super().visit(call.args[0]) - rhs = super().visit(call.args[1]) - if isinstance(lhs, relay.expr.Var): - lhs = compiler_begin(lhs, "ccompiler") - if isinstance(rhs, relay.expr.Var): - rhs = compiler_begin(rhs, "ccompiler") - op = relay.multiply(lhs, rhs) - if self.in_compiler == 2: - op = compiler_end(op, "ccompiler") - self.in_compiler = 0 - return op - return super().visit_call(call) - - class WholeGraphAnnotator(ExprMutator): """ An annotator that creates a compiler for an entire graph. @@ -261,7 +209,7 @@ def test_multi_node_compiler(): r = relay.concatenate((q0, q1, q2), axis=0) f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) mod = tvm.IRModule() - ann = CcompilerAnnotator() + ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) mod = transform.PartitionGraph()(mod) mod = transform.InferType()(mod) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index b4de303680f5..3f0624e20d78 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -19,7 +19,9 @@ import copy import glob import os +import pathlib import pytest +import shutil pytest.importorskip("pty") import sys @@ -43,46 +45,36 @@ TARGET = tvm.target.target.micro("host") -def _make_sess_from_op(workspace, op_name, sched, arg_bufs): +def _make_sess_from_op(temp_dir, op_name, sched, arg_bufs): with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): mod = tvm.build(sched, arg_bufs, Target(TARGET, TARGET), name=op_name) - return _make_session(workspace, mod) + return _make_session(temp_dir, mod) -def _make_session(workspace, mod): - compiler = tvm.micro.DefaultCompiler(target=TARGET) - opts = tvm.micro.default_options( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") +def _make_session(temp_dir, mod): + template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") + project = tvm.micro.generate_project( + template_project_dir, mod, temp_dir / "project", {"verbose": 1} ) - micro_binary = tvm.micro.build_static_runtime( - workspace, - compiler, - mod, - opts, - extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], - ) - - flasher_kw = { - "debug": DEBUG, - } - flasher = compiler.flasher(**flasher_kw) - return tvm.micro.Session(binary=micro_binary, flasher=flasher) + project.build() + project.flash() + return tvm.micro.Session(project.transport()) -def _make_add_sess(workspace): +def _make_add_sess(temp_dir): A = tvm.te.placeholder((2,), dtype="int8") B = tvm.te.placeholder((1,), dtype="int8") C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C") sched = tvm.te.create_schedule(C.op) - return _make_sess_from_op(workspace, "add", sched, [A, B, C]) + return _make_sess_from_op(temp_dir, "add", sched, [A, B, C]) -def _make_ident_sess(workspace): +def _make_ident_sess(temp_dir): A = tvm.te.placeholder((2,), dtype="int8") B = tvm.te.compute(A.shape, lambda i: A[i], name="B") sched = tvm.te.create_schedule(B.op) - return _make_sess_from_op(workspace, "ident", sched, [A, B]) + return _make_sess_from_op(temp_dir, "ident", sched, [A, B]) @pytest.mark.skip(reason="We don't currently use uTVM") @@ -91,9 +83,9 @@ def test_compile_runtime(): """Test compiling the on-device runtime.""" import tvm.micro - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() - with _make_add_sess(workspace) as sess: + with _make_add_sess(temp_dir) as sess: A_data = tvm.nd.array(np.array([2, 3], dtype="int8"), device=sess.device) assert (A_data.numpy() == np.array([2, 3])).all() B_data = tvm.nd.array(np.array([4], dtype="int8"), device=sess.device) @@ -131,9 +123,9 @@ def test_reset(): import tvm.micro from tvm.micro import transport - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() - with _make_add_sess(workspace) as sess: + with _make_add_sess(temp_dir) as sess: try: sess._rpc.get_function("tvm.testing.reset_server")() assert False, "expected to raise SessionTerminatedError; did not raise" @@ -145,9 +137,11 @@ def test_reset(): @tvm.testing.requires_micro def test_graph_executor(): """Test use of the graph executor with microTVM.""" - import tvm.micro - workspace = tvm.micro.Workspace(debug=True) + ws_root = pathlib.Path(os.path.dirname(__file__) + "/micro-workspace") + if ws_root.exists(): + shutil.rmtree(ws_root) + temp_dir = tvm.contrib.utils.tempdir(ws_root.resolve()) relay_mod = tvm.parser.fromtext( """ #[version = "0.0.5"] @@ -160,7 +154,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): factory = tvm.relay.build(relay_mod, target=TARGET) - with _make_session(workspace, factory.get_lib()) as sess: + with _make_session(temp_dir, factory) as sess: graph_mod = tvm.micro.create_local_graph_executor( factory.get_graph_json(), sess.get_system_lib(), sess.device ) @@ -181,9 +175,9 @@ def test_std_math_functions(): """Verify that standard math functions can be used.""" import tvm.micro - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() - with _make_add_sess(workspace) as sess: + with _make_add_sess(temp_dir) as sess: A_data = tvm.nd.array(np.array([2, 3], dtype="int8"), device=sess.device) assert (A_data.numpy() == np.array([2, 3])).all() B_data = tvm.nd.array(np.array([4], dtype="int8"), device=sess.device) @@ -194,12 +188,12 @@ def test_std_math_functions(): system_lib = sess.get_system_lib() system_lib.get_function("add")(A_data, B_data, C_data) - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() A = tvm.te.placeholder((2,), dtype="float32", name="A") B = tvm.te.compute(A.shape, lambda i: tvm.te.exp(A[i]), name="B") s = tvm.te.create_schedule(B.op) - with _make_sess_from_op(workspace, "myexpf", s, [A, B]) as sess: + with _make_sess_from_op(temp_dir, "myexpf", s, [A, B]) as sess: A_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) B_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) lib = sess.get_system_lib() @@ -214,12 +208,12 @@ def test_platform_timer(): """Verify the platform timer can be used to time remote functions.""" import tvm.micro - workspace = tvm.micro.Workspace() + temp_dir = tvm.contrib.utils.tempdir() A = tvm.te.placeholder((2,), dtype="float32", name="A") B = tvm.te.compute(A.shape, lambda i: tvm.te.exp(A[i]), name="B") s = tvm.te.create_schedule(B.op) - with _make_sess_from_op(workspace, "myexpf", s, [A, B]) as sess: + with _make_sess_from_op(temp_dir, "myexpf", s, [A, B]) as sess: A_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) B_data = tvm.nd.array(np.array([2.0, 3.0], dtype="float32"), device=sess.device) lib = sess.get_system_lib() @@ -232,5 +226,4 @@ def test_platform_timer(): if __name__ == "__main__": - test_graph_executor() -# sys.exit(pytest.main([__file__] + sys.argv[1:])) + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index e51566322e60..160cc11e64e3 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -349,7 +349,7 @@ def _run_unlinked(lib_mod): @pytest.mark.skip(reason="We don't currently use uTVM") @tvm.testing.requires_micro def test_crt_link_params(): - import tvm.micro + from tvm import micro for dtype in LINKABLE_DTYPES: mod, param_init = _make_mod_and_params(dtype) @@ -357,34 +357,21 @@ def test_crt_link_params(): main_func = mod["main"] target = "c --system-lib --runtime=c --link-params" with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - graph_json, lib, params = tvm.relay.build(mod, target, params=param_init) - assert set(params.keys()) == {"p0", "p1"} # NOTE: op folded + factory = tvm.relay.build(mod, target, params=param_init) + assert set(factory.get_params().keys()) == {"p0", "p1"} # NOTE: op folded - workspace = tvm.micro.Workspace() - compiler = tvm.micro.DefaultCompiler(target=target) - opts = tvm.micro.default_options( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") + temp_dir = tvm.contrib.utils.tempdir() + template_project_dir = os.path.join( + tvm.micro.get_standalone_crt_dir(), "template", "host" ) - opts["bin_opts"]["ldflags"].append("-DTVM_HOST_USE_GRAPH_EXECUTOR_MODULE") - - micro_binary = tvm.micro.build_static_runtime( - workspace, - compiler, - lib, - compiler_options=opts, - extra_libs=[ - tvm.micro.get_standalone_crt_lib(m) - for m in ("memory", "graph_executor_module", "graph_executor") - ], + project = tvm.micro.generate_project( + template_project_dir, factory, temp_dir / "project", {"verbose": 1} ) - - flasher_kw = { - "debug": False, - } - flasher = compiler.flasher(**flasher_kw) - with tvm.micro.Session(binary=micro_binary, flasher=flasher) as sess: + project.build() + project.flash() + with tvm.micro.Session(project.transport()) as sess: graph_rt = tvm.micro.session.create_local_graph_executor( - graph_json, sess.get_system_lib(), sess.device + factory.get_graph_json(), sess.get_system_lib(), sess.device ) # NOTE: not setting params here. diff --git a/tests/python/unittest/test_micro_artifact.py b/tests/python/unittest/test_micro_artifact.py deleted file mode 100644 index fc180200720d..000000000000 --- a/tests/python/unittest/test_micro_artifact.py +++ /dev/null @@ -1,149 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Unit tests for the artifact module.""" - -import pytest -import json -import os -import shutil -import tvm - -from tvm.contrib import utils - -pytest.importorskip("tvm.micro") -from tvm.micro import artifact - -FILE_LIST = ["label1", "label2", "label12", "unlabelled"] - - -TEST_METADATA = {"foo": "bar"} - - -TEST_LABELS = {"label1": ["label1", "label12"], "label2": ["label2", "label12"]} - - -def build_artifact(artifact_path, immobile=False): - os.mkdir(artifact_path) - - for f in FILE_LIST: - with open(os.path.join(artifact_path, f), "w") as lib_f: - lib_f.write(f"{f}\n") - - sub_dir = os.path.join(artifact_path, "sub_dir") - os.mkdir(sub_dir) - os.symlink("label1", os.path.join(artifact_path, "rel_symlink")) - os.symlink("label2", os.path.join(artifact_path, "abs_symlink"), "label2") - os.symlink( - os.path.join(artifact_path, "sub_dir"), os.path.join(artifact_path, "abs_dir_symlink") - ) - - from tvm.micro import artifact - - art = artifact.Artifact(artifact_path, TEST_LABELS, TEST_METADATA, immobile=immobile) - - return art - - -@tvm.testing.requires_micro -def test_basic_functionality(): - temp_dir = utils.tempdir() - artifact_path = temp_dir.relpath("foo") - art = build_artifact(artifact_path) - - assert art.abspath("bar") == os.path.join(artifact_path, "bar") - - for label, paths in TEST_LABELS.items(): - assert art.label(label) == paths - assert art.label_abspath(label) == [os.path.join(artifact_path, p) for p in paths] - - -@tvm.testing.requires_micro -def test_archive(): - from tvm.micro import artifact - - temp_dir = utils.tempdir() - art = build_artifact(temp_dir.relpath("foo")) - - # Create archive - archive_path = art.archive(temp_dir.temp_dir) - assert archive_path == temp_dir.relpath("foo.tar") - - # Inspect created archive - unpack_dir = temp_dir.relpath("unpack") - os.mkdir(unpack_dir) - shutil.unpack_archive(archive_path, unpack_dir) - - for path in FILE_LIST: - with open(os.path.join(unpack_dir, "foo", path)) as f: - assert f.read() == f"{path}\n" - - with open(os.path.join(unpack_dir, "foo", "metadata.json")) as metadata_f: - metadata = json.load(metadata_f) - - assert metadata["version"] == 2 - assert metadata["labelled_files"] == TEST_LABELS - assert metadata["metadata"] == TEST_METADATA - - # Unarchive and verify basic functionality - unarchive_base_dir = temp_dir.relpath("unarchive") - unarch = artifact.Artifact.unarchive(archive_path, unarchive_base_dir) - - assert unarch.metadata == TEST_METADATA - assert unarch.labelled_files == TEST_LABELS - for f in FILE_LIST: - assert os.path.exists(os.path.join(unarchive_base_dir, f)) - - -@tvm.testing.requires_micro -def test_metadata_only(): - from tvm.micro import artifact - - temp_dir = utils.tempdir() - base_dir = temp_dir.relpath("foo") - art = build_artifact(base_dir) - - artifact_path = art.archive(temp_dir.relpath("foo.artifact"), metadata_only=True) - unarch_base_dir = temp_dir.relpath("bar") - unarch = artifact.Artifact.unarchive(artifact_path, unarch_base_dir) - assert unarch.base_dir == base_dir - - for p in unarch.label_abspath("label1") + unarch.label_abspath("label2"): - assert os.path.exists(p) - - os.unlink(art.abspath("label1")) - with open(art.abspath("label2"), "w+") as f: - f.write("changed line\n") - - try: - artifact.Artifact.unarchive(artifact_path, os.path.join(temp_dir.temp_dir, "bar2")) - assert False, "unarchive should raise error" - except artifact.ArchiveModifiedError as err: - assert str(err) == ( - "Files in metadata-only archive have been modified:\n" - " * label1: original file not found\n" - " * label2: sha256 mismatch: expected " - "6aa3c5668c8794c791400e19ecd7123949ded1616eafb0395acdd2d896354e83, got " - "ed87db21670a81819d65eccde87c5ae0243b2b61783bf77e9b27993be9a3eca0" - ) - - -if __name__ == "__main__": - test_basic_functionality() - test_archive() - test_metadata_only() - # TODO: tests for dir symlinks, symlinks out of bounds, loading malformed artifact tars. diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 5a32385632fc..92c1174e728c 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -27,6 +27,7 @@ import tvm import tvm.relay from tvm.relay.backend import executor_factory +from tvm.relay.testing import byoc import tvm.runtime.module import tvm.testing from tvm.contrib import utils @@ -345,5 +346,69 @@ def test_export_non_dso_exportable(): ) +@tvm.testing.requires_micro +def test_export_byoc_c_module(): + """Test BYOC flow when it produces DSO-exportable modules. + + NOTE the general BYOC flow is not fully supported by Model Library Format right now. + """ + x = tvm.relay.var("x", shape=(10, 10)) + w0 = tvm.relay.var("w0", shape=(10, 10)) + w1 = tvm.relay.var("w1", shape=(10, 10)) + w2 = tvm.relay.var("w2", shape=(10, 10)) + w3 = tvm.relay.var("w3", shape=(10, 10)) + w4 = tvm.relay.var("w4", shape=(10, 10)) + w5 = tvm.relay.var("w5", shape=(10, 10)) + w6 = tvm.relay.var("w6", shape=(10, 10)) + w7 = tvm.relay.var("w7", shape=(10, 10)) + + # C compiler + z0 = tvm.relay.add(x, w0) + p0 = tvm.relay.subtract(z0, w1) + q0 = tvm.relay.multiply(p0, w2) + + z1 = tvm.relay.add(x, w3) + p1 = tvm.relay.subtract(z1, w4) + q1 = tvm.relay.multiply(p1, w5) + + # Other parts on TVM + z2 = tvm.relay.add(x, w6) + q2 = tvm.relay.subtract(z2, w7) + + r = tvm.relay.concatenate((q0, q1, q2), axis=0) + f = tvm.relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = tvm.IRModule() + ann = byoc.CcompilerAnnotator() + mod["main"] = ann.visit(f) + mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) + mod = tvm.relay.transform.InferType()(mod) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(mod, tvm.target.target.micro("host")) + + temp_dir = utils.tempdir() + mlf_tar_path = temp_dir.relpath("lib.tar") + + from tvm import micro + + micro.export_model_library_format(factory, mlf_tar_path) + + with tarfile.open(mlf_tar_path, "r:*") as tf: + tar_members = [ti.name for ti in tf.getmembers()] + print("tar members", tar_members) + assert "./metadata.json" in tar_members + with tf.extractfile("./metadata.json") as f: + metadata = json.load(f) + main_md = metadata["memory"]["functions"]["main"] + assert main_md == [ + { + "constants_size_bytes": 0, + "device": 1, + "io_size_bytes": 4800, + "workspace_size_bytes": 800, + } + ] + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py new file mode 100644 index 000000000000..b5e2a57c122c --- /dev/null +++ b/tests/python/unittest/test_micro_project_api.py @@ -0,0 +1,424 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +import io +import json +import sys +import unittest +from unittest import mock + +import pytest + +import tvm +from tvm.micro import project_api + + +class BaseTestHandler(project_api.server.ProjectAPIHandler): + + DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo( + platform_name="platform_name", + is_template=True, + model_library_format_path="./model-library-format-path.sh", + project_options=[ + project_api.server.ProjectOption(name="foo", help="Option foo"), + project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), + ], + ) + + def server_info_query(self, tvm_version): + return self.DEFAULT_TEST_SERVER_INFO + + def generate_project(self, model_library_format_path, crt_path, project_path, options): + assert False, "generate_project is not implemented for this test" + + def build(self, options): + assert False, "build is not implemented for this test" + + def flash(self, options): + assert False, "flash is not implemented for this test" + + def open_transport(self, options): + assert False, "open_transport is not implemented for this test" + + def close_transport(self, options): + assert False, "open_transport is not implemented for this test" + + def read_transport(self, n, timeout_sec): + assert False, "read_transport is not implemented for this test" + + def write_transport(self, data, timeout_sec): + assert False, "write_transport is not implemented for this test" + + +class Transport: + def readable(self): + return True + + def writable(self): + return True + + def seekable(self): + return False + + closed = False + + def __init__(self): + self.data = bytearray() + self.rpos = 0 + + self.items = [] + + def read(self, size=-1): + to_read = len(self.data) - self.rpos + if size != -1: + to_read = min(size, to_read) + + rpos = self.rpos + self.rpos += to_read + return self.data[rpos : self.rpos] + + def write(self, data): + self.data.extend(data) + + +class ClientServerFixture: + def __init__(self, handler): + self.handler = handler + self.client_to_server = Transport() + self.server_to_client = Transport() + + self.server = project_api.server.ProjectAPIServer( + self.client_to_server, self.server_to_client, handler + ) + self.client = project_api.client.ProjectAPIClient( + self.server_to_client, + self.client_to_server, + testonly_did_write_request=self._process_server_request, + ) + + self.expect_failure = False + + def _process_server_request(self): + assert self.server.serve_one_request() == ( + not self.expect_failure + ), "Server failed to process request" + + +def test_server_info_query(): + fixture = ClientServerFixture(BaseTestHandler()) + + # Examine reply explicitly because these are the defaults for all derivative test cases. + reply = fixture.client.server_info_query(tvm.__version__) + assert reply["protocol_version"] == 1 + assert reply["platform_name"] == "platform_name" + assert reply["is_template"] == True + assert reply["model_library_format_path"] == "./model-library-format-path.sh" + assert reply["project_options"] == [ + {"name": "foo", "choices": None, "help": "Option foo"}, + {"name": "bar", "choices": ["qux"], "help": "Option bar"}, + ] + + +def test_server_info_query_wrong_tvm_version(): + def server_info_query(tvm_version): + raise project_api.server.UnsupportedTVMVersionError() + + with mock.patch.object(BaseTestHandler, "server_info_query", side_effect=server_info_query): + fixture = ClientServerFixture(BaseTestHandler()) + with pytest.raises(project_api.server.UnsupportedTVMVersionError) as exc_info: + fixture.client.server_info_query(tvm.__version__) + + assert "UnsupportedTVMVersionError" in str(exc_info.value) + + +def test_server_info_query_wrong_protocol_version(): + ServerInfoProtocol = collections.namedtuple( + "ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"] + ) + + def server_info_query(tvm_version): + return ServerInfoProtocol( + protocol_version=0, **BaseTestHandler.DEFAULT_TEST_SERVER_INFO._asdict() + ) + + with mock.patch.object(BaseTestHandler, "server_info_query", side_effect=server_info_query): + fixture = ClientServerFixture(BaseTestHandler()) + with pytest.raises(project_api.client.UnsupportedProtocolVersionError) as exc_info: + fixture.client.server_info_query(tvm.__version__) + + assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value) + + +def test_base_test_handler(): + """All methods should raise AssertionError on BaseTestHandler.""" + fixture = ClientServerFixture(BaseTestHandler()) + + for method in dir(fixture.handler): + if method.startswith("_") or not callable(method) or method == "server_info_query": + continue + + with self.assertThrows(AssertionError) as exc_info: + getattr(fixture.client, method)() + + assert (exc_info.exception) == f"{method} is not implemented for this test" + + +def test_build(): + with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + fixture.client.build(options={"bar": "baz"}) + + fixture.handler.build.assert_called_once_with(options={"bar": "baz"}) + + +def test_flash(): + with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + fixture.client.flash(options={"bar": "baz"}) + fixture.handler.flash.assert_called_once_with(options={"bar": "baz"}) + + +def test_open_transport(): + timeouts = project_api.server.TransportTimeouts( + session_start_retry_timeout_sec=1.0, + session_start_timeout_sec=2.0, + session_established_timeout_sec=3.0, + ) + + with mock.patch.object(BaseTestHandler, "open_transport", return_value=timeouts) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + assert fixture.client.open_transport(options={"bar": "baz"}) == { + "timeouts": dict(timeouts._asdict()) + } + fixture.handler.open_transport.assert_called_once_with({"bar": "baz"}) + + +def test_close_transport(): + with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + fixture.client.close_transport() + fixture.handler.close_transport.assert_called_once_with() + + +def test_read_transport(): + with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch: + fixture = ClientServerFixture(BaseTestHandler()) + assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"} + + fixture.handler.read_transport.assert_called_with(128, 5.0) + + fixture.handler.read_transport.side_effect = project_api.server.IoTimeoutError + with pytest.raises(project_api.server.IoTimeoutError) as exc_info: + fixture.client.read_transport(256, timeout_sec=10.0) + + fixture.handler.read_transport.assert_called_with(256, 10.0) + + fixture.handler.read_transport.side_effect = project_api.server.TransportClosedError + with pytest.raises(project_api.server.TransportClosedError) as exc_info: + fixture.client.read_transport(512, timeout_sec=15.0) + + fixture.handler.read_transport.assert_called_with(512, 15.0) + + assert fixture.handler.read_transport.call_count == 3 + + +def test_write_transport(): + with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None + fixture.handler.write_transport.assert_called_with(b"foo", 5.0) + + fixture.handler.write_transport.side_effect = project_api.server.IoTimeoutError + with pytest.raises(project_api.server.IoTimeoutError) as exc_info: + fixture.client.write_transport(b"bar", timeout_sec=10.0) + + fixture.handler.write_transport.assert_called_with(b"bar", 10.0) + + fixture.handler.write_transport.side_effect = project_api.server.TransportClosedError + with pytest.raises(project_api.server.TransportClosedError) as exc_info: + fixture.client.write_transport(b"baz", timeout_sec=15.0) + + fixture.handler.write_transport.assert_called_with(b"baz", 15.0) + + assert fixture.handler.write_transport.call_count == 3 + + +class ProjectAPITestError(Exception): + """An error raised in test.""" + + +def test_method_raises_error(): + with mock.patch.object( + BaseTestHandler, "close_transport", side_effect=ProjectAPITestError + ) as patch: + fixture = ClientServerFixture(BaseTestHandler()) + with pytest.raises(project_api.server.ServerError) as exc_info: + fixture.client.close_transport() + + fixture.handler.close_transport.assert_called_once_with() + assert "ProjectAPITestError" in str(exc_info.value) + + +def test_method_not_found(): + fixture = ClientServerFixture(BaseTestHandler()) + + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("invalid_method", {"bar": None}) + + assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND + + +def test_extra_param(): + fixture = ClientServerFixture(BaseTestHandler()) + + # test one with has_preprocssing and one without + assert hasattr(fixture.server, "_dispatch_build") == False + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("build", {"invalid_param_name": None, "options": {}}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "build: extra parameters: invalid_param_name" in str(exc_info.value) + + assert hasattr(fixture.server, "_dispatch_open_transport") == True + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("open_transport", {"invalid_param_name": None, "options": {}}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value) + + +def test_missing_param(): + fixture = ClientServerFixture(BaseTestHandler()) + + # test one with has_preprocssing and one without + assert hasattr(fixture.server, "_dispatch_build") == False + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("build", {}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "build: parameter options not given" in str(exc_info.value) + + assert hasattr(fixture.server, "_dispatch_open_transport") == True + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("open_transport", {}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "open_transport: parameter options not given" in str(exc_info.value) + + +def test_incorrect_param_type(): + fixture = ClientServerFixture(BaseTestHandler()) + + # The error message given at the JSON-RPC server level doesn't make sense when preprocessing is + # used. Only test without preprocessing here. + assert hasattr(fixture.server, "_dispatch_build") == False + with pytest.raises(project_api.server.JSONRPCError) as exc_info: + fixture.client._request_reply("build", {"options": None}) + + assert exc_info.value.code == project_api.server.ErrorCode.INVALID_PARAMS + assert "build: parameter options: want , got " in str( + exc_info.value + ) + + +def test_invalid_request(): + fixture = ClientServerFixture(BaseTestHandler()) + + # Invalid JSON does not get a reply. + fixture.client_to_server.write(b"foobar\n") + assert fixture.server.serve_one_request() == False + assert fixture.server_to_client.read() == b"" + + # EOF causes a clean return + assert fixture.server.serve_one_request() == False + assert fixture.server_to_client.read() == b"" + + def _request_reply(request): + fixture.client_to_server.write(request + b"\n") + assert fixture.server.serve_one_request() == False + return json.loads(fixture.server_to_client.read()) + + # Parseable JSON with the wrong schema gets a reply. + assert _request_reply(b"1") == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": "request: want dict; got 1", + }, + "id": None, + "jsonrpc": "2.0", + } + + # Incorrect JSON-RPC spec version. + assert _request_reply(b'{"jsonrpc": 1.0}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": 'request["jsonrpc"]: want "2.0"; got 1.0', + }, + "id": None, + "jsonrpc": "2.0", + } + + # Method not a str + assert _request_reply(b'{"jsonrpc": "2.0", "method": 123}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": 'request["method"]: want str; got 123', + }, + "id": None, + "jsonrpc": "2.0", + } + + # Method name has invalid characters + assert _request_reply(b'{"jsonrpc": "2.0", "method": "bar!"}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": "request[\"method\"]: should match regex ^[a-zA-Z0-9_]+$; got 'bar!'", + }, + "id": None, + "jsonrpc": "2.0", + } + + # params not a dict + assert _request_reply(b'{"jsonrpc": "2.0", "method": "bar", "params": 123}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": "request[\"params\"]: want dict; got ", + }, + "id": None, + "jsonrpc": "2.0", + } + + # id not valid + assert _request_reply(b'{"jsonrpc": "2.0", "method": "bar", "params": {}, "id": {}}') == { + "error": { + "code": project_api.server.ErrorCode.INVALID_REQUEST, + "data": None, + "message": 'request["id"]: want str, number, null; got {}', + }, + "id": None, + "jsonrpc": "2.0", + } + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_transport.py b/tests/python/unittest/test_micro_transport.py index b0f99681af2e..a188e612763f 100644 --- a/tests/python/unittest/test_micro_transport.py +++ b/tests/python/unittest/test_micro_transport.py @@ -132,25 +132,25 @@ def test_transport_logger(self): transport.to_return = 3 transport_logger.write(b"data", 3.0) assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 3 B]: 64 61 74 " - " dat" + "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" + " data" ) # Normal log, multi-line data written. transport.to_return = 20 transport_logger.write(b"data" * 6, 3.0) assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 20 B]:\n" + "foo: write { 3.00s} <- [ 24 B]:\n" "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" - "0010 64 61 74 61 data" + "0010 64 61 74 61 64 61 74 61 datadata" ) # Lack of timeout prints. transport.to_return = 3 transport_logger.write(b"data", None) assert test_log.records[-1].getMessage() == ( - "foo: write { None } <- [ 3 B]: 64 61 74 " - " dat" + "foo: write { None } <- [ 4 B]: 64 61 74 61" + " data" ) # IoTimeoutError includes the timeout value. diff --git a/tutorials/micro/micro_tflite.py b/tutorials/micro/micro_tflite.py index 5e517bf062ef..5a39be08e108 100644 --- a/tutorials/micro/micro_tflite.py +++ b/tutorials/micro/micro_tflite.py @@ -208,52 +208,92 @@ with tvm.transform.PassContext( opt_level=3, config={"tir.disable_vectorize": True}, disabled_pass=["FuseOps", "AlterOpLayout"] ): - graph, c_mod, c_params = relay.build(mod, target=TARGET, params=params) + module = relay.build(mod, target=TARGET, params=params) -# Compiling for a host simulated device -# ------------------------------------- +# Inspecting the compilation output +# --------------------------------- # -# First, compile a static microTVM runtime for the targeted device. In this case, the host simulated -# device is used. -compiler = tvm.micro.DefaultCompiler(target=TARGET) -opts = tvm.micro.default_options( - os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") +# The compilation process has produced some C code implementing the operators in this graph. We +# can inspect it by printing the CSourceModule contents (for the purposes of this tutorial, let's +# just print the first 10 lines): + +c_source_module = module.get_lib().imported_modules[0] +assert c_source_module.type_key == "c", "tutorial is broken" + +c_source_code = c_source_module.get_source() +first_few_lines = c_source_code.split("\n")[:10] +assert any( + l.startswith("TVM_DLL int32_t tvmgen_default_") for l in first_few_lines +), f"tutorial is broken: {first_few_lines!r}" +print("\n".join(first_few_lines)) + + +# Compiling the generated code +# ---------------------------- +# +# Now we need to incorporate the generated C code into a project that allows us to run inference on the +# device. The simplest way to do this is to integrate it yourself, using microTVM's standard output format +# (:doc:`Model Library Format` `). This is a tarball with a standard layout: + +# Get a temporary path where we can store the tarball (since this is running as a tutorial). +import tempfile + +fd, model_library_format_tar_path = tempfile.mkstemp() +os.close(fd) +os.unlink(model_library_format_tar_path) +tvm.micro.export_model_library_format(module, model_library_format_tar_path) + +import tarfile + +with tarfile.open(model_library_format_tar_path, "r:*") as tar_f: + print("\n".join(f" - {m.name}" for m in tar_f.getmembers())) + +# Cleanup for tutorial: +os.unlink(model_library_format_tar_path) + + +# TVM also provides a standard way for embedded platforms to automatically generate a standalone +# project, compile and flash it to a target, and communicate with it using the standard TVM RPC +# protocol. The Model Library Format serves as the model input to this process. When embedded +# platforms provide such an integration, they can be used directly by TVM for both host-driven +# inference and autotuning . This integration is provided by the +# `microTVM Project API` _, +# +# Embedded platforms need to provide a Template Project containing a microTVM API Server (typically, +# this lives in a file ``microtvm_api_server.py`` in the root directory). Let's use the example ``host`` +# project in this tutorial, which simulates the device using a POSIX subprocess and pipes: + +import subprocess +import pathlib + +repo_root = pathlib.Path( + subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() ) +template_project_path = repo_root / "src" / "runtime" / "crt" / "host" +project_options = {} # You can use options to provide platform-specific options through TVM. # Compiling for physical hardware (or an emulated board, like the mps_an521) # -------------------------------------------------------------------------- -# For physical hardware, comment out the previous section selecting TARGET and BOARD and use this -# compiler definition instead of the one above. -# -# import subprocess -# from tvm.micro.contrib import zephyr -# -# repo_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding='utf-8').strip() -# project_dir = os.path.join(repo_root, "apps", "microtvm", "zephyr", "host_driven") -# compiler = zephyr.ZephyrCompiler( -# project_dir=project_dir, -# board=BOARD, -# zephyr_toolchain_variant="zephyr", -# ) -# -# opts = tvm.micro.default_options(f"{project_dir}/crt") -# -# -# # Enable printing memory usage statistics for the runtime image generated by Zephyr -# logging.basicConfig(level="INFO") - -workspace = tvm.micro.Workspace() -micro_binary = tvm.micro.build_static_runtime( - workspace, - compiler, - c_mod, - opts, - # Use the microTVM memory manager. If, in your main.cc, you change TVMPlatformMemoryAllocate and - # TVMPlatformMemoryFree to use e.g. malloc() and free(), you can omit this extra library. - extra_libs=[tvm.micro.get_standalone_crt_lib("memory")], +# For physical hardware, you can try out the Zephyr platform by using a different template project +# and options: +# +# template_project_path = repo_root / "apps" / "microtvm" / "zephyr" / "template_project" +# project_options = {"project_type": "host_driven", zephyr_board": "nucleo_f746zg"}} + +# Create a temporary directory +import tvm.contrib.utils + +temp_dir = tvm.contrib.utils.tempdir() +generated_project_dir = temp_dir / "generated-project" +generated_project = tvm.micro.generate_project( + template_project_path, module, generated_project_dir, project_options ) +# Build and flash the project +generated_project.build() +generated_project.flash() + ###################################################################### # Next, establish a session with the simulated device and run the @@ -261,14 +301,13 @@ # microcontroller, but in this tutorial, it simply launches a subprocess # to stand in for an attached microcontroller. -flasher = compiler.flasher() -with tvm.micro.Session(binary=micro_binary, flasher=flasher) as session: +with tvm.micro.Session(transport_context_manager=generated_project.transport()) as session: graph_mod = tvm.micro.create_local_graph_executor( - graph, session.get_system_lib(), session.device + module.get_graph_json(), session.get_system_lib(), session.device ) # Set the model parameters using the lowered parameters produced by `relay.build`. - graph_mod.set_input(**c_params) + graph_mod.set_input(**module.get_params()) # The model consumes a single float32 value and returns a predicted sine value. To pass the # input value we construct a tvm.nd.array object with a single contrived number as input. For