|
5 | 5 | # SPDX-License-Identifier: Apache-2.0 |
6 | 6 |
|
7 | 7 | import os |
| 8 | +import pathlib |
8 | 9 | import sys |
9 | 10 |
|
10 | 11 | import pytest |
11 | 12 |
|
| 13 | +import numba_dpex |
12 | 14 | from numba_dpex.core import config |
13 | 15 |
|
14 | | -from .common import script_path |
15 | | - |
16 | 16 | if config.TESTING_SKIP_NO_DEBUGGING: |
17 | 17 | pexpect = pytest.importorskip("pexpect") |
18 | 18 | else: |
@@ -102,3 +102,57 @@ def set_scheduler_lock(self): |
102 | 102 | @staticmethod |
103 | 103 | def script_path(script): |
104 | 104 | return script_path(script) |
| 105 | + |
| 106 | + |
| 107 | +def script_path(script): |
| 108 | + package_path = pathlib.Path(numba_dpex.__file__).parent |
| 109 | + return str(package_path / "examples/debug" / script) |
| 110 | + |
| 111 | + |
| 112 | +def line_number(file_path, text): |
| 113 | + """Return line number of the text in the file""" |
| 114 | + with open(file_path, "r") as lines: |
| 115 | + for line_number, line in enumerate(lines): |
| 116 | + if text in line: |
| 117 | + return line_number + 1 |
| 118 | + |
| 119 | + raise RuntimeError(f"Can not find {text} in {file_path}") |
| 120 | + |
| 121 | + |
| 122 | +def breakpoint_by_mark(script, mark, offset=0): |
| 123 | + """Return breakpoint for the mark in the script |
| 124 | +
|
| 125 | + Example: breakpoint_by_mark("script.py", "Set here") -> "script.py:25" |
| 126 | + """ |
| 127 | + return f"{script}:{line_number(script_path(script), mark) + offset}" |
| 128 | + |
| 129 | + |
| 130 | +def breakpoint_by_function(script, function): |
| 131 | + """Return breakpoint for the function in the script""" |
| 132 | + return breakpoint_by_mark(script, f"def {function}", 1) |
| 133 | + |
| 134 | + |
| 135 | +def setup_breakpoint( |
| 136 | + app: gdb, |
| 137 | + breakpoint: str, |
| 138 | + script=None, |
| 139 | + expected_location=None, |
| 140 | + expected_line=None, |
| 141 | +): |
| 142 | + if not script: |
| 143 | + script = breakpoint.split(" ")[0].split(":")[0] |
| 144 | + |
| 145 | + if not expected_location: |
| 146 | + expected_location = breakpoint.split(" ")[0] |
| 147 | + if not expected_location.split(":")[-1].isnumeric(): |
| 148 | + expected_location = breakpoint_by_function( |
| 149 | + script, expected_location.split(":")[-1] |
| 150 | + ) |
| 151 | + |
| 152 | + app.breakpoint(breakpoint) |
| 153 | + app.run(script) |
| 154 | + |
| 155 | + app.child.expect(rf"Thread .* hit Breakpoint .* at {expected_location}") |
| 156 | + |
| 157 | + if expected_line: |
| 158 | + app.child.expect(expected_line) |
0 commit comments