diff --git a/amaranth/_utils.py b/amaranth/_utils.py index 886de592e..bf4b49a16 100644 --- a/amaranth/_utils.py +++ b/amaranth/_utils.py @@ -3,12 +3,13 @@ import warnings import linecache import re +import unicodedata from collections import OrderedDict from collections.abc import Iterable __all__ = ["flatten", "union", "final", "deprecated", "get_linter_options", - "get_linter_option"] + "get_linter_option", "validate_name"] def flatten(i): @@ -101,3 +102,25 @@ def get_linter_option(filename, name, type, default): except ValueError: return default assert False + +_invalid_categories = { + "Zs", # space separators + "Zl", # line separators + "Zp", # paragraph separators + "Cc", # control codepoints (e.g. \0) + "Cs", # UTF-16 surrogate pair codepoints (no thanks WTF-8) + "Cn", # unassigned codepoints +} + +def validate_name(name, what, none_ok=False, empty_ok=False): + if not isinstance(name, str): + if name is None and none_ok: return + raise TypeError(f"{what} must be a string, not {name!r}") + if name == "" and not empty_ok: + raise NameError(f"{what} must be a non-empty string") + + # RTLIL allows bytes >= 33. In the same spirit we allow all characters that + # Unicode does not declare as separator or control. + for c in name: + if unicodedata.category(c) in _invalid_categories: + raise NameError(f"{what} {name!r} contains whitespace/control character {c!r}") diff --git a/amaranth/hdl/_ast.py b/amaranth/hdl/_ast.py index 551fcc388..064c79dae 100644 --- a/amaranth/hdl/_ast.py +++ b/amaranth/hdl/_ast.py @@ -1919,12 +1919,11 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F attrs=None, decoder=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) - if name is not None and not isinstance(name, str): - raise TypeError(f"Name must be a string, not {name!r}") if name is None: self.name = tracer.get_var_name(depth=2 + src_loc_at, default="$signal") else: self.name = name + validate_name(self.name, "Name", empty_ok=True) orig_shape = shape if shape is None: @@ -2127,8 +2126,7 @@ class ClockSignal(Value): """ def __init__(self, domain="sync", *, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) - if not isinstance(domain, str): - raise TypeError(f"Clock domain name must be a string, not {domain!r}") + validate_name(domain, "Clock domain name") if domain == "comb": raise ValueError(f"Domain '{domain}' does not have a clock") self._domain = domain @@ -2167,8 +2165,7 @@ class ResetSignal(Value): """ def __init__(self, domain="sync", allow_reset_less=False, *, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) - if not isinstance(domain, str): - raise TypeError(f"Clock domain name must be a string, not {domain!r}") + validate_name(domain, "Clock domain name") if domain == "comb": raise ValueError(f"Domain '{domain}' does not have a reset") self._domain = domain @@ -2850,8 +2847,7 @@ class IOPort(IOValue): def __init__(self, width, *, name=None, attrs=None, metadata=None, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) - if name is not None and not isinstance(name, str): - raise TypeError(f"Name must be a string, not {name!r}") + validate_name(name, "Name", none_ok=True) self.name = name or tracer.get_var_name(depth=2 + src_loc_at) self._width = operator.index(width) diff --git a/amaranth/hdl/_cd.py b/amaranth/hdl/_cd.py index 683315e8b..c4385caee 100644 --- a/amaranth/hdl/_cd.py +++ b/amaranth/hdl/_cd.py @@ -1,5 +1,6 @@ from .. import tracer from ._ast import Signal +from .._utils import validate_name __all__ = ["ClockDomain", "DomainError"] @@ -54,6 +55,7 @@ def __init__(self, name=None, *, clk_edge="pos", reset_less=False, async_reset=F name = tracer.get_var_name() except tracer.NameNotFound: raise ValueError("Clock domain name must be specified explicitly") + validate_name(name, "Clock domain name") if name.startswith("cd_"): name = name[3:] if name == "comb": @@ -78,6 +80,7 @@ def __init__(self, name=None, *, clk_edge="pos", reset_less=False, async_reset=F self.local = local def rename(self, new_name): + validate_name(new_name, "Clock domain name") self.name = new_name self.clk.name = self._name_for(new_name, "clk") if self.rst is not None: diff --git a/amaranth/hdl/_dsl.py b/amaranth/hdl/_dsl.py index 2322a831c..e43509031 100644 --- a/amaranth/hdl/_dsl.py +++ b/amaranth/hdl/_dsl.py @@ -5,7 +5,7 @@ import warnings import sys -from .._utils import flatten +from .._utils import flatten, validate_name from ..utils import bits_for from .. import tracer from ._ast import * @@ -426,6 +426,8 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None): warnings.warn("`reset=` is deprecated, use `init=` instead", DeprecationWarning, stacklevel=2) init = reset + validate_name(name, "FSM name") + validate_name(domain, "FSM clock domain") fsm_data = self._set_ctrl("FSM", { "name": name, "init": init, @@ -609,6 +611,7 @@ def _add_submodule(self, submodule, name=None, src_loc=None): if name == None: self._anon_submodules.append((submodule, src_loc)) else: + validate_name(name, "Submodule name") if name in self._named_submodules: raise NameError(f"Submodule named '{name}' already exists") self._named_submodules[name] = (submodule, src_loc) diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index f3e9fc9ae..e435deac3 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -3,7 +3,7 @@ import enum import warnings -from .._utils import flatten +from .._utils import flatten, validate_name from .. import tracer, _unused from . import _ast, _cd, _ir, _nir @@ -309,11 +309,13 @@ class Instance(Fragment): def __init__(self, type, *args, src_loc=None, src_loc_at=0, **kwargs): super().__init__(src_loc=src_loc or tracer.get_src_loc(src_loc_at)) + validate_name(type, "Instance type") self.type = type self.parameters = OrderedDict() self.named_ports = OrderedDict() for (kind, name, value) in args: + validate_name(name, "Instance argument name") if kind == "a": self.attrs[name] = value elif kind == "p": @@ -331,6 +333,7 @@ def __init__(self, type, *args, src_loc=None, src_loc_at=0, **kwargs): .format((kind, name, value))) for kw, arg in kwargs.items(): + validate_name(kw, "Instance keyword argument name") if kw.startswith("a_"): self.attrs[kw[2:]] = arg elif kw.startswith("p_"): @@ -556,6 +559,7 @@ def _assign_port_names(self): raise TypeError("Signals with private names cannot be used in unnamed top-level ports") name = _add_name(assigned_names, conn.name) assigned_names.add(name) + validate_name(name, "Top-level port name") new_ports.append((name, conn, dir)) self.ports = new_ports diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index d13eb8271..f8c3b36c2 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -117,9 +117,9 @@ def __init__(self, design, *, vcd_file, gtkw_file=None, traces=(), fs_per_delta= vcd_var = None for (*var_scope, var_name) in names: + # Shouldn't happen, but avoid producing a corrupt file. if re.search(r"[ \t\r\n]", var_name): - raise NameError("Signal '{}.{}' contains a whitespace character" - .format(".".join(var_scope), var_name)) + assert False, f"invalid name {var_name!r} made it to writer" # :nocov: field_name = var_name for item in repr.path: diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 576aa90af..7c833da7e 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -1182,6 +1182,15 @@ def test_name(self): self.assertEqual(s2.name, "sig") s3 = Signal(name="") self.assertEqual(s3.name, "") + s4 = Signal(name="$\\1a!\U0001F33C") + self.assertEqual(s4.name, "$\\1a!\U0001f33c") # Astral plane emoji "Blossom" + + def test_wrong_name(self): + for bad in [" ", "\r", "\n", "\t", "\0", "\u009d"]: # Control character OSC + name = f"sig{bad}" + with self.assertRaises(NameError, + msg="Name {name!r} contains whitespace/control character {bad!r}"): + Signal(name=name) def test_init(self): s1 = Signal(4, init=0b111, reset_less=True) @@ -1733,6 +1742,9 @@ def test_ioport(self): self.assertEqual(b.metadata, ("x", "y", "z")) self.assertEqual(b._ioports(), {b}) self.assertRepr(b, "(io-port b)") + with self.assertRaisesRegex(NameError, + r"^Name must be a non-empty string$"): + IOPort(1, name="") def test_ioport_wrong(self): with self.assertRaisesRegex(TypeError, diff --git a/tests/test_hdl_cd.py b/tests/test_hdl_cd.py index 5a9fa736b..25df3f1f0 100644 --- a/tests/test_hdl_cd.py +++ b/tests/test_hdl_cd.py @@ -23,6 +23,9 @@ def test_name(self): ClockDomain() cd_reset = ClockDomain(local=True) self.assertEqual(cd_reset.local, True) + with self.assertRaisesRegex(TypeError, + r"^Clock domain name must be a string, not 1$"): + sync.rename(1) def test_edge(self): sync = ClockDomain() @@ -64,6 +67,9 @@ def test_rename(self): self.assertEqual(sync.name, "pix") self.assertEqual(sync.clk.name, "pix_clk") self.assertEqual(sync.rst.name, "pix_rst") + with self.assertRaisesRegex(TypeError, + r"^Clock domain name must be a string, not 1$"): + sync.rename(1) def test_rename_reset_less(self): sync = ClockDomain(reset_less=True) diff --git a/tests/test_hdl_dsl.py b/tests/test_hdl_dsl.py index 8048e5458..65bbf936d 100644 --- a/tests/test_hdl_dsl.py +++ b/tests/test_hdl_dsl.py @@ -792,6 +792,17 @@ def test_FSM_wrong_next(self): with m.FSM(): m.next = "FOO" + def test_FSM_wrong_name(self): + m = Module() + with self.assertRaisesRegex(TypeError, + r"^FSM name must be a string, not 1$"): + with m.FSM(name=1): + pass + with self.assertRaisesRegex(TypeError, + r"^FSM clock domain must be a string, not 1$"): + with m.FSM(domain=1): + pass + def test_If_inside_FSM_wrong(self): m = Module() with m.FSM(): @@ -868,6 +879,16 @@ def test_submodule_named_index(self): self.assertEqual(m1._named_submodules.keys(), {"foo"}) self.assertEqual(m1._named_submodules["foo"][0], m2) + def test_submodule_wrong_name(self): + m1 = Module() + m2 = Module() + with self.assertRaisesRegex(TypeError, + r"^Submodule name must be a string, not 1$"): + m1.submodules[1] = m2 + with self.assertRaisesRegex(NameError, + r"^Submodule name must be a non-empty string$"): + m1.submodules[""] = m2 + def test_submodule_wrong(self): m = Module() with self.assertRaisesRegex(TypeError, diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index 133c32fd4..50c760456 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -734,6 +734,17 @@ def test_construct(self): ("io6", (io6, "io")), ])) + def test_wrong_name(self): + with self.assertRaisesRegex(TypeError, + r"^Instance type must be a string, not 1$"): + Instance(1) + with self.assertRaisesRegex(TypeError, + r"^Instance argument name must be a string, not 1$"): + Instance("foo", ("a", 1, 2)) + with self.assertRaisesRegex(NameError, + r"^Instance keyword argument name ' ' contains whitespace/control character ' '$"): + Instance("foo", **{" ": 2}) + def test_cast_ports(self): inst = Instance("foo", ("i", "s1", 1), @@ -1043,6 +1054,12 @@ def test_assign_names_to_fragments_duplicate(self): self.assertEqual(design.fragments[a1_f].name, ("top", "a")) self.assertEqual(design.fragments[a2_f].name, ("top", "a$1")) + def test_port_wrong_name(self): + f = Fragment() + with self.assertRaisesRegex(NameError, + r"^Top-level port name must be a non-empty string$"): + design = Design(f, ports=[("", Signal(), None)], hierarchy=("top",)) + class ElaboratesTo(Elaboratable): def __init__(self, lower): diff --git a/tests/test_sim.py b/tests/test_sim.py index 9be42990e..fc85b88c8 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -1230,19 +1230,6 @@ def process(): sim.add_testbench(process) sim.run() - def test_bug_595(self): - dut = Module() - dummy = Signal() - with dut.FSM(name="name with space"): - with dut.State(0): - dut.d.comb += dummy.eq(1) - sim = Simulator(dut) - with self.assertRaisesRegex(NameError, - r"^Signal 'bench\.top\.name with space_state' contains a whitespace character$"): - with open(os.path.devnull, "w") as f: - with sim.write_vcd(f): - sim.run() - def test_bug_588(self): dut = Module() a = Signal(32)