Skip to content

Commit

Permalink
Additional PR comments
Browse files Browse the repository at this point in the history
PR comments and Python 3.6 support

Linting fix

Re-add test onnx file

Test Arduino cli bug workaround

Support new hardware targets

Temporary fix for tests

Formatting issue

Spelling fix

Add test case for exact FQBN matching
  • Loading branch information
guberti committed Aug 16, 2021
1 parent 3248709 commit 7b14e83
Show file tree
Hide file tree
Showing 21 changed files with 123 additions and 95 deletions.
2 changes: 1 addition & 1 deletion apps/microtvm/arduino/example_project/src/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ tvm_workspace_t app_workspace;

// Blink code for debugging purposes
void TVMPlatformAbort(tvm_crt_error_t error) {
TVMLogf("TVMPlatformAbort: %08x\n", error);
TVMLogf("TVMPlatformAbort: 0x%08x\n", error);
for (;;) {
#ifdef LED_BUILTIN
digitalWrite(LED_BUILTIN, HIGH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* \file tvm/runtime/crt/host/crt_config.h
* \brief CRT configuration for the host-linked CRT.
*/
#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_
Expand Down
2 changes: 1 addition & 1 deletion apps/microtvm/arduino/host_driven/project.ino
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void loop() {
int to_read = min(Serial.available(), 128);

uint8_t data[to_read];
size_t bytes_remaining = Serial.readBytes(data, to_read);
size_t bytes_remaining = Serial.readBytes((char*) data, to_read);
uint8_t* arr_ptr = data;
while (bytes_remaining > 0) {
// Pass the received bytes to the RPC server.
Expand Down
7 changes: 4 additions & 3 deletions apps/microtvm/arduino/host_driven/src/model_support.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* under the License.
*/

#include "stdarg.h"
#include "standalone_crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h"
#include "stdarg.h"

// Blink code for debugging purposes
void TVMPlatformAbort(tvm_crt_error_t error) {
TVMLogf("TVMPlatformAbort: %08x\n", error);
for (;;);
TVMLogf("TVMPlatformAbort: 0x%08x\n", error);
for (;;)
;
}

size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* \file tvm/runtime/crt/host/crt_config.h
* \brief CRT configuration for the host-linked CRT.
*/
#ifndef TVM_RUNTIME_MICRO_CRT_CONFIG_H_
Expand Down
115 changes: 65 additions & 50 deletions apps/microtvm/arduino/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class BoardAutodetectFailed(Exception):
"due": {
"package": "arduino",
"architecture": "sam",
"board": "arduino_due_x",
"board": "arduino_due_x_dbg",
},
# Due to the way the Feather S2 bootloader works, compilation
# behaves fine but uploads cannot be done automatically
Expand Down Expand Up @@ -94,6 +94,11 @@ class BoardAutodetectFailed(Exception):
"architecture": "avr",
"board": "teensy41",
},
"wioterminal": {
"package": "Seeeduino",
"architecture": "samd",
"board": "seeed_wio_terminal",
},
}

PROJECT_TYPES = ["example_project", "host_driven"]
Expand Down Expand Up @@ -133,10 +138,23 @@ def server_info_query(self, tvm_version):
)

def _copy_project_files(self, api_server_dir, project_dir, project_type):
"""Copies the files for project_type into project_dir.
Notes
-----
template_dir is NOT a project type, and that directory is never copied
in this function. template_dir only holds this file and its unit tests,
so this file is copied separately in generate_project.
"""
project_types_folder = api_server_dir.parents[0]
shutil.copytree(
project_types_folder / project_type / "src", project_dir / "src", dirs_exist_ok=True
)
for item in (project_types_folder / project_type / "src").iterdir():
dest = project_dir / "src" / item.name
if item.is_dir():
shutil.copytree(item, dest)
else:
shutil.copy2(item, dest)

# Arduino requires the .ino file have the same filename as its containing folder
shutil.copy2(
project_types_folder / project_type / "project.ino",
Expand All @@ -146,7 +164,6 @@ def _copy_project_files(self, api_server_dir, project_dir, project_type):
CRT_COPY_ITEMS = ("include", "src")

def _copy_standalone_crt(self, source_dir, standalone_crt_dir):
# Copy over the standalone_crt directory
output_crt_dir = source_dir / "standalone_crt"
for item in self.CRT_COPY_ITEMS:
src_path = os.path.join(standalone_crt_dir, item)
Expand Down Expand Up @@ -200,9 +217,9 @@ def _template_model_header(self, source_dir, metadata):
with open(source_dir / "model.h", "r") as f:
model_h_template = Template(f.read())

# The structure of the "memory" key depends on the style -
# only style="full-model" works with AOT, so we'll check that
assert metadata["style"] == "full-model"
assert (
metadata["style"] == "full-model"
), "when generating AOT, expect only full-model Model Library Format"

template_values = {
"workspace_size_bytes": metadata["memory"]["functions"]["main"][0][
Expand All @@ -225,20 +242,20 @@ def _change_cpp_file_extensions(self, source_dir):
for filename in source_dir.rglob(f"*.inc"):
filename.rename(filename.with_suffix(".h"))

"""Arduino only supports includes relative to the top-level project, so this
finds each time we #include a file and changes the path to be relative to the
top-level project.ino file. For example, the line:
#include <tvm/runtime/crt/platform.h>
Might be changed to (depending on the source file's location):
def _convert_includes(self, project_dir, source_dir):
"""Changes all #include statements in project_dir to be relevant to their
containing file's location.
#include "../../../../include/tvm/runtime/crt/platform.h"
Arduino only supports includes relative to a file's location, so this
function finds each time we #include a file and changes the path to
be relative to the file location. Does not do this for standard C
libraries. Also changes angle brackets syntax to double quotes syntax.
We also need to leave standard library includes as-is.
"""
See Also
-----
https://www.arduino.cc/reference/en/language/structure/further-syntax/include/
def _convert_includes(self, project_dir, source_dir):
"""
for ext in ("c", "h", "cpp"):
for filename in source_dir.rglob(f"*.{ext}"):
with filename.open() as file:
Expand All @@ -263,27 +280,19 @@ def _convert_includes(self, project_dir, source_dir):
# be added in the future.
POSSIBLE_BASE_PATHS = ["src/standalone_crt/include/", "src/standalone_crt/crt_config/"]

"""Takes a single #include path, and returns the new location
it should point to (as described above). For example, one of the
includes for "src/standalone_crt/src/runtime/crt/common/ndarray.c" is:
#include <tvm/runtime/crt/platform.h>
For that line, _convert_includes might call _find_modified_include_path
with the arguments:
project_dir = "/path/to/project/dir"
file_path = "/path/to/project/dir/src/standalone_crt/src/runtime/crt/common/ndarray.c"
include_path = "tvm/runtime/crt/platform.h"
Given these arguments, _find_modified_include_path should return:
"../../../../../../src/standalone_crt/include/tvm/runtime/crt/platform.h"
See unit test in ./tests/test_arduino_microtvm_api_server.py
"""

def _find_modified_include_path(self, project_dir, file_path, include_path):
"""Takes a single #include path, and returns the location it should point to.
Examples
--------
>>> _find_modified_include_path(
... "/path/to/project/dir"
... "/path/to/project/dir/src/standalone_crt/src/runtime/crt/common/ndarray.c"
... "tvm/runtime/crt/platform.h"
... )
"../../../../../../src/standalone_crt/include/tvm/runtime/crt/platform.h"
"""
if include_path.endswith(".inc"):
include_path = re.sub(r"\.[a-z]+$", ".h", include_path)

Expand Down Expand Up @@ -314,8 +323,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec
source_dir = project_dir / "src"
source_dir.mkdir()

# Copies files from the template folder to project_dir. model.h is copied here,
# but will also need to be templated later.
# Copies files from the template folder to project_dir
shutil.copy2(API_SERVER_DIR / "microtvm_api_server.py", project_dir)
self._copy_project_files(API_SERVER_DIR, project_dir, options["project_type"])

Expand Down Expand Up @@ -359,17 +367,24 @@ def build(self, options):
# Specify project to compile
subprocess.run(compile_cmd)

"""We run the command `arduino-cli board list`, which produces
outputs of the form:
Port Type Board Name FQBN Core
/dev/ttyS4 Serial Port Unknown
/dev/ttyUSB0 Serial Port (USB) Spresense SPRESENSE:spresense:spresense SPRESENSE:spresense
"""

BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core")

def _parse_boards_tabular_str(self, tabular_str):
"""Parses the tabular output from `arduino-cli board list` into a 2D array
Examples
--------
>>> list(_parse_boards_tabular_str(bytes(
... "Port Type Board Name FQBN Core \n"
... "/dev/ttyS4 Serial Port Unknown \n"
... "/dev/ttyUSB0 Serial Port (USB) Spresense SPRESENSE:spresense:spresense SPRESENSE:spresense\n"
... "\n",
... "utf-8")))
[['/dev/ttys4', 'Serial Port', 'Unknown', '', ''], ['/dev/ttyUSB0', 'Serial Port (USB)',
'Spresense', 'SPRESENSE:spresense:spresense', 'SPRESENSE:spresense']]
"""

str_rows = tabular_str.split("\n")[:-2]
header = str_rows[0]
indices = [header.index(h) for h in self.BOARD_LIST_HEADERS] + [len(header)]
Expand All @@ -387,7 +402,7 @@ def _parse_boards_tabular_str(self, tabular_str):

def _auto_detect_port(self, options):
list_cmd = [options["arduino_cli_cmd"], "board", "list"]
list_cmd_output = subprocess.run(list_cmd, capture_output=True).stdout.decode("utf-8")
list_cmd_output = subprocess.run(list_cmd, stdout=subprocess.PIPE).stdout.decode("utf-8")

desired_fqbn = self._get_fqbn(options)
for line in self._parse_boards_tabular_str(list_cmd_output):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.

import subprocess
import sys
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -64,8 +65,8 @@ def test_find_modified_include_path(self, mock_pathlib_path):

BOARD_CONNECTED_OUTPUT = bytes(
"Port Type Board Name FQBN Core \n"
"/dev/ttyACM1 Serial Port (USB) Wrong Arduino arduino:mbed_nano:nano33 arduino:mbed_nano\n"
"/dev/ttyACM0 Serial Port (USB) Arduino Nano 33 BLE arduino:mbed_nano:nano33ble arduino:mbed_nano\n"
"/dev/ttyACM1 Serial Port (USB) Arduino Nano 33 arduino:mbed_nano:nano33 arduino:mbed_nano\n"
"/dev/ttyS4 Serial Port Unknown \n"
"\n",
"utf-8",
Expand All @@ -77,37 +78,38 @@ def test_find_modified_include_path(self, mock_pathlib_path):
"utf-8",
)

@mock.patch("subprocess.check_output")
def test_auto_detect_port(self, mock_subprocess_check_output):
@mock.patch("subprocess.run")
def test_auto_detect_port(self, mock_subprocess_run):
process_mock = mock.Mock()
handler = microtvm_api_server.Handler()

# Test it returns the correct port when a board is connected
mock_subprocess_check_output.return_value = self.BOARD_CONNECTED_OUTPUT
detected_port = handler._auto_detect_port(self.DEFAULT_OPTIONS)
assert detected_port == "/dev/ttyACM0"
mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT
assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == "/dev/ttyACM0"

# Test it raises an exception when no board is connected
mock_subprocess_check_output.return_value = self.BOARD_DISCONNECTED_OUTPUT
mock_subprocess_run.return_value.stdout = self.BOARD_DISCONNECTED_OUTPUT
with pytest.raises(microtvm_api_server.BoardAutodetectFailed):
handler._auto_detect_port(self.DEFAULT_OPTIONS)

@mock.patch("subprocess.check_call")
def test_flash(self, mock_subprocess_check_call):
# Test that the FQBN needs to match EXACTLY
handler._get_fqbn = mock.MagicMock(return_value="arduino:mbed_nano:nano33")
mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT
assert (
handler._auto_detect_port({**self.DEFAULT_OPTIONS, "arduino_board": "nano33"})
== "/dev/ttyACM1"
)

@mock.patch("subprocess.run")
def test_flash(self, mock_subprocess_run):
handler = microtvm_api_server.Handler()
handler._port = "/dev/ttyACM0"

# Test no exception thrown when code 0 returned
mock_subprocess_check_call.return_value = 0
# Test no exception thrown when command works
handler.flash(self.DEFAULT_OPTIONS)
mock_subprocess_check_call.assert_called_once()

# Test InvalidPortException raised when port incorrect
mock_subprocess_check_call.return_value = 2
with pytest.raises(microtvm_api_server.InvalidPortException):
handler.flash(self.DEFAULT_OPTIONS)
mock_subprocess_run.assert_called_once()

# Test SketchUploadException raised for other issues
mock_subprocess_check_call.return_value = 1
with pytest.raises(microtvm_api_server.SketchUploadException):
# Test exception raised when `arduino-cli upload` returns error code
mock_subprocess_run.side_effect = subprocess.CalledProcessError(2, [])
with pytest.raises(subprocess.CalledProcessError):
handler.flash(self.DEFAULT_OPTIONS)
8 changes: 4 additions & 4 deletions tests/lint/check_file_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@
# pytest config
"pytest.ini",
# microTVM tests
"tests/micro/testdata/digit-2.jpg",
"tests/micro/testdata/digit-9.jpg",
"tests/micro/testdata/mnist-8.onnx",
"tests/micro/testdata/yes_no.tflite",
"tests/micro/testdata/mnist/digit-2.jpg",
"tests/micro/testdata/mnist/digit-9.jpg",
"tests/micro/testdata/mnist/mnist-8.onnx",
"tests/micro/testdata/kws/yes_no.tflite",
# microTVM Zephyr runtime
"apps/microtvm/zephyr/template_project/CMakeLists.txt.template",
"apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-arm",
Expand Down
1 change: 1 addition & 0 deletions tests/micro/arduino/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"spresense": ("cxd5602gg", "spresense"),
"teensy40": ("imxrt1060", "teensy40"),
"teensy41": ("imxrt1060", "teensy41"),
"wioterminal": ("atsamd51", "wioterminal"),
}

TEMPLATE_PROJECT_DIR = (
Expand Down
8 changes: 4 additions & 4 deletions tests/micro/arduino/test_arduino_rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,17 @@ def test_onnx(platform, arduino_cli_cmd, tvm_debug, workspace_dir):

# Load test images.
this_dir = pathlib.Path(__file__).parent
testdata_dir = this_dir.parent / "testdata"
digit_2 = Image.open(testdata_dir / "digit-2.jpg").resize((28, 28))
mnist_testdata = this_dir.parent / "testdata" / "mnist"
digit_2 = Image.open(mnist_testdata / "digit-2.jpg").resize((28, 28))
digit_2 = np.asarray(digit_2).astype("float32")
digit_2 = np.expand_dims(digit_2, axis=0)

digit_9 = Image.open(testdata_dir / "digit-9.jpg").resize((28, 28))
digit_9 = Image.open(mnist_testdata / "digit-9.jpg").resize((28, 28))
digit_9 = np.asarray(digit_9).astype("float32")
digit_9 = np.expand_dims(digit_9, axis=0)

# Load ONNX model and convert to Relay.
onnx_model = onnx.load(testdata_dir / "mnist-8.onnx")
onnx_model = onnx.load(mnist_testdata / "mnist-8.onnx")
shape = {"Input3": (1, 1, 28, 28)}
relay_mod, params = relay.frontend.from_onnx(onnx_model, shape=shape, freeze_params=True)
relay_mod = relay.transform.DynamicToStatic()(relay_mod)
Expand Down
Loading

0 comments on commit 7b14e83

Please sign in to comment.