Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce naming rules on core HDL #1235

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion amaranth/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import warnings
import linecache
import re
import unicodedata
from collections import OrderedDict
from collections.abc import Iterable


__all__ = ["flatten", "union", "final", "deprecated", "get_linter_options",
"get_linter_option"]
"get_linter_option", "validate_name"]


def flatten(i):
Expand Down Expand Up @@ -101,3 +102,25 @@ def get_linter_option(filename, name, type, default):
except ValueError:
return default
assert False

_invalid_categories = {
"Zs", # space separators
"Zl", # line separators
"Zp", # paragraph separators
"Cc", # control codepoints (e.g. \0)
"Cs", # UTF-16 surrogate pair codepoints (no thanks WTF-8)
"Cn", # unassigned codepoints
}

def validate_name(name, what, none_ok=False, empty_ok=False):
if not isinstance(name, str):
if name is None and none_ok: return
raise TypeError(f"{what} must be a string, not {name!r}")
if name == "" and not empty_ok:
raise NameError(f"{what} must be a non-empty string")

# RTLIL allows bytes >= 33. In the same spirit we allow all characters that
# Unicode does not declare as separator or control.
for c in name:
if unicodedata.category(c) in _invalid_categories:
raise NameError(f"{what} {name!r} contains whitespace/control character {c!r}")
12 changes: 4 additions & 8 deletions amaranth/hdl/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,12 +1919,11 @@ def __init__(self, shape=None, *, name=None, init=None, reset=None, reset_less=F
attrs=None, decoder=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)

if name is not None and not isinstance(name, str):
raise TypeError(f"Name must be a string, not {name!r}")
if name is None:
self.name = tracer.get_var_name(depth=2 + src_loc_at, default="$signal")
else:
self.name = name
validate_name(self.name, "Name", empty_ok=True)

orig_shape = shape
if shape is None:
Expand Down Expand Up @@ -2127,8 +2126,7 @@ class ClockSignal(Value):
"""
def __init__(self, domain="sync", *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
if not isinstance(domain, str):
raise TypeError(f"Clock domain name must be a string, not {domain!r}")
validate_name(domain, "Clock domain name")
if domain == "comb":
raise ValueError(f"Domain '{domain}' does not have a clock")
self._domain = domain
Expand Down Expand Up @@ -2167,8 +2165,7 @@ class ResetSignal(Value):
"""
def __init__(self, domain="sync", allow_reset_less=False, *, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)
if not isinstance(domain, str):
raise TypeError(f"Clock domain name must be a string, not {domain!r}")
validate_name(domain, "Clock domain name")
if domain == "comb":
raise ValueError(f"Domain '{domain}' does not have a reset")
self._domain = domain
Expand Down Expand Up @@ -2850,8 +2847,7 @@ class IOPort(IOValue):
def __init__(self, width, *, name=None, attrs=None, metadata=None, src_loc_at=0):
super().__init__(src_loc_at=src_loc_at)

if name is not None and not isinstance(name, str):
raise TypeError(f"Name must be a string, not {name!r}")
validate_name(name, "Name", none_ok=True)
self.name = name or tracer.get_var_name(depth=2 + src_loc_at)

self._width = operator.index(width)
Expand Down
3 changes: 3 additions & 0 deletions amaranth/hdl/_cd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .. import tracer
from ._ast import Signal
from .._utils import validate_name


__all__ = ["ClockDomain", "DomainError"]
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self, name=None, *, clk_edge="pos", reset_less=False, async_reset=F
name = tracer.get_var_name()
except tracer.NameNotFound:
raise ValueError("Clock domain name must be specified explicitly")
validate_name(name, "Clock domain name")
if name.startswith("cd_"):
name = name[3:]
if name == "comb":
Expand All @@ -78,6 +80,7 @@ def __init__(self, name=None, *, clk_edge="pos", reset_less=False, async_reset=F
self.local = local

def rename(self, new_name):
validate_name(new_name, "Clock domain name")
self.name = new_name
self.clk.name = self._name_for(new_name, "clk")
if self.rst is not None:
Expand Down
5 changes: 4 additions & 1 deletion amaranth/hdl/_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
import sys

from .._utils import flatten
from .._utils import flatten, validate_name
from ..utils import bits_for
from .. import tracer
from ._ast import *
Expand Down Expand Up @@ -426,6 +426,8 @@ def FSM(self, init=None, domain="sync", name="fsm", *, reset=None):
warnings.warn("`reset=` is deprecated, use `init=` instead",
DeprecationWarning, stacklevel=2)
init = reset
validate_name(name, "FSM name")
validate_name(domain, "FSM clock domain")
fsm_data = self._set_ctrl("FSM", {
"name": name,
"init": init,
Expand Down Expand Up @@ -609,6 +611,7 @@ def _add_submodule(self, submodule, name=None, src_loc=None):
if name == None:
self._anon_submodules.append((submodule, src_loc))
else:
validate_name(name, "Submodule name")
if name in self._named_submodules:
raise NameError(f"Submodule named '{name}' already exists")
self._named_submodules[name] = (submodule, src_loc)
Expand Down
6 changes: 5 additions & 1 deletion amaranth/hdl/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import enum
import warnings

from .._utils import flatten
from .._utils import flatten, validate_name
from .. import tracer, _unused
from . import _ast, _cd, _ir, _nir

Expand Down Expand Up @@ -309,11 +309,13 @@ class Instance(Fragment):
def __init__(self, type, *args, src_loc=None, src_loc_at=0, **kwargs):
super().__init__(src_loc=src_loc or tracer.get_src_loc(src_loc_at))

validate_name(type, "Instance type")
self.type = type
self.parameters = OrderedDict()
self.named_ports = OrderedDict()

for (kind, name, value) in args:
validate_name(name, "Instance argument name")
if kind == "a":
self.attrs[name] = value
elif kind == "p":
Expand All @@ -331,6 +333,7 @@ def __init__(self, type, *args, src_loc=None, src_loc_at=0, **kwargs):
.format((kind, name, value)))

for kw, arg in kwargs.items():
validate_name(kw, "Instance keyword argument name")
if kw.startswith("a_"):
self.attrs[kw[2:]] = arg
elif kw.startswith("p_"):
Expand Down Expand Up @@ -556,6 +559,7 @@ def _assign_port_names(self):
raise TypeError("Signals with private names cannot be used in unnamed top-level ports")
name = _add_name(assigned_names, conn.name)
assigned_names.add(name)
validate_name(name, "Top-level port name")
new_ports.append((name, conn, dir))
self.ports = new_ports

Expand Down
4 changes: 2 additions & 2 deletions amaranth/sim/pysim.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def __init__(self, design, *, vcd_file, gtkw_file=None, traces=(), fs_per_delta=

vcd_var = None
for (*var_scope, var_name) in names:
# Shouldn't happen, but avoid producing a corrupt file.
if re.search(r"[ \t\r\n]", var_name):
raise NameError("Signal '{}.{}' contains a whitespace character"
.format(".".join(var_scope), var_name))
assert False, f"invalid name {var_name!r} made it to writer" # :nocov:

field_name = var_name
for item in repr.path:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_hdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,15 @@ def test_name(self):
self.assertEqual(s2.name, "sig")
s3 = Signal(name="")
self.assertEqual(s3.name, "")
s4 = Signal(name="$\\1a!\U0001F33C")
self.assertEqual(s4.name, "$\\1a!\U0001f33c") # Astral plane emoji "Blossom"

def test_wrong_name(self):
for bad in [" ", "\r", "\n", "\t", "\0", "\u009d"]: # Control character OSC
name = f"sig{bad}"
with self.assertRaises(NameError,
msg="Name {name!r} contains whitespace/control character {bad!r}"):
Signal(name=name)

def test_init(self):
s1 = Signal(4, init=0b111, reset_less=True)
Expand Down Expand Up @@ -1733,6 +1742,9 @@ def test_ioport(self):
self.assertEqual(b.metadata, ("x", "y", "z"))
self.assertEqual(b._ioports(), {b})
self.assertRepr(b, "(io-port b)")
with self.assertRaisesRegex(NameError,
r"^Name must be a non-empty string$"):
IOPort(1, name="")

def test_ioport_wrong(self):
with self.assertRaisesRegex(TypeError,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_hdl_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_name(self):
ClockDomain()
cd_reset = ClockDomain(local=True)
self.assertEqual(cd_reset.local, True)
with self.assertRaisesRegex(TypeError,
r"^Clock domain name must be a string, not 1$"):
sync.rename(1)

def test_edge(self):
sync = ClockDomain()
Expand Down Expand Up @@ -64,6 +67,9 @@ def test_rename(self):
self.assertEqual(sync.name, "pix")
self.assertEqual(sync.clk.name, "pix_clk")
self.assertEqual(sync.rst.name, "pix_rst")
with self.assertRaisesRegex(TypeError,
r"^Clock domain name must be a string, not 1$"):
sync.rename(1)

def test_rename_reset_less(self):
sync = ClockDomain(reset_less=True)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_hdl_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,17 @@ def test_FSM_wrong_next(self):
with m.FSM():
m.next = "FOO"

def test_FSM_wrong_name(self):
m = Module()
with self.assertRaisesRegex(TypeError,
r"^FSM name must be a string, not 1$"):
with m.FSM(name=1):
pass
with self.assertRaisesRegex(TypeError,
r"^FSM clock domain must be a string, not 1$"):
with m.FSM(domain=1):
pass

def test_If_inside_FSM_wrong(self):
m = Module()
with m.FSM():
Expand Down Expand Up @@ -868,6 +879,16 @@ def test_submodule_named_index(self):
self.assertEqual(m1._named_submodules.keys(), {"foo"})
self.assertEqual(m1._named_submodules["foo"][0], m2)

def test_submodule_wrong_name(self):
m1 = Module()
m2 = Module()
with self.assertRaisesRegex(TypeError,
r"^Submodule name must be a string, not 1$"):
m1.submodules[1] = m2
with self.assertRaisesRegex(NameError,
r"^Submodule name must be a non-empty string$"):
m1.submodules[""] = m2

def test_submodule_wrong(self):
m = Module()
with self.assertRaisesRegex(TypeError,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_hdl_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,17 @@ def test_construct(self):
("io6", (io6, "io")),
]))

def test_wrong_name(self):
with self.assertRaisesRegex(TypeError,
r"^Instance type must be a string, not 1$"):
Instance(1)
with self.assertRaisesRegex(TypeError,
r"^Instance argument name must be a string, not 1$"):
Instance("foo", ("a", 1, 2))
with self.assertRaisesRegex(NameError,
r"^Instance keyword argument name ' ' contains whitespace/control character ' '$"):
Instance("foo", **{" ": 2})

def test_cast_ports(self):
inst = Instance("foo",
("i", "s1", 1),
Expand Down Expand Up @@ -1043,6 +1054,12 @@ def test_assign_names_to_fragments_duplicate(self):
self.assertEqual(design.fragments[a1_f].name, ("top", "a"))
self.assertEqual(design.fragments[a2_f].name, ("top", "a$1"))

def test_port_wrong_name(self):
f = Fragment()
with self.assertRaisesRegex(NameError,
r"^Top-level port name must be a non-empty string$"):
design = Design(f, ports=[("", Signal(), None)], hierarchy=("top",))


class ElaboratesTo(Elaboratable):
def __init__(self, lower):
Expand Down
13 changes: 0 additions & 13 deletions tests/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,19 +1230,6 @@ def process():
sim.add_testbench(process)
sim.run()

def test_bug_595(self):
dut = Module()
dummy = Signal()
with dut.FSM(name="name with space"):
with dut.State(0):
dut.d.comb += dummy.eq(1)
sim = Simulator(dut)
with self.assertRaisesRegex(NameError,
r"^Signal 'bench\.top\.name with space_state' contains a whitespace character$"):
with open(os.path.devnull, "w") as f:
with sim.write_vcd(f):
sim.run()

def test_bug_588(self):
dut = Module()
a = Signal(32)
Expand Down