From ad48c5b4d2d1286fb8e6b0b3fe3ce6ebdd91e5f7 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 14 Jun 2023 19:53:15 -0400 Subject: [PATCH 01/40] Add typing for prosemirror.utils and add JSON alias --- prosemirror/utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 021ef47..110bde1 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -1,2 +1,13 @@ -def text_length(text): +from collections.abc import Mapping, Sequence +from typing import Union + +from typing_extensions import TypeAlias + +JSONDict: TypeAlias = Mapping[str, "JSON"] +JSONList: TypeAlias = Sequence["JSON"] + +JSON: TypeAlias = Union[JSONDict, JSONList, str, int, float, bool, None] + + +def text_length(text: str) -> int: return len(text.encode("utf-16-le")) // 2 From 2d7411c42bd6e15674c1061dfd14e1756febad73 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 14 Jun 2023 20:02:51 -0400 Subject: [PATCH 02/40] Add typing for prosemirror.model.comparedeep --- prosemirror/model/comparedeep.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/prosemirror/model/comparedeep.py b/prosemirror/model/comparedeep.py index ba1178d..8b1fa9e 100644 --- a/prosemirror/model/comparedeep.py +++ b/prosemirror/model/comparedeep.py @@ -1,2 +1,5 @@ -def compare_deep(a, b): +from prosemirror.utils import JSON + + +def compare_deep(a: JSON, b: JSON) -> bool: return a == b From 927409bde458578e7205c0956f7ebe1d5fbae33d Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 14 Jun 2023 23:02:39 -0400 Subject: [PATCH 03/40] Typing for prosemirror.model.content --- prosemirror/model/content.py | 341 +++++++++++++++++++---------- prosemirror/model/from_dom.py | 4 +- prosemirror/transform/transform.py | 2 +- 3 files changed, 233 insertions(+), 114 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 105b777..8977bfd 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -1,20 +1,60 @@ import re +from dataclasses import dataclass, field from functools import cmp_to_key, reduce -from typing import ClassVar +from typing import ( + TYPE_CHECKING, + ClassVar, + Dict, + List, + Literal, + NamedTuple, + NoReturn, + Optional, + TypedDict, + Union, + cast, +) from .fragment import Fragment +if TYPE_CHECKING: + from .schema import NodeType + +@dataclass +class MatchEdge: + type: "NodeType" + next: "ContentMatch" + + +@dataclass +class WrapCacheEntry: + target: "NodeType" + computed: Optional[List["NodeType"]] + + +class Active(TypedDict): + match: "ContentMatch" + type: Optional["NodeType"] + via: Optional["Active"] + + +@dataclass(eq=False) class ContentMatch: - empty: ClassVar["ContentMatch"] + """ + Instances of this class represent a match state of a node type's + [content expression](#model.NodeSpec.content), and can be used to + find out whether further content matches here, and whether a given + position is a valid end of the node. + """ - def __init__(self, valid_end): - self.valid_end = valid_end - self.next = [] - self.wrap_cache = [] + empty: ClassVar["ContentMatch"] + valid_end: bool + next: List[MatchEdge] = field(default_factory=list, init=False) + wrap_cache: List[WrapCacheEntry] = field(default_factory=list, init=False) @classmethod - def parse(cls, string, node_types): + def parse(cls, string: str, node_types: Dict[str, "NodeType"]) -> "ContentMatch": stream = TokenStream(string, node_types) if stream.next is None: return ContentMatch.empty @@ -25,16 +65,18 @@ def parse(cls, string, node_types): check_for_dead_ends(match, stream) return match - def match_type(self, type, *args): - for i in range(0, len(self.next), 2): - if self.next[i].name == type.name: - return self.next[i + 1] + def match_type(self, type: "NodeType") -> Optional["ContentMatch"]: + for next in self.next: + if next.type.name == type.name: + return next.next return None - def match_fragment(self, frag, start=0, end=None): + def match_fragment( + self, frag: Fragment, start: int = 0, end: Optional[int] = None + ) -> Optional["ContentMatch"]: if end is None: end = frag.child_count - cur = self + cur: Optional["ContentMatch"] = self i = start while cur and i < end: cur = cur.match_type(frag.child(i).type) @@ -42,56 +84,57 @@ def match_fragment(self, frag, start=0, end=None): return cur @property - def inline_content(self): - if not self.next: - return None - first = self.next[0] - return first.is_inline if first else False + def inline_content(self) -> bool: + return bool(self.next) and self.next[0].type.is_inline @property - def default_type(self): - for i in range(0, len(self.next), 2): - type = self.next[i] + def default_type(self) -> Optional["NodeType"]: + for next in self.next: + type = next.type if not (type.is_text or type.has_required_attrs()): return type + return None - def compatible(self, other): - for i in range(0, len(self.next), 2): - for j in range(0, len(other.next), 2): - if self.next[i].name == other.next[j].name: + def compatible(self, other: "ContentMatch") -> bool: + for i in self.next: + for j in other.next: + if i.type.name == j.type.name: return True return False - def fill_before(self, after, to_end=False, start_index=0): + def fill_before( + self, after: Fragment, to_end: bool = False, start_index: int = 0 + ) -> Optional[Fragment]: seen = [self] - def search(match, types): + def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: nonlocal seen finished = match.match_fragment(after, start_index) if finished and (not to_end or finished.valid_end): return Fragment.from_([tp.create_and_fill() for tp in types]) - for i in range(0, len(match.next), 2): - type = match.next[i] - next = match.next[i + 1] + for i in match.next: + type = i.type + next = i.next if not (type.is_text or type.has_required_attrs()) and next not in seen: seen.append(next) found = search(next, [*types, type]) if found: return found + return None return search(self, []) - def find_wrapping(self, target): - for i in range(0, len(self.wrap_cache), 2): - if self.wrap_cache[i].name == target.name: - return self.wrap_cache[i + 1] + def find_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: + for entry in self.wrap_cache: + if entry.target.name == target.name: + return entry.computed computed = self.compute_wrapping(target) - self.wrap_cache.extend([target, computed]) + self.wrap_cache.append(WrapCacheEntry(target, computed)) return computed - def compute_wrapping(self, target): + def compute_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: seen = {} - active = [{"match": self, "type": None, "via": None}] + active: List[Active] = [{"match": self, "type": None, "via": None}] while len(active): current = active.pop(0) match = current["match"] @@ -100,50 +143,59 @@ def compute_wrapping(self, target): obj = current while obj["type"]: result.append(obj["type"]) - obj = obj["via"] + obj = cast(Active, obj["via"]) return list(reversed(result)) - for i in range(0, len(match.next), 2): - type = match.next[i] + for i in range(len(match.next)): + type = match.next[i].type if ( not type.is_leaf and not type.has_required_attrs() and type.name not in seen - and (not current["type"] or match.next[i + 1].valid_end) + and (not current["type"] or match.next[i].next.valid_end) ): active.append( - {"match": type.content_match, "via": current, "type": type} + { + "match": type.content_match, + "via": current, + "type": type, + } ) seen[type.name] = True + return None @property - def edge_count(self): - return len(self.next) >> 1 + def edge_count(self) -> int: + return len(self.next) - def edge(self, n): - i = n << 1 - if i >= len(self.next): + def edge(self, n: int) -> MatchEdge: + if n >= len(self.next): raise ValueError(f"There's no {n}th edge in this content match") - return {"type": self.next[i], "next": self.next[i + 1]} + return self.next[n] - def __str__(self): + def __str__(self) -> str: seen = [] - def scan(m): + def scan(m: "ContentMatch") -> None: nonlocal seen - for i in range(1, len(m.next), 2): - if m.next[i] in seen: - scan(m.next[i]) + seen.append(m) + for i in m.next: + if i.next not in seen: + scan(i.next) scan(self) - def iteratee(m, i): - out = i + ("*" if m.valid_end else " ") + " " - for i in range(0, len(m.next), 2): + def iteratee(m: "ContentMatch", i: int) -> str: + out = str(i) + ("*" if m.valid_end else " ") + " " + for i in range(len(m.next)): out += ( - (", " if i else "") + m.next(i) + "->" + seen.index(m.next[i + 1]) + (", " if i else "") + + m.next[i].type.name + + "->" + + str(seen.index(m.next[i].next)) ) + return out - return "\n".join((iteratee(m, i)) for m, i in enumerate(seen)) + return "\n".join((iteratee(m, i)) for i, m in enumerate(seen)) ContentMatch.empty = ContentMatch(True) @@ -153,7 +205,7 @@ def iteratee(m, i): class TokenStream: - def __init__(self, string, node_types): + def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: self.string = string self.node_types = node_types self.inline = None @@ -161,13 +213,13 @@ def __init__(self, string, node_types): self.tokens = [i for i in TOKEN_REGEX.findall(string) if i.strip()] @property - def next(self): + def next(self) -> Optional[str]: try: return self.tokens[self.pos] except IndexError: return None - def eat(self, tok): + def eat(self, tok: str) -> Union[int, bool]: if self.next == tok: pos = self.pos self.pos += 1 @@ -175,11 +227,51 @@ def eat(self, tok): else: return False - def err(self, str): + def err(self, str: str) -> NoReturn: raise SyntaxError(f'{str} (in content expression) "{self.string}"') -def parse_expr(stream): +class ChoiceExpr(TypedDict): + type: Literal["choice"] + exprs: List["Expr"] + + +class SeqExpr(TypedDict): + type: Literal["seq"] + exprs: List["Expr"] + + +class PlusExpr(TypedDict): + type: Literal["plus"] + expr: "Expr" + + +class StarExpr(TypedDict): + type: Literal["star"] + expr: "Expr" + + +class OptExpr(TypedDict): + type: Literal["opt"] + expr: "Expr" + + +class RangeExpr(TypedDict): + type: Literal["range"] + min: int + max: int + expr: "Expr" + + +class NameExpr(TypedDict): + type: Literal["name"] + value: "NodeType" + + +Expr = Union[ChoiceExpr, SeqExpr, PlusExpr, StarExpr, OptExpr, RangeExpr, NameExpr] + + +def parse_expr(stream: TokenStream) -> Expr: exprs = [] while True: exprs.append(parse_expr_seq(stream)) @@ -190,7 +282,7 @@ def parse_expr(stream): return {"type": "choice", "exprs": exprs} -def parse_expr_seq(stream): +def parse_expr_seq(stream: TokenStream) -> Expr: exprs = [] while True: exprs.append(parse_expr_subscript(stream)) @@ -201,7 +293,7 @@ def parse_expr_seq(stream): return {"type": "seq", "exprs": exprs} -def parse_expr_subscript(stream): +def parse_expr_subscript(stream: TokenStream) -> Expr: expr = parse_expr_atom(stream) while True: if stream.eat("+"): @@ -220,15 +312,17 @@ def parse_expr_subscript(stream): NUMBER_REGEX = re.compile(r"\D") -def parse_num(stream: TokenStream): - if NUMBER_REGEX.match(stream.next): - stream.err(f'Expected number, got "{stream.next}"') - result = int(stream.next) +def parse_num(stream: TokenStream) -> int: + next = stream.next + assert next is not None + if NUMBER_REGEX.match(next): + stream.err(f'Expected number, got "{next}"') + result = int(next) stream.pos += 1 return result -def parse_expr_range(stream: TokenStream, expr): +def parse_expr_range(stream: TokenStream, expr: Expr) -> Expr: min_ = parse_num(stream) max_ = min_ if stream.eat(","): @@ -241,7 +335,7 @@ def parse_expr_range(stream: TokenStream, expr): return {"type": "range", "min": min_, "max": max_, "expr": expr} -def resolve_name(stream: TokenStream, name): +def resolve_name(stream: TokenStream, name: str) -> List["NodeType"]: types = stream.node_types type = types.get(name) if type: @@ -255,15 +349,17 @@ def resolve_name(stream: TokenStream, name): return result -def parse_expr_atom(stream: TokenStream): +def parse_expr_atom( + stream: TokenStream, +) -> Expr: if stream.eat("("): expr = parse_expr(stream) if not stream.eat(")"): stream.err("missing closing patren") return expr - elif not re.match(r"\W", stream.next): + elif not re.match(r"\W", cast(str, stream.next)): - def iteratee(type): + def iteratee(type: "NodeType") -> Expr: nonlocal stream if stream.inline is None: stream.inline = type.is_inline @@ -271,7 +367,9 @@ def iteratee(type): stream.err("Mixing inline and block content") return {"type": "name", "value": type} - exprs = [iteratee(type) for type in resolve_name(stream, stream.next)] + exprs = [ + iteratee(type) for type in resolve_name(stream, cast(str, stream.next)) + ] stream.pos += 1 if len(exprs) == 1: return exprs[0] @@ -280,37 +378,50 @@ def iteratee(type): stream.err(f'Unexpected token "{stream.next}"') -def nfa(expr): - nfa_ = [[]] +class Edge(TypedDict): + term: Optional["NodeType"] + to: Optional[int] + + +def nfa( + expr: Expr, +) -> List[List[Edge]]: + nfa_: List[List[Edge]] = [[]] - def node(): + def node() -> int: nonlocal nfa_ nfa_.append([]) return len(nfa_) - 1 - def edge(from_, to=None, term=None): + def edge( + from_: int, to: Optional[int] = None, term: Optional["NodeType"] = None + ) -> Edge: nonlocal nfa_ - edge = {"term": term, "to": to} + edge: Edge = {"term": term, "to": to} nfa_[from_].append(edge) return edge - def connect(edges, to): + def connect(edges: List[Edge], to: int) -> None: for edge in edges: edge["to"] = to - def compile(expr, from_): + def compile(expr: Expr, from_: int) -> List[Edge]: if expr["type"] == "choice": return list( - reduce(lambda out, expr: out + compile(expr, from_), expr["exprs"], []) + reduce( + lambda out, expr: [*out, *compile(expr, from_)], + expr["exprs"], + cast(List[Edge], []), + ) ) elif expr["type"] == "seq": i = 0 while True: - next = compile(expr["exprs"][i], from_) + nxt = compile(expr["exprs"][i], from_) if i == len(expr["exprs"]) - 1: - return next + return nxt from_ = node() - connect(next, from_) + connect(nxt, from_) i += 1 elif expr["type"] == "star": loop = node() @@ -346,73 +457,81 @@ def compile(expr, from_): return nfa_ -def cmp(a, b): +def cmp(a: int, b: int) -> int: return b - a -def null_from(nfa, node): +def null_from( + nfa: List[List[Edge]], + node: int, +) -> List[int]: result = [] - def scan(n): + def scan(n: int) -> None: nonlocal result edges = nfa[n] if len(edges) == 1 and not edges[0].get("term"): - return scan(edges[0].get("to")) + return scan(cast(int, edges[0].get("to"))) result.append(n) for edge in edges: term, to = edge.get("term"), edge.get("to") if not term and to not in result: - scan(to) + scan(cast(int, to)) scan(node) return sorted(result) -def dfa(nfa): +class DFAState(NamedTuple): + state: "NodeType" + next: List[int] + + +def dfa(nfa: List[List[Edge]]) -> ContentMatch: labeled = {} - def explore(states): + def explore(states: List[int]) -> ContentMatch: nonlocal labeled - out = [] + out: List[DFAState] = [] for node in states: for item in nfa[node]: term, to = item.get("term"), item.get("to") if not term: continue - known = term in out - if known: - set = out[out.index(term) + 1] - else: - set = False - for n in null_from(nfa, to): - if not set: + set: Optional[List[int]] = None + for t in out: + if t[0] == term: + set = t[1] + for n in null_from(nfa, cast(int, to)): + if set is None: set = [] - out.extend([term, set]) + out.append(DFAState(term, set)) if n not in set: set.append(n) state = ContentMatch((len(nfa) - 1) in states) labeled[",".join([str(s) for s in states])] = state - for i in range(0, len(out), 2): - out[i + 1].sort(key=cmp_to_key(cmp)) - states = out[i + 1] + for i in range(len(out)): + out[i][1].sort(key=cmp_to_key(cmp)) + states = out[i][1] find_by_key = ",".join(str(s) for s in states) - items_to_extend = [out[i], labeled.get(find_by_key) or explore(states)] - state.next.extend(items_to_extend) + state.next.append( + MatchEdge(out[i][0], labeled.get(find_by_key) or explore(states)) + ) return state return explore(null_from(nfa, 0)) -def check_for_dead_ends(match, stream): +def check_for_dead_ends(match: ContentMatch, stream: TokenStream) -> None: work = [match] i = 0 while i < len(work): state = work[i] dead = not state.valid_end nodes = [] - for j in range(0, len(state.next), 2): - node = state.next[j] - next = state.next[j + 1] + for j in range(len(state.next)): + node = state.next[j].type + next = state.next[j].next nodes.append(node.name) if dead and not (node.is_text or node.has_required_attrs()): dead = False diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 818d75d..5f4f7fe 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -1081,8 +1081,8 @@ def scan(match: ContentMatch) -> bool: i = 0 while i < match.edge_count: result = match.edge(i) - _type = result["type"] - _next = result["next"] + _type = result.type + _next = result.next if _type == node_type: return True diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index 131d03b..4fc5862 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -161,7 +161,7 @@ def clear_incompatible(self, pos, parent_type, match=None): for i in range(node.child_count): child = node.child(i) end = cur + child.node_size - allowed = match.match_type(child.type, child.attrs) + allowed = match.match_type(child.type) if not allowed: del_steps.append(ReplaceStep(cur, end, Slice.empty)) else: From 7380bae8e1bd010bfa157928acfedf92e211d1a2 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Thu, 15 Jun 2023 09:19:27 -0400 Subject: [PATCH 04/40] Typing for prosemirror.model.diff --- prosemirror/model/diff.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/prosemirror/model/diff.py b/prosemirror/model/diff.py index a606bdc..db72f33 100644 --- a/prosemirror/model/diff.py +++ b/prosemirror/model/diff.py @@ -1,7 +1,17 @@ +from typing import TYPE_CHECKING, Optional, TypedDict + from prosemirror.utils import text_length +if TYPE_CHECKING: + from prosemirror.model.fragment import Fragment + + +class Diff(TypedDict): + a: int + b: int + -def find_diff_start(a, b, pos): +def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: i = 0 while True: if a.child_count == i or b.child_count == i: @@ -37,7 +47,9 @@ def find_diff_start(a, b, pos): i += 1 -def find_diff_end(a, b, pos_a, pos_b): +def find_diff_end( + a: "Fragment", b: "Fragment", pos_a: int, pos_b: int +) -> Optional[Diff]: i_a, i_b = a.child_count, b.child_count while True: if i_a == 0 or i_b == 0: From f9a7e30c5ae5278ca7a44769d74b5850adc6eef3 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Thu, 15 Jun 2023 11:23:42 -0400 Subject: [PATCH 05/40] Typing for schema + fixes uncovered --- prosemirror/model/content.py | 2 + prosemirror/model/fragment.py | 2 +- prosemirror/model/from_dom.py | 8 +- prosemirror/model/schema.py | 321 +++++++++++++++++------ prosemirror/schema/basic/schema_basic.py | 8 +- prosemirror/test_builder/__init__.py | 2 +- 6 files changed, 248 insertions(+), 95 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 8977bfd..680e2ff 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -205,6 +205,8 @@ def iteratee(m: "ContentMatch", i: int) -> str: class TokenStream: + inline: Optional[bool] + def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: self.string = string self.node_types = node_types diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 4de31e7..2d03119 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -236,7 +236,7 @@ def from_array(cls, array): return cls(joined or array, size) @classmethod - def from_(cls, nodes): + def from_(cls, nodes) -> "Fragment": if not nodes: return cls.empty if isinstance(nodes, cls): diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 5f4f7fe..e225c8b 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -225,7 +225,7 @@ def insert(rule: ParseRule) -> None: return for name in schema.marks: - rules = schema.marks[name].spec["parseDOM"] + rules = schema.marks[name].spec.get("parseDOM") if rules: for rule in rules: @@ -986,15 +986,15 @@ def textblock_from_context(self) -> Optional[NodeType]: if ( default is not None - and default.is_textblock + and default.is_text_block and default.default_attrs ): return default d -= 1 - for name, type_ in self.parser.schema.nodes.iteritems(): - if type_.is_textblock and type_.default_attrs: + for name, type_ in self.parser.schema.nodes.items(): + if type_.is_text_block and type_.default_attrs: return type_ return None diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index c871b73..136b292 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -1,15 +1,34 @@ -from collections import OrderedDict -from typing import Any, Dict, Literal, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + List, + Literal, + Optional, + TypeAlias, + TypeVar, + Union, + cast, +) + +from typing_extensions import NotRequired, TypedDict + +from prosemirror.model.fragment import Fragment +from prosemirror.model.mark import Mark +from prosemirror.model.node import Node, TextNode +from prosemirror.utils import JSON, JSONDict from .content import ContentMatch -from .fragment import Fragment -from .mark import Mark -from .node import Node, TextNode -Attrs = Dict[str, Any] +if TYPE_CHECKING: + pass +Attrs: TypeAlias = JSONDict -def default_attrs(attrs): + +def default_attrs(attrs: "Attributes") -> Optional[Attrs]: defaults = {} for attr_name, attr in attrs.items(): if attr.has_default: @@ -18,14 +37,14 @@ def default_attrs(attrs): return defaults -def compute_attrs(attrs, value): +def compute_attrs(attrs: "Attributes", value: Optional[Attrs]) -> Attrs: built = {} for name in attrs: given = None if value: given = value.get(name) if given is None: - attr = attrs.get(name) + attr = attrs[name] if attr.has_default: given = attr.default else: @@ -34,7 +53,7 @@ def compute_attrs(attrs, value): return built -def init_attrs(attrs): +def init_attrs(attrs: Optional["AttributeSpecs"]) -> "Attributes": result = {} if attrs: for name in attrs: @@ -43,34 +62,53 @@ def init_attrs(attrs): class NodeType: - def __init__(self, name, schema, spec): + """ + Node types are objects allocated once per `Schema` and used to + [tag](#model.Node.type) `Node` instances. They contain information + about the node type, such as its name and what kind of node it + represents. + """ + + name: str + + schema: "Schema[Any, Any]" + + spec: "NodeSpec" + + inline_content: bool + + content_match: "ContentMatch" + + mark_set: Optional[List["MarkType"]] + + def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> None: self.name = name self.schema = schema self.spec = spec - self.groups = spec.get("group").split(" ") if "group" in spec else [] + self.groups = spec["group"].split(" ") if "group" in spec else [] self.attrs = init_attrs(spec.get("attrs")) self.default_attrs = default_attrs(self.attrs) - self.content_match = None + self.content_match = None # type: ignore[assignment] self.mark_set = None - self.inline_content = None + self.inline_content = None # type: ignore[assignment] self.is_block = not (spec.get("inline") or name == "text") self.is_text = name == "text" @property - def is_inline(self): + def is_inline(self) -> bool: return not self.is_block @property - def is_text_block(self): + def is_text_block(self) -> bool: # FIXME: name is wrong, should be is_textblock return self.is_block and self.inline_content @property - def is_leaf(self): + def is_leaf(self) -> bool: return self.content_match == ContentMatch.empty @property - def is_atom(self): - return self.is_leaf or self.spec.get("atom") + def is_atom(self) -> bool: + return self.is_leaf or bool(self.spec.get("atom")) @property def whitespace(self) -> Literal["pre", "normal"]: @@ -78,21 +116,26 @@ def whitespace(self) -> Literal["pre", "normal"]: "pre" if self.spec.get("code") else "normal" ) - def has_required_attrs(self): + def has_required_attrs(self) -> bool: for n in self.attrs: if self.attrs[n].is_required: return True return False - def compatible_content(self, other): + def compatible_content(self, other: "NodeType") -> bool: return self == other or (self.content_match.compatible(other.content_match)) - def compute_attrs(self, attrs): + def compute_attrs(self, attrs: Optional[Attrs]) -> Attrs: if not attrs and self.default_attrs: return self.default_attrs return compute_attrs(self.attrs, attrs) - def create(self, attrs=None, content=None, marks=None): + def create( + self, + attrs: Optional[Attrs] = None, + content: Optional[Union[Fragment, Node, List[Node]]] = None, + marks: Optional[List[Mark]] = None, + ) -> Node: if self.is_text: raise ValueError("NodeType.create cannot construct text nodes") return Node( @@ -102,28 +145,39 @@ def create(self, attrs=None, content=None, marks=None): Mark.set_from(marks), ) - def create_checked(self, attrs=None, content=None, marks=None): + def create_checked( + self, + attrs: Optional[Attrs] = None, + content: Optional[Union[Fragment, Node, List[Node]]] = None, + marks: Optional[List[Mark]] = None, + ) -> Node: content = Fragment.from_(content) if not self.valid_content(content): raise ValueError("Invalid content for node " + self.name) return Node(self, self.compute_attrs(attrs), content, Mark.set_from(marks)) - def create_and_fill(self, attrs=None, content=None, marks=None): + def create_and_fill( + self, + attrs: Optional[Attrs] = None, + content: Optional[Union[Fragment, Node, List[Node]]] = None, + marks: Optional[List[Mark]] = None, + ) -> Optional[Node]: attrs = self.compute_attrs(attrs) - content = Fragment.from_(content) - if content.size: - before = self.content_match.fill_before(content) + frag = Fragment.from_(content) + if frag.size: + before = self.content_match.fill_before(frag) if not before: return None - content = before.append(content) - after = self.content_match.match_fragment(content).fill_before( - Fragment.empty, True - ) + frag = before.append(frag) + matched = self.content_match.match_fragment(frag) + if not matched: + return None + after = matched.fill_before(Fragment.empty, True) if not after: return None - return Node(self, attrs, content.append(after), Mark.set_from(marks)) + return Node(self, attrs, frag.append(after), Mark.set_from(marks)) - def valid_content(self, content): + def valid_content(self, content: Fragment) -> bool: result = self.content_match.match_fragment(content) if not result or not result.valid_end: return False @@ -132,15 +186,15 @@ def valid_content(self, content): return False return True - def allows_mark_type(self, mark_type): + def allows_mark_type(self, mark_type: "MarkType") -> bool: return self.mark_set is None or mark_type in self.mark_set - def allows_marks(self, marks): + def allows_marks(self, marks: List[Mark]) -> bool: if self.mark_set is None: return True return all(self.allows_mark_type(mark.type) for mark in marks) - def allowed_marks(self, marks): + def allowed_marks(self, marks: List[Mark]) -> List[Mark]: if self.mark_set is None: return marks copy = None @@ -155,61 +209,76 @@ def allowed_marks(self, marks): elif len(copy): return copy else: - return Mark.empty + return Mark.none @classmethod - def compile(cls, nodes: OrderedDict, schema): - result = {} + def compile( + cls, nodes: Dict["Nodes", "NodeSpec"], schema: "Schema[Nodes, Marks]" + ) -> Dict["Nodes", "NodeType"]: + result: Dict["Nodes", "NodeType"] = {} for name, spec in nodes.items(): result[name] = NodeType(name, schema, spec) - top_node = schema.spec.get("topNode") or "doc" + top_node = cast(Nodes, schema.spec.get("topNode") or "doc") if not result.get(top_node): raise ValueError(f"Schema is missing its top node type {top_node}") - if not result.get("text"): + if not result.get(cast(Nodes, "text")): raise ValueError("every schema needs a 'text' type") - if getattr(result.get("text"), "attrs", {}): + if result[cast(Nodes, "text")].attrs: raise ValueError("the text node type should not have attributes") return result - def __str__(self): + def __str__(self) -> str: return f"" - def __repr__(self): + def __repr__(self) -> str: return self.__str__() +Attributes: TypeAlias = Dict[str, "Attribute"] + + class Attribute: - def __init__(self, options): + def __init__(self, options: "AttributeSpec") -> None: self.has_default = "default" in options self.default = options["default"] if self.has_default else None @property - def is_required(self): + def is_required(self) -> bool: return not self.has_default class MarkType: - def __init__(self, name, rank, schema, spec): + excluded: List["MarkType"] + instance: Optional[Mark] + + def __init__( + self, name: str, rank: int, schema: "Schema", spec: "MarkSpec" + ) -> None: self.name = name self.schema = schema self.spec = spec self.attrs = init_attrs(spec.get("attrs")) self.rank = rank - self.excluded = None + self.excluded = None # type: ignore[assignment] defaults = default_attrs(self.attrs) - self.instance = False + self.instance = None if defaults: self.instance = Mark(self, defaults) - def create(self, attrs=None): + def create( + self, + attrs: Optional[Attrs] = None, + ) -> Mark: if not attrs and self.instance: return self.instance return Mark(self, compute_attrs(self.attrs, attrs)) @classmethod - def compile(cls, marks: OrderedDict, schema): + def compile( + cls, marks: Dict["Marks", "MarkSpec"], schema: "Schema[Nodes, Marks]" + ) -> Dict["Marks", "MarkType"]: result = {} rank = 0 for name, spec in marks.items(): @@ -217,23 +286,96 @@ def compile(cls, marks: OrderedDict, schema): rank += 1 return result - def remove_from_set(self, set_): + def remove_from_set(self, set_: List["Mark"]) -> List["Mark"]: return [item for item in set_ if item.type != self] - def is_in_set(self, set): + def is_in_set(self, set: List[Mark]) -> Optional[Mark]: return next((item for item in set if item.type == self), None) - def excludes(self, other): + def excludes(self, other: "MarkType") -> bool: return any(other.name == e.name for e in self.excluded) -class Schema: - def __init__(self, spec): - self.spec = {**spec} - self.spec["nodes"] = OrderedDict(self.spec["nodes"]) - self.spec["marks"] = OrderedDict(self.spec.get("marks", {})) +Nodes = TypeVar("Nodes", bound=str) +Marks = TypeVar("Marks", bound=str) + + +class SchemaSpec(TypedDict, Generic[Nodes, Marks]): + """ + An object describing a schema, as passed to the [`Schema`](#model.Schema) + constructor. + """ + + # The node types in this schema. Maps names to + # [`NodeSpec`](#model.NodeSpec) objects that describe the node type + # associated with that name. Their order is significant—it + # determines which [parse rules](#model.NodeSpec.parseDOM) take + # precedence by default, and which nodes come first in a given + # [group](#model.NodeSpec.group). + nodes: Dict[Nodes, "NodeSpec"] + + # The mark types that exist in this schema. The order in which they + # are provided determines the order in which [mark + # sets](#model.Mark.addToSet) are sorted and in which [parse + # rules](#model.MarkSpec.parseDOM) are tried. + marks: NotRequired[Dict[Marks, "MarkSpec"]] + + # The name of the default top-level node for the schema. Defaults + # to `"doc"`. + topNode: NotRequired[str] + + +class NodeSpec(TypedDict, total=False): + """ + A description of a node type, used when defining a schema. + """ + + content: str + marks: str + group: str + inline: bool + atom: bool + attrs: "AttributeSpecs" + selectable: bool + draggable: bool + code: bool + whitespace: Literal["pre", "normal"] + definingAsContext: bool + definingForContent: bool + defining: bool + isolating: bool + toDOM: Callable[[Node], Any] # FIXME: add types + parseDOM: List[Dict[str, Any]] # FIXME: add types + + +AttributeSpecs: TypeAlias = Dict[str, "AttributeSpec"] + + +class MarkSpec(TypedDict, total=False): + attrs: AttributeSpecs + inclusive: bool + excludes: str + group: str + spanning: bool + toDOM: Callable[[Mark, bool], Any] # FIXME: add types + parseDOM: List[Dict[str, Any]] # FIXME: add types + + +class AttributeSpec(TypedDict, total=False): + default: JSON + + +class Schema(Generic[Nodes, Marks]): + spec: SchemaSpec + + nodes: Dict[Nodes, "NodeType"] + + marks: Dict[Marks, "MarkType"] + + def __init__(self, spec: SchemaSpec) -> None: + self.spec = spec self.nodes = NodeType.compile(self.spec["nodes"], self) - self.marks = MarkType.compile(self.spec["marks"], self) + self.marks = MarkType.compile(self.spec.get("marks", {}), self) content_expr_cache = {} for prop in self.nodes: if prop in self.marks: @@ -243,7 +385,7 @@ def __init__(self, spec): mark_expr = type.spec.get("marks") if content_expr not in content_expr_cache: content_expr_cache[content_expr] = ContentMatch.parse( - content_expr, self.nodes + content_expr, cast(Dict[str, "NodeType"], self.nodes) ) type.content_match = content_expr_cache[content_expr] @@ -256,25 +398,25 @@ def __init__(self, spec): type.mark_set = [] else: type.mark_set = None - # type.mark_set = None if mark_expr == "_" else { - # gather_marks(self, mark_expr.split(" ")) if mark_expr else ( - # [] if (mark_expr == "" or not type.inline_content) else None - # ) - # } - for prop in self.marks: - type = self.marks.get(prop) - excl = type.spec.get("excludes") - type.excluded = ( - [type] + for mark in self.marks.values(): + excl = mark.spec.get("excludes") + mark.excluded = ( + [mark] if excl is None else ([] if excl == "" else (gather_marks(self, excl.split(" ")))) ) - self.top_node_type = self.nodes.get((self.spec.get("topNode") or "doc")) - self.cached = {} + self.top_node_type = self.nodes[cast(Nodes, self.spec.get("topNode") or "doc")] + self.cached: Dict[str, Any] = {} self.cached["wrappings"] = {} - def node(self, type, attrs=None, content=None, marks=None): + def node( + self, + type: Union[str, NodeType], + attrs: Optional[Attrs] = None, + content: Optional[Union[Fragment, Node]] = None, + marks: Optional[List[Mark]] = None, + ) -> Node: if isinstance(type, str): type = self.node_type(type) elif not isinstance(type, NodeType): @@ -283,29 +425,37 @@ def node(self, type, attrs=None, content=None, marks=None): raise ValueError(f"Node type from different schema used ({type.name})") return type.create_checked(attrs, content, marks) - def text(self, text, marks=None): - type = self.nodes.get("text") + def text(self, text: str, marks: Optional[List[Mark]] = None) -> TextNode: + type = self.nodes[cast(Nodes, "text")] return TextNode(type, type.default_attrs, text, Mark.set_from(marks)) - def mark(self, type, attrs=None): + def mark( + self, + type: Union[str, MarkType], + attrs: Optional[ + Union[Dict[str, Optional[str]], Dict[str, str], Dict[str, int]] + ] = None, + ) -> Mark: if isinstance(type, str): - type = self.marks[type] + type = self.marks[cast(Marks, type)] return type.create(attrs) - def node_from_json(self, json_data): + def node_from_json(self, json_data: JSON) -> Union[Node, TextNode]: return Node.from_json(self, json_data) - def mark_from_json(self, json_data): + def mark_from_json( + self, json_data: Dict[str, Union[str, Dict[str, Optional[str]], Dict[str, int]]] + ) -> Mark: return Mark.from_json(self, json_data) - def node_type(self, name): - found = self.nodes.get(name) + def node_type(self, name: str) -> NodeType: + found = self.nodes.get(cast(Nodes, name)) if not found: raise ValueError(f"Unknown node type: {name}") return found -def gather_marks(schema, marks): +def gather_marks(schema: Schema, marks: List[str]) -> List[MarkType]: found = [] for name in marks: mark = schema.marks.get(name) @@ -313,8 +463,7 @@ def gather_marks(schema, marks): if mark: found.append(mark) else: - for prop in schema.marks: - mark = schema.marks.get(prop) + for mark in schema.marks.values(): if name == "_" or ( mark.spec.get("group") and name in mark.spec["group"].split(" ") ): diff --git a/prosemirror/schema/basic/schema_basic.py b/prosemirror/schema/basic/schema_basic.py index 76b7ff3..46cd993 100644 --- a/prosemirror/schema/basic/schema_basic.py +++ b/prosemirror/schema/basic/schema_basic.py @@ -1,4 +1,6 @@ +from typing import Dict from prosemirror.model import Schema +from prosemirror.model.schema import NodeSpec, MarkSpec p_dom = ["p", 0] blockquote_dom = ["blockquote", 0] @@ -6,7 +8,7 @@ pre_dom = ["pre", ["code", 0]] br_dom = ["br"] -nodes = { +nodes: Dict[str, NodeSpec] = { "doc": {"content": "block+"}, "paragraph": { "content": "inline*", @@ -87,7 +89,7 @@ strong_dom = ["strong", 0] code_dom = ["code", 0] -marks = { +marks: Dict[str, MarkSpec] = { "link": { "attrs": {"href": {}, "title": {"default": None}}, "inclusive": False, @@ -110,4 +112,4 @@ } -schema = Schema({"nodes": nodes, "marks": marks}) +schema: Schema[str, str] = Schema({"nodes": nodes, "marks": marks}) diff --git a/prosemirror/test_builder/__init__.py b/prosemirror/test_builder/__init__.py index 1771a33..e500353 100644 --- a/prosemirror/test_builder/__init__.py +++ b/prosemirror/test_builder/__init__.py @@ -4,7 +4,7 @@ from .build import builders -test_schema = Schema( +test_schema: Schema[str, str] = Schema( { "nodes": add_list_nodes(_schema.spec["nodes"], "paragraph block*", "block"), "marks": _schema.spec["marks"], From a1160cf08669725435943de3ae745fce9fd08ffb Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Thu, 15 Jun 2023 13:15:30 -0400 Subject: [PATCH 06/40] Typing for prosemirror.model.fragment --- prosemirror/model/content.py | 5 +- prosemirror/model/diff.py | 70 ++++++------ prosemirror/model/fragment.py | 134 +++++++++++++++-------- prosemirror/model/from_dom.py | 4 +- prosemirror/model/node.py | 11 ++ prosemirror/model/resolvedpos.py | 90 ++++++++------- prosemirror/schema/basic/schema_basic.py | 3 +- prosemirror/transform/replace.py | 27 +++-- 8 files changed, 216 insertions(+), 128 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 680e2ff..01792d1 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -18,6 +18,7 @@ from .fragment import Fragment if TYPE_CHECKING: + from .node import Node from .schema import NodeType @@ -111,7 +112,9 @@ def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: nonlocal seen finished = match.match_fragment(after, start_index) if finished and (not to_end or finished.valid_end): - return Fragment.from_([tp.create_and_fill() for tp in types]) + return Fragment.from_( + [cast("Node", tp.create_and_fill()) for tp in types] + ) for i in match.next: type = i.type next = i.next diff --git a/prosemirror/model/diff.py b/prosemirror/model/diff.py index db72f33..e4f7c1a 100644 --- a/prosemirror/model/diff.py +++ b/prosemirror/model/diff.py @@ -2,6 +2,8 @@ from prosemirror.utils import text_length +from . import node as pm_node + if TYPE_CHECKING: from prosemirror.model.fragment import Fragment @@ -22,23 +24,26 @@ def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: continue if not child_a.same_markup(child_b): return pos - if child_a.is_text and child_a.text != child_b.text: - if child_b.text.startswith(child_a.text): - return pos + text_length(child_a.text) - if child_a.text.startswith(child_b.text): - return pos + text_length(child_b.text) - next_index = next( - ( - index_a - for ((index_a, char_a), (_, char_b)) in zip( - enumerate(child_a.text), enumerate(child_b.text) - ) - if char_a != char_b - ), - None, - ) - if next_index is not None: - return pos + next_index + if child_a.is_text: + assert isinstance(child_a, pm_node.TextNode) + assert isinstance(child_b, pm_node.TextNode) + if child_a.text != child_b.text: + if child_b.text.startswith(child_a.text): + return pos + text_length(child_a.text) + if child_a.text.startswith(child_b.text): + return pos + text_length(child_b.text) + next_index = next( + ( + index_a + for ((index_a, char_a), (_, char_b)) in zip( + enumerate(child_a.text), enumerate(child_b.text) + ) + if char_a != char_b + ), + None, + ) + if next_index is not None: + return pos + next_index if child_a.content.size or child_b.content.size: inner = find_diff_start(child_a.content, child_b.content, pos + 1) if inner: @@ -69,20 +74,23 @@ def find_diff_end( if not child_a.same_markup(child_b): return {"a": pos_a, "b": pos_b} - if child_a.is_text and child_a.text != child_b.text: - same, min_size = ( - 0, - min(text_length(child_a.text), text_length(child_b.text)), - ) - while ( - same < min_size - and child_a.text[text_length(child_a.text) - same - 1] - == child_b.text[text_length(child_b.text) - same - 1] - ): - same += 1 - pos_a -= 1 - pos_b -= 1 - return {"a": pos_a, "b": pos_b} + if child_a.is_text: + assert isinstance(child_a, pm_node.TextNode) + assert isinstance(child_b, pm_node.TextNode) + if child_a.text != child_b.text: + same, min_size = ( + 0, + min(text_length(child_a.text), text_length(child_b.text)), + ) + while ( + same < min_size + and child_a.text[text_length(child_a.text) - same - 1] + == child_b.text[text_length(child_b.text) - same - 1] + ): + same += 1 + pos_a -= 1 + pos_b -= 1 + return {"a": pos_a, "b": pos_b} if child_a.content.size or child_b.content.size: inner = find_diff_end( diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 2d03119..e332ede 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -1,27 +1,46 @@ -from typing import TYPE_CHECKING, ClassVar, Iterable, cast - -from prosemirror.utils import text_length - -from .diff import find_diff_end, find_diff_start +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Iterable, + List, + Optional, + Union, + cast, +) + +from prosemirror.utils import JSON, text_length if TYPE_CHECKING: + from prosemirror.model.schema import Schema + + from .diff import Diff from .node import Node, TextNode -def retIndex(index, offset): +def retIndex(index: int, offset: int) -> Dict[str, int]: return {"index": index, "offset": offset} class Fragment: empty: ClassVar["Fragment"] + content: List["Node"] + size: int - def __init__(self, content, size=None): + def __init__(self, content: List["Node"], size: Optional[int] = None) -> None: self.content = content - self.size = size - if size is None: - self.size = sum(c.node_size for c in content) - - def nodes_between(self, from_, to, f, node_start=0, parent=None): + self.size = size if size is not None else sum(c.node_size for c in content) + + def nodes_between( + self, + from_: int, + to: int, + f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], + node_start: int = 0, + parent: Optional["Node"] = None, + ) -> None: i = 0 pos = 0 while pos < to: @@ -42,14 +61,24 @@ def nodes_between(self, from_, to, f, node_start=0, parent=None): pos = end i += 1 - def descendants(self, f): + def descendants( + self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] + ) -> None: self.nodes_between(0, self.size, f) - def text_between(self, from_, to, block_separator="", leaf_text=""): + def text_between( + self, + from_: int, + to: int, + block_separator: str = "", + leaf_text: Union[Callable, str] = "", + ) -> str: text = [] separated = True - def iteratee(node: "Node", pos, *args): + def iteratee( + node: "Node", pos: int, _parent: Optional["Node"], _to: int + ) -> None: nonlocal text nonlocal separated if node.is_text: @@ -69,7 +98,7 @@ def iteratee(node: "Node", pos, *args): self.nodes_between(from_, to, iteratee, 0) return "".join(text) - def append(self, other): + def append(self, other: "Fragment") -> "Fragment": if not other.size: return self if not self.size: @@ -80,7 +109,9 @@ def append(self, other): self.content.copy(), 0, ) - if last.is_text and last.same_markup(first): + assert last is not None and first is not None + if pm_node.is_text(last) and last.same_markup(first): + assert isinstance(first, pm_node.TextNode) content[len(content) - 1] = last.with_text(last.text + first.text) i = 1 while i < len(other.content): @@ -88,12 +119,13 @@ def append(self, other): i += 1 return Fragment(content, self.size + other.size) - def cut(self, from_, to=None): + def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": if to is None: to = self.size if from_ == 0 and to == self.size: return self - result, size = [], 0 + result: List["Node"] = [] + size = 0 if to <= from_: return Fragment(result, size) i, pos = 0, 0 @@ -102,7 +134,7 @@ def cut(self, from_, to=None): end = pos + child.node_size if end > from_: if pos < from_ or end > to: - if child.is_text: + if pm_node.is_text(child): child = child.cut( max(0, from_ - pos), min(text_length(child.text), to - pos) ) @@ -117,14 +149,14 @@ def cut(self, from_, to=None): i += 1 return Fragment(result, size) - def cut_by_index(self, from_, to=None): + def cut_by_index(self, from_: int, to: Optional[int] = None) -> "Fragment": if from_ == to: return Fragment.empty if from_ == 0 and to == len(self.content): return self return Fragment(self.content[from_:to]) - def replace_child(self, index, node): + def replace_child(self, index: int, node: "Node") -> "Fragment": current = self.content[index] if current == node: return self @@ -133,39 +165,39 @@ def replace_child(self, index, node): copy[index] = node return Fragment(copy, size) - def add_to_start(self, node): + def add_to_start(self, node: "Node") -> "Fragment": return Fragment([node, *self.content], self.size + node.node_size) - def add_to_end(self, node): + def add_to_end(self, node: "Node") -> "Fragment": return Fragment([*self.content, node], self.size + node.node_size) - def eq(self, other): + def eq(self, other: "Fragment") -> bool: if len(self.content) != len(other.content): return False return all(a.eq(b) for (a, b) in zip(self.content, other.content)) @property - def first_child(self): + def first_child(self) -> Optional["Node"]: return self.content[0] if self.content else None @property - def last_child(self): + def last_child(self) -> Optional["Node"]: return self.content[-1] if self.content else None @property - def child_count(self): + def child_count(self) -> int: return len(self.content) - def child(self, index): + def child(self, index: int) -> "Node": return self.content[index] - def maybe_child(self, index): + def maybe_child(self, index: int) -> Optional["Node"]: try: return self.content[index] except IndexError: return None - def for_each(self, f): + def for_each(self, f: Callable) -> None: i = 0 p = 0 while i < len(self.content): @@ -174,17 +206,26 @@ def for_each(self, f): p += child.node_size i += 1 - def find_diff_start(self, other, pos=0): + def find_diff_start(self, other: "Fragment", pos: int = 0) -> Optional[int]: + from .diff import find_diff_start + return find_diff_start(self, other, pos) - def find_diff_end(self, other, pos=None, other_pos=None): + def find_diff_end( + self, + other: "Fragment", + pos: Optional[int] = None, + other_pos: Optional[int] = None, + ) -> Optional["Diff"]: + from .diff import find_diff_end + if pos is None: pos = self.size if other_pos is None: other_pos = other.size return find_diff_end(self, other, pos, other_pos) - def find_index(self, pos, round=-1): + def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: if pos == 0: return retIndex(0, pos) if pos == self.size: @@ -203,12 +244,13 @@ def find_index(self, pos, round=-1): i += 1 cur_pos = end - def to_json(self): + def to_json(self) -> JSON: if self.content: return [item.to_json() for item in self.content] + return None @classmethod - def from_json(cls, schema, value): + def from_json(cls, schema: "Schema", value: Any) -> "Fragment": if not value: return cls.empty if isinstance(value, str): @@ -220,26 +262,28 @@ def from_json(cls, schema, value): return cls([schema.node_from_json(item) for item in value]) @classmethod - def from_array(cls, array): + def from_array(cls, array: List["Node"]) -> "Fragment": if not array: return cls.empty joined, size = None, 0 for i in range(len(array)): node = array[i] size += node.node_size - if i and node.is_text and array[i - 1].same_markup(node): + if i and pm_node.is_text(node) and array[i - 1].same_markup(node): if not joined: joined = array[0:i] - joined[-1] = node.with_text(joined[-1].text + node.text) + last = joined[-1] + assert isinstance(last, pm_node.TextNode) + joined[-1] = node.with_text(last.text + node.text) elif joined: joined.append(node) return cls(joined or array, size) @classmethod - def from_(cls, nodes) -> "Fragment": + def from_(cls, nodes: Union["Fragment", "Node", List["Node"], None]) -> "Fragment": if not nodes: return cls.empty - if isinstance(nodes, cls): + if isinstance(nodes, Fragment): return nodes if isinstance(nodes, Iterable): return cls.from_array(list(nodes)) @@ -247,14 +291,16 @@ def from_(cls, nodes) -> "Fragment": return cls([nodes], nodes.node_size) raise ValueError(f"cannot convert {nodes!r} to a fragment") - def to_string_inner(self): + def to_string_inner(self) -> str: return ", ".join([str(i) for i in self.content]) - def __str__(self): + def __str__(self) -> str: return f"<{self.to_string_inner()}>" - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.__str__()}>" Fragment.empty = Fragment([], 0) + +from . import node as pm_node diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index e225c8b..969ac7b 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -418,7 +418,9 @@ def finish(self, open_end: bool) -> Union[Node, Fragment]: content = Fragment.from_(self.content) if not open_end and self.match is not None: - content = content.append(self.match.fill_before(Fragment.empty, True)) + content = content.append( + cast(Fragment, self.match.fill_before(Fragment.empty, True)) + ) return ( self.type.create(self.attrs, content, self.marks) if self.type else content diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 1cae455..cf09c8f 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,3 +1,5 @@ +from typing_extensions import TypeGuard + from prosemirror.utils import text_length from .comparedeep import compare_deep @@ -345,3 +347,12 @@ def wrap_marks(marks, str): str = marks[i].type.name + "(" + str + ")" i -= 1 return str + + +def is_text(node: Node) -> TypeGuard[TextNode]: + """ + Helper function to check if a node is a text node, but with + type narrowing. (TypeGuard cannot narrow the type of `self`; see + https://mypy.readthedocs.io/en/stable/type_narrowing.html#typeguards-as-methods) + """ + return node.is_text diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index af823bd..53c53da 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,93 +1,104 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Union, cast + from .mark import Mark +if TYPE_CHECKING: + from prosemirror.model.mark import Mark + from prosemirror.model.node import Node, TextNode + class ResolvedPos: - def __init__(self, pos, path, parent_offset): + def __init__( + self, pos: int, path: List[Union["Node", int]], parent_offset: int + ) -> None: self.pos = pos self.path = path self.depth = int(len(path) / 3 - 1) self.parent_offset = parent_offset - def resolve_depth(self, val=None): + def resolve_depth(self, val: Optional[int] = None) -> int: if val is None: return self.depth return self.depth + val if val < 0 else val @property - def parent(self): + def parent(self) -> "Node": return self.node(self.depth) @property - def doc(self): + def doc(self) -> "Node": return self.node(0) - def node(self, depth): - return self.path[self.resolve_depth(depth) * 3] + def node(self, depth: int) -> "Node": + return cast("Node", self.path[self.resolve_depth(depth) * 3]) - def index(self, depth=None): - return self.path[self.resolve_depth(depth) * 3 + 1] + def index(self, depth: Optional[int] = None) -> int: + return cast(int, self.path[self.resolve_depth(depth) * 3 + 1]) - def index_after(self, depth): + def index_after(self, depth: int) -> int: depth = self.resolve_depth(depth) return self.index(depth) + ( 0 if depth == self.depth and not self.text_offset else 1 ) - def start(self, depth=None): + def start(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) - return 0 if depth == 0 else self.path[depth * 3 - 1] + 1 + return 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 - def end(self, depth=None): + def end(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) return self.start(depth) + self.node(depth).content.size - def before(self, depth=None): + def before(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) if not depth: raise ValueError("There is no position before the top level node") - return self.pos if depth == self.depth + 1 else self.path[depth * 3 - 1] + return ( + self.pos if depth == self.depth + 1 else cast(int, self.path[depth * 3 - 1]) + ) - def after(self, depth=None): + def after(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) if not depth: raise ValueError("There is no position after the top level node") return ( self.pos if depth == self.depth + 1 - else self.path[depth * 3 - 1] + self.path[depth * 3].node_size + else cast(int, self.path[depth * 3 - 1]) + + cast("Node", self.path[depth * 3]).node_size ) @property - def text_offset(self): - return self.pos - self.path[-1] + def text_offset(self) -> int: + return self.pos - cast(int, self.path[-1]) @property - def node_after(self): + def node_after(self) -> Optional["Node"]: parent = self.parent index = self.index(self.depth) if index == parent.child_count: return None - d_off = self.pos - self.path[-1] + d_off = self.pos - cast(int, self.path[-1]) child = parent.child(index) return parent.child(index).cut(d_off) if d_off else child @property - def node_before(self): + def node_before(self) -> Optional["Node"]: index = self.index(self.depth) - d_off = self.pos - self.path[-1] + d_off = self.pos - cast(int, self.path[-1]) if d_off: return self.parent.child(index).cut(0, d_off) return None if index == 0 else self.parent.child(index - 1) - def pos_at_index(self, index, depth=None): + def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) - node = self.path[depth * 3] - pos = 0 if depth == 0 else self.path[depth * 3 - 1] + 1 + node = cast("Node", self.path[depth * 3]) + pos = 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 for i in range(index): pos += node.child(i).node_size return pos - def marks(self): + def marks(self) -> List["Mark"]: parent = self.parent index = self.index() if parent.content.size == 0: @@ -125,7 +136,7 @@ def marks_across(self, end): i += 1 return marks - def shared_depth(self, pos): + def shared_depth(self, pos: int) -> int: depth = self.depth while depth > 0: if self.start(depth) <= pos and self.end(depth) >= pos: @@ -133,7 +144,9 @@ def shared_depth(self, pos): depth -= 1 return 0 - def block_range(self, other=None, pred=None): + def block_range( + self, other: Optional["ResolvedPos"] = None, pred: None = None + ) -> Optional["NodeRange"]: if other is None: other = self if other.pos < self.pos: @@ -145,6 +158,7 @@ def block_range(self, other=None, pred=None): if other.pos <= self.end(d) and (not pred or pred(self.node(d))): return NodeRange(self, other, d) d -= 1 + return None def same_parent(self, other): return self.pos - self.parent_offset == other.pos - other.parent_offset @@ -155,7 +169,7 @@ def max(self, other): def min(self, other): return other if other.pos < self.pos else self - def __str__(self): + def __str__(self) -> str: path = "/".join( [ f"{self.node(i).type.name}_{self.index(i - 1)}" @@ -165,10 +179,10 @@ def __str__(self): return f"{path}:{self.parent_offset}" @classmethod - def resolve(cls, doc, pos): + def resolve(cls, doc: "Node", pos: int) -> "ResolvedPos": if not (pos >= 0 and pos <= doc.content.size): raise ValueError(f"Position {pos} out of range") - path = [] + path: List[Union["Node", int]] = [] start = 0 parent_offset = pos node = doc @@ -187,33 +201,33 @@ def resolve(cls, doc, pos): return cls(pos, path, parent_offset) @classmethod - def resolve_cached(cls, doc, pos): + def resolve_cached(cls, doc: "Node", pos: int) -> "ResolvedPos": # no cache for now return cls.resolve(doc, pos) class NodeRange: - def __init__(self, from_, to, depth): + def __init__(self, from_: ResolvedPos, to: ResolvedPos, depth: int) -> None: self.from_ = from_ self.to = to self.depth = depth @property - def start(self): + def start(self) -> int: return self.from_.before(self.depth + 1) @property - def end(self): + def end(self) -> int: return self.to.after(self.depth + 1) @property - def parent(self): + def parent(self) -> "Node": return self.from_.node(self.depth) @property - def start_index(self): + def start_index(self) -> int: return self.from_.index(self.depth) @property - def end_index(self): + def end_index(self) -> int: return self.to.index_after(self.depth) diff --git a/prosemirror/schema/basic/schema_basic.py b/prosemirror/schema/basic/schema_basic.py index 46cd993..f98cc5e 100644 --- a/prosemirror/schema/basic/schema_basic.py +++ b/prosemirror/schema/basic/schema_basic.py @@ -1,6 +1,7 @@ from typing import Dict + from prosemirror.model import Schema -from prosemirror.model.schema import NodeSpec, MarkSpec +from prosemirror.model.schema import MarkSpec, NodeSpec p_dom = ["p", 0] blockquote_dom = ["blockquote", 0] diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index 88622c4..c4fa503 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -1,6 +1,6 @@ -from typing import List, Optional +from typing import List, Optional, cast -from prosemirror.model import Fragment, ResolvedPos, Slice +from prosemirror.model import Fragment, Node, ResolvedPos, Slice from .replace_step import ReplaceAroundStep, ReplaceStep, Step @@ -98,7 +98,9 @@ def fit(self) -> Optional[Step]: open_start = from__.depth open_end = to_.depth while open_start and open_end and content.child_count == 1: - content = content.first_child.content + first_child = content.first_child + assert first_child + content = first_child.content open_start -= 1 open_end -= 1 @@ -137,6 +139,7 @@ def find_fittable(self) -> Optional[_Fittable]: parent = content_at( self.unplaced.content, slice_depth - 1 ).first_child + assert parent fragment = parent.content else: parent = None @@ -165,7 +168,7 @@ def _lazy_wrap(): if pass_ == 1 and ( (match.match_type(first.type) or _lazy_inject()) if first - else type_.compatible_content(parent.type) + else parent and type_.compatible_content(parent.type) ): return _Fittable( slice_depth, @@ -189,7 +192,7 @@ def open_more(self) -> bool: open_start = self.unplaced.open_start open_end = self.unplaced.open_end inner = content_at(content, open_start) - if not inner.child_count or inner.first_child.is_leaf: + if not inner.child_count or cast("Node", inner.first_child).is_leaf: return False self.unplaced = Slice( content, @@ -404,28 +407,28 @@ def close_frontier_node(self): def drop_from_fragment(fragment: Fragment, depth: int, count: int) -> Fragment: if depth == 0: return fragment.cut_by_index(count) + first_child = fragment.first_child + assert first_child return fragment.replace_child( 0, - fragment.first_child.copy( - drop_from_fragment(fragment.first_child.content, depth - 1, count) - ), + first_child.copy(drop_from_fragment(first_child.content, depth - 1, count)), ) def add_to_fragment(fragment: Fragment, depth: int, content: Fragment) -> Fragment: if depth == 0: return fragment.append(content) + last_child = fragment.last_child + assert last_child return fragment.replace_child( fragment.child_count - 1, - fragment.last_child.copy( - add_to_fragment(fragment.last_child.content, depth - 1, content) - ), + last_child.copy(add_to_fragment(last_child.content, depth - 1, content)), ) def content_at(fragment: Fragment, depth: int) -> Fragment: for _ in range(depth): - fragment = fragment.first_child.content + fragment = cast(Node, fragment.first_child).content return fragment From 6426a8059cc80fa8ba6b9a6a8b4bd64f3e3455ce Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Thu, 15 Jun 2023 20:40:07 -0400 Subject: [PATCH 07/40] Add typings for node --- prosemirror/model/node.py | 149 +++++++++++++++++++++++------------- prosemirror/model/schema.py | 2 + 2 files changed, 99 insertions(+), 52 deletions(-) diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index cf09c8f..f82bb69 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,6 +1,14 @@ +from typing import Any, Callable, Dict, List, Optional, Union + from typing_extensions import TypeGuard -from prosemirror.utils import text_length +from prosemirror.model.content import ContentMatch +from prosemirror.model.fragment import Fragment +from prosemirror.model.mark import Mark +from prosemirror.model.replace import Slice +from prosemirror.model.resolvedpos import ResolvedPos +from prosemirror.model.schema import NodeType, Schema +from prosemirror.utils import JSON, JSONDict, text_length from .comparedeep import compare_deep from .fragment import Fragment @@ -12,83 +20,104 @@ class Node: - def __init__(self, type, attrs, content: Fragment, marks): + def __init__( + self, + type: NodeType, + attrs: JSONDict, + content: Optional[Fragment], + marks: List[Mark], + ) -> None: self.type = type self.attrs = attrs self.content = content or Fragment.empty self.marks = marks or Mark.none @property - def node_size(self): + def node_size(self) -> int: return 1 if self.is_leaf else 2 + self.content.size @property - def child_count(self): + def child_count(self) -> int: return self.content.child_count - def child(self, index): + def child(self, index: int) -> "Node": return self.content.child(index) - def maybe_child(self, index): + def maybe_child(self, index: int) -> Optional["Node"]: return self.content.maybe_child(index) def for_each(self, f): self.content.for_each(f) - def nodes_between(self, from_, to, f, start_pos=0): + def nodes_between( + self, from_: int, to: int, f: Callable, start_pos: int = 0 + ) -> None: self.content.nodes_between(from_, to, f, start_pos, self) def descendants(self, f): self.nodes_between(0, self.content.size, f) @property - def text_content(self): + def text_content(self) -> str: if self.is_leaf and self.type.spec.get("leafText") is not None: return self.type.spec["leafText"](self) return self.text_between(0, self.content.size, "") - def text_between(self, from_, to, block_separator="", leaf_text=""): + def text_between( + self, + from_: int, + to: int, + block_separator: str = "", + leaf_text: Union[Callable, str] = "", + ) -> str: return self.content.text_between(from_, to, block_separator, leaf_text) @property - def first_child(self): + def first_child(self) -> Optional["Node"]: return self.content.first_child @property - def last_child(self): + def last_child(self) -> Optional["Node"]: return self.content.last_child - def eq(self, other: "Node"): + def eq(self, other: "Node") -> bool: return self == other or ( self.same_markup(other) and self.content.eq(other.content) ) - def same_markup(self, other: "Node"): + def same_markup(self, other: "Node") -> bool: return self.has_markup(other.type, other.attrs, other.marks) - def has_markup(self, type, attrs, marks=None): + def has_markup( + self, + type: NodeType, + attrs: Optional[JSONDict] = None, + marks: Optional[List[Mark]] = None, + ) -> bool: return ( self.type.name == type.name and (compare_deep(self.attrs, attrs or type.default_attrs or empty_attrs)) and (Mark.same_set(self.marks, marks or Mark.none)) ) - def copy(self, content=None): + def copy(self, content: Optional[Fragment] = None) -> "Node": if content == self.content: return self return self.__class__(self.type, self.attrs, content, self.marks) - def mark(self, marks): + def mark(self, marks: List[Mark]) -> "Node": if marks == self.marks: return self return self.__class__(self.type, self.attrs, self.content, marks) - def cut(self, from_, to): + def cut(self, from_: int, to: Optional[int]) -> "Node": if from_ == 0 and to == self.content.size: return self return self.copy(self.content.cut(from_, to)) - def slice(self, from_, to=None, include_parents=False): + def slice( + self, from_: int, to: Optional[int] = None, include_parents: bool = False + ) -> Slice: if to is None: to = self.content.size if from_ == to: @@ -101,17 +130,18 @@ def slice(self, from_, to=None, include_parents=False): content = node.content.cut(from__.pos - start, to_.pos - start) return Slice(content, from__.depth - depth, to_.depth - depth) - def replace(self, from_, to, slice): + def replace(self, from_: int, to: int, slice: Slice) -> "Node": return replace(self.resolve(from_), self.resolve(to), slice) - def node_at(self, pos): + def node_at(self, pos: int) -> Optional["Node"]: node = self while True: index_info = node.content.find_index(pos) index, offset = index_info["index"], index_info["offset"] - node = node.maybe_child(index) - if not node: + next_node = node.maybe_child(index) + if not next_node: return None + node = next_node if offset == pos or node.is_text: return node pos -= offset + 1 @@ -135,10 +165,10 @@ def child_before(self, pos): node = self.content.child(index - 1) return {"node": node, "index": index - 1, "offset": offset - node.node_size} - def resolve(self, pos): + def resolve(self, pos: int) -> ResolvedPos: return ResolvedPos.resolve_cached(self, pos) - def resolve_no_cache(self, pos): + def resolve_no_cache(self, pos: int) -> ResolvedPos: return ResolvedPos.resolve(self, pos) def range_has_mark(self, from_, to, type): @@ -159,30 +189,30 @@ def is_block(self): return self.type.is_block @property - def is_text_block(self): + def is_text_block(self) -> bool: return self.type.is_text_block @property - def inline_content(self): + def inline_content(self) -> bool: return self.type.inline_content @property - def is_inline(self): + def is_inline(self) -> bool: return self.type.is_inline @property - def is_text(self): + def is_text(self) -> bool: return self.type.is_text @property - def is_leaf(self): + def is_leaf(self) -> bool: return self.type.is_leaf @property - def is_atom(self): + def is_atom(self) -> bool: return self.type.is_atom - def __str__(self): + def __str__(self) -> str: to_debug_string = ( self.type.spec["toDebugString"] if "toDebugString" in self.type.spec @@ -198,17 +228,24 @@ def __str__(self): def __repr__(self): return f"<{self.__class__.__name__} {self.__str__()}>" - def content_match_at(self, index): + def content_match_at(self, index: int) -> ContentMatch: match = self.type.content_match.match_fragment(self.content, 0, index) if not match: raise ValueError("Called contentMatchAt on a node with invalid content") return match - def can_replace(self, from_, to, replacement=Fragment.empty, start=0, end=None): + def can_replace( + self, + from_: int, + to: int, + replacement: Fragment = Fragment.empty, + start: int = 0, + end: Optional[int] = None, + ) -> bool: if end is None: end = replacement.child_count one = self.content_match_at(from_).match_fragment(replacement, start, end) - two = False + two: Optional[ContentMatch] = None if one: two = one.match_fragment(self.content, to) if not two or not two.valid_end: @@ -218,11 +255,13 @@ def can_replace(self, from_, to, replacement=Fragment.empty, start=0, end=None): return False return True - def can_replace_with(self, from_, to, type, marks=None): + def can_replace_with( + self, from_: int, to: int, type: NodeType, marks: None = None + ) -> bool: if marks and not self.type.allows_marks(marks): return False start = self.content_match_at(from_).match_type(type) - end = False + end: Optional[ContentMatch] = None if start: end = start.match_fragment(self.content, to) return end.valid_end if end else False @@ -248,8 +287,8 @@ def check(self): ) return self.content.for_each(lambda node, *args: node.check()) - def to_json(self): - obj = {"type": self.type.name} + def to_json(self) -> JSONDict: + obj: Dict[str, JSON] = {"type": self.type.name} for _ in self.attrs: obj["attrs"] = self.attrs break @@ -260,7 +299,7 @@ def to_json(self): return obj @classmethod - def from_json(cls, schema, json_data): + def from_json(cls, schema: Schema, json_data: Any) -> "Node": if isinstance(json_data, str): import json @@ -281,13 +320,19 @@ def from_json(cls, schema, json_data): class TextNode(Node): - def __init__(self, type, attrs, content, marks): + def __init__( + self, + type: NodeType, + attrs: JSONDict, + content: str, + marks: List[Mark], + ) -> None: super().__init__(type, attrs, None, marks) if not content: raise ValueError("Empty text nodes are not allowed") self.text = content - def __str__(self): + def __str__(self) -> str: import json to_debug_string = ( @@ -300,29 +345,29 @@ def __str__(self): return wrap_marks(self.marks, json.dumps(self.text)) @property - def text_content(self): + def text_content(self) -> str: return self.text def text_between(self, from_, to): return self.text[from_:to] @property - def node_size(self): + def node_size(self) -> int: return text_length(self.text) - def mark(self, marks): + def mark(self, marks: List[Mark]) -> "TextNode": return ( self if marks == self.marks else TextNode(self.type, self.attrs, self.text, marks) ) - def with_text(self, text): + def with_text(self, text: str) -> "TextNode": if text == self.text: return self return TextNode(self.type, self.attrs, text, self.marks) - def cut(self, from_=0, to=None): + def cut(self, from_: int = 0, to: Optional[int] = None) -> "TextNode": if to is None: to = text_length(self.text) if from_ == 0 and to == text_length(self.text): @@ -332,16 +377,16 @@ def cut(self, from_=0, to=None): ) return self.with_text(substring) - def eq(self, other): + def eq(self, other: Node) -> bool: return self.same_markup(other) and self.text == getattr(other, "text", None) - def to_json(self): - base = super().to_json() - base["text"] = self.text - return base + def to_json( + self, + ) -> JSONDict: + return {**super().to_json(), "text": self.text} -def wrap_marks(marks, str): +def wrap_marks(marks: List[Mark], str: str) -> str: i = len(marks) - 1 while i >= 0: str = marks[i].type.name + "(" + str + ")" diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 136b292..3896711 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -346,6 +346,8 @@ class NodeSpec(TypedDict, total=False): isolating: bool toDOM: Callable[[Node], Any] # FIXME: add types parseDOM: List[Dict[str, Any]] # FIXME: add types + toDebugString: Callable[[Node], str] + leafText: Callable[[Node], str] AttributeSpecs: TypeAlias = Dict[str, "AttributeSpec"] From 690f7d5fb0d702bc82dca2744999e973a2b9bdca Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Thu, 15 Jun 2023 21:03:54 -0400 Subject: [PATCH 08/40] Finish typing for node --- prosemirror/model/fragment.py | 4 +- prosemirror/model/from_dom.py | 6 ++- prosemirror/model/node.py | 79 ++++++++++++++++++++------------ prosemirror/model/resolvedpos.py | 7 ++- prosemirror/model/schema.py | 6 ++- 5 files changed, 63 insertions(+), 39 deletions(-) diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index e332ede..91b54f4 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -197,7 +197,7 @@ def maybe_child(self, index: int) -> Optional["Node"]: except IndexError: return None - def for_each(self, f: Callable) -> None: + def for_each(self, f: Callable[["Node", int, int], Any]) -> None: i = 0 p = 0 while i < len(self.content): @@ -303,4 +303,4 @@ def __repr__(self) -> str: Fragment.empty = Fragment([], 0) -from . import node as pm_node +from . import node as pm_node # noqa: E402 diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 969ac7b..9d8bc22 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -7,6 +7,8 @@ from lxml.cssselect import CSSSelector from lxml.html import HtmlElement as DOMNode +from prosemirror.utils import JSONDict + from .content import ContentMatch from .fragment import Fragment from .mark import Mark @@ -707,7 +709,7 @@ def add_element_by_rule( elif rule.get_content is not None: self.find_inside(dom_) rule.get_content(dom_, self.parser.schema).for_each( - lambda node: self.insert_node(node) + lambda node, offset, index: self.insert_node(node) ) else: content_dom = dom_ @@ -1169,7 +1171,7 @@ def get_node_type(element: DOMNode) -> int: return 8 -def from_html(schema: Schema, html: str) -> Dict[str, Any]: +def from_html(schema: Schema, html: str) -> JSONDict: fragment = lxml.html.fragment_fromstring(html, create_parent="document-fragment") prose_doc = DOMParser.from_schema(schema).parse(fragment) diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index f82bb69..9929530 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,13 +1,7 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypedDict, Union from typing_extensions import TypeGuard -from prosemirror.model.content import ContentMatch -from prosemirror.model.fragment import Fragment -from prosemirror.model.mark import Mark -from prosemirror.model.replace import Slice -from prosemirror.model.resolvedpos import ResolvedPos -from prosemirror.model.schema import NodeType, Schema from prosemirror.utils import JSON, JSONDict, text_length from .comparedeep import compare_deep @@ -16,13 +10,24 @@ from .replace import Slice, replace from .resolvedpos import ResolvedPos +if TYPE_CHECKING: + from .content import ContentMatch + from .schema import MarkType, NodeType, Schema + + empty_attrs: dict = {} +class ChildInfo(TypedDict): + node: Optional["Node"] + index: int + offset: int + + class Node: def __init__( self, - type: NodeType, + type: "NodeType", attrs: JSONDict, content: Optional[Fragment], marks: List[Mark], @@ -46,15 +51,21 @@ def child(self, index: int) -> "Node": def maybe_child(self, index: int) -> Optional["Node"]: return self.content.maybe_child(index) - def for_each(self, f): + def for_each(self, f: Callable[["Node", int, int], None]) -> None: self.content.for_each(f) def nodes_between( - self, from_: int, to: int, f: Callable, start_pos: int = 0 + self, + from_: int, + to: int, + f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], + start_pos: int = 0, ) -> None: self.content.nodes_between(from_, to, f, start_pos, self) - def descendants(self, f): + def descendants( + self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] + ) -> None: self.nodes_between(0, self.content.size, f) @property @@ -90,7 +101,7 @@ def same_markup(self, other: "Node") -> bool: def has_markup( self, - type: NodeType, + type: "NodeType", attrs: Optional[JSONDict] = None, marks: Optional[List[Mark]] = None, ) -> bool: @@ -110,7 +121,7 @@ def mark(self, marks: List[Mark]) -> "Node": return self return self.__class__(self.type, self.attrs, self.content, marks) - def cut(self, from_: int, to: Optional[int]) -> "Node": + def cut(self, from_: int, to: Optional[int] = None) -> "Node": if from_ == 0 and to == self.content.size: return self return self.copy(self.content.cut(from_, to)) @@ -146,7 +157,7 @@ def node_at(self, pos: int) -> Optional["Node"]: return node pos -= offset + 1 - def child_after(self, pos): + def child_after(self, pos: int) -> ChildInfo: index_info = self.content.find_index(pos) index, offset = index_info["index"], index_info["offset"] return { @@ -155,7 +166,7 @@ def child_after(self, pos): "offset": offset, } - def child_before(self, pos): + def child_before(self, pos: int) -> ChildInfo: if pos == 0: return {"node": None, "index": 0, "offset": 0} index_info = self.content.find_index(pos) @@ -171,11 +182,15 @@ def resolve(self, pos: int) -> ResolvedPos: def resolve_no_cache(self, pos: int) -> ResolvedPos: return ResolvedPos.resolve(self, pos) - def range_has_mark(self, from_, to, type): + def range_has_mark( + self, from_: int, to: int, type: Union[Mark, "MarkType"] + ) -> bool: found = False if to > from_: - def iteratee(node): + def iteratee( + node: "Node", pos: int, parent: Optional["Node"], index: int + ) -> bool: nonlocal found if type.is_in_set(node.marks): found = True @@ -185,7 +200,7 @@ def iteratee(node): return found @property - def is_block(self): + def is_block(self) -> bool: return self.type.is_block @property @@ -225,10 +240,10 @@ def __str__(self) -> str: name += f"({self.content.to_string_inner()})" return wrap_marks(self.marks, name) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.__str__()}>" - def content_match_at(self, index: int) -> ContentMatch: + def content_match_at(self, index: int) -> "ContentMatch": match = self.type.content_match.match_fragment(self.content, 0, index) if not match: raise ValueError("Called contentMatchAt on a node with invalid content") @@ -245,7 +260,7 @@ def can_replace( if end is None: end = replacement.child_count one = self.content_match_at(from_).match_fragment(replacement, start, end) - two: Optional[ContentMatch] = None + two: Optional["ContentMatch"] = None if one: two = one.match_fragment(self.content, to) if not two or not two.valid_end: @@ -256,23 +271,23 @@ def can_replace( return True def can_replace_with( - self, from_: int, to: int, type: NodeType, marks: None = None + self, from_: int, to: int, type: "NodeType", marks: None = None ) -> bool: if marks and not self.type.allows_marks(marks): return False start = self.content_match_at(from_).match_type(type) - end: Optional[ContentMatch] = None + end: Optional["ContentMatch"] = None if start: end = start.match_fragment(self.content, to) return end.valid_end if end else False - def can_append(self, other): + def can_append(self, other: "Node") -> bool: if other.content.size: - return self.can_replace(self.child_count, self.child_before, other.content) + return self.can_replace(self.child_count, self.child_count, other.content) else: return self.type.compatible_content(other.type) - def check(self): + def check(self) -> None: if not self.type.valid_content(self.content): raise ValueError( f"Invalid content for node {self.type.name}: {str(self.content)[:50]}" @@ -299,7 +314,7 @@ def to_json(self) -> JSONDict: return obj @classmethod - def from_json(cls, schema: Schema, json_data: Any) -> "Node": + def from_json(cls, schema: "Schema", json_data: Any) -> "Node": if isinstance(json_data, str): import json @@ -322,7 +337,7 @@ def from_json(cls, schema: Schema, json_data: Any) -> "Node": class TextNode(Node): def __init__( self, - type: NodeType, + type: "NodeType", attrs: JSONDict, content: str, marks: List[Mark], @@ -348,7 +363,13 @@ def __str__(self) -> str: def text_content(self) -> str: return self.text - def text_between(self, from_, to): + def text_between( + self, + from_: int, + to: int, + block_separator: str = "", + leaf_text: Union[Callable, str] = "", + ) -> str: return self.text[from_:to] @property diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index 53c53da..11bddb7 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,10 +1,9 @@ -from typing import TYPE_CHECKING, Any, List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Union, cast from .mark import Mark if TYPE_CHECKING: - from prosemirror.model.mark import Mark - from prosemirror.model.node import Node, TextNode + from .node import Node class ResolvedPos: @@ -109,7 +108,7 @@ def marks(self) -> List["Mark"]: other = parent.maybe_child(index) if not main: main, other = other, main - marks = main.marks + marks = cast("Node", main).marks i = 0 while i < len(marks): if marks[i].type.spec.get("inclusive") is False and ( diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 3896711..1f96fb6 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -31,7 +31,7 @@ def default_attrs(attrs: "Attributes") -> Optional[Attrs]: defaults = {} for attr_name, attr in attrs.items(): - if attr.has_default: + if not attr.has_default: return None defaults[attr_name] = attr.default return defaults @@ -429,7 +429,9 @@ def node( def text(self, text: str, marks: Optional[List[Mark]] = None) -> TextNode: type = self.nodes[cast(Nodes, "text")] - return TextNode(type, type.default_attrs, text, Mark.set_from(marks)) + return TextNode( + type, cast(Attrs, type.default_attrs), text, Mark.set_from(marks) + ) def mark( self, From 662170e159081a1cd1bf1b063cecb3a61267992a Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 10:58:54 -0400 Subject: [PATCH 09/40] Typing for prosemirror.model.replace --- prosemirror/model/replace.py | 100 +++++++++++++++++++------------ prosemirror/model/resolvedpos.py | 8 +-- prosemirror/transform/replace.py | 2 +- 3 files changed, 68 insertions(+), 42 deletions(-) diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 3112e40..0ae9fed 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,22 +1,30 @@ -from typing import ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union, cast +from prosemirror.utils import JSON, JSONDict from .fragment import Fragment +if TYPE_CHECKING: + from .node import Node, TextNode + from .resolvedpos import ResolvedPos + from .schema import Schema + + class ReplaceError(ValueError): pass -def remove_range(content, from_, to): +def remove_range(content: Fragment, from_: int, to: int) -> Fragment: from_index_info = content.find_index(from_) index, offset = from_index_info["index"], from_index_info["offset"] child = content.maybe_child(index) to_index_info = content.find_index(to) index_to, offset_to = to_index_info["index"], to_index_info["offset"] - if offset == from_ or child.is_text: + if offset == from_ or cast("Node", child).is_text: if offset_to != to and not content.child(index_to).is_text: raise ValueError("removing non-flat range") return content.cut(0, from_).append(content.cut(to)) + assert child if index != index_to: raise ValueError("removing non-flat range") return content.replace_child( @@ -25,14 +33,17 @@ def remove_range(content, from_, to): ) -def insert_into(content, dist, insert, parent): +def insert_into( + content: Fragment, dist: int, insert: Fragment, parent: None +) -> Optional[Fragment]: a = content.find_index(dist) index, offset = a["index"], a["offset"] child = content.maybe_child(index) - if offset == dist or child.is_text: + if offset == dist or cast("Node", child).is_text: if parent and not parent.can_replace(index, index, insert): return None return content.cut(0, dist).append(insert).append(content.cut(dist)) + assert child inner = insert_into(child.content, dist - offset - 1, insert, None) if inner: return content.replace_child(index, child.copy(inner)) @@ -42,38 +53,39 @@ def insert_into(content, dist, insert, parent): class Slice: empty: ClassVar["Slice"] - def __init__(self, content, open_start, open_end): + def __init__(self, content: Fragment, open_start: int, open_end: int) -> None: self.content = content self.open_start = open_start self.open_end = open_end @property - def size(self): + def size(self) -> int: return self.content.size - self.open_start - self.open_end - def insert_at(self, pos, fragment): + def insert_at(self, pos: int, fragment: Fragment) -> Optional["Slice"]: content = insert_into(self.content, pos + self.open_start, fragment, None) if content: return Slice(content, self.open_start, self.open_end) + return None - def remove_between(self, from_, to): + def remove_between(self, from_: int, to: int) -> "Slice": return Slice( remove_range(self.content, from_ + self.open_start, to + self.open_start), self.open_start, self.open_end, ) - def eq(self, other): + def eq(self, other: "Slice") -> bool: return ( self.content.eq(other.content) and self.open_start == other.open_start and self.open_end == other.open_end ) - def __str__(self): + def __str__(self) -> str: return f"{self.content}({self.open_start},{self.open_end})" - def to_json(self): + def to_json(self) -> JSON: if not self.content.size: return None json = {"content": self.content.to_json()} @@ -84,11 +96,7 @@ def to_json(self): return json @classmethod - def from_json(cls, schema, json_data): - if isinstance(json_data, str): - import json - - json_data = json.loads(json) + def from_json(cls, schema: "Schema", json_data: JSONDict) -> "Slice": if not json_data: return cls.empty open_start = json_data.get("openStart", 0) or 0 @@ -102,7 +110,7 @@ def from_json(cls, schema, json_data): ) @classmethod - def max_open(cls, fragment: Fragment, open_isolating=True): + def max_open(cls, fragment: Fragment, open_isolating: bool = True) -> "Slice": open_start = 0 open_end = 0 n = fragment.first_child @@ -119,7 +127,7 @@ def max_open(cls, fragment: Fragment, open_isolating=True): Slice.empty = Slice(Fragment.empty, 0, 0) -def replace(from_, to, slice): +def replace(from_: "ResolvedPos", to: "ResolvedPos", slice: Slice) -> "Node": if slice.open_start > from_.depth: raise ReplaceError("Inserted content deeper than insertion position") if from_.depth - slice.open_start != to.depth - slice.open_end: @@ -127,7 +135,9 @@ def replace(from_, to, slice): return replace_outer(from_, to, slice, 0) -def replace_outer(from_, to, slice: Slice, depth): +def replace_outer( + from_: "ResolvedPos", to: "ResolvedPos", slice: Slice, depth: int +) -> "Node": index = from_.index(depth) node = from_.node(depth) if index == to.index(depth) and depth < from_.depth - slice.open_start: @@ -155,27 +165,32 @@ def replace_outer(from_, to, slice: Slice, depth): return close(node, replace_three_way(from_, start, end, to, depth)) -def check_join(main, sub): +def check_join(main: "Node", sub: "Node") -> None: if not sub.type.compatible_content(main.type): raise ReplaceError(f"Cannot join {sub.type.name} onto {main.type.name}") -def joinable(before, after, depth): +def joinable(before: "ResolvedPos", after: "ResolvedPos", depth: int) -> "Node": node = before.node(depth) check_join(node, after.node(depth)) return node -def add_node(child, target): +def add_node(child: "Node", target: List["Node"]) -> None: last = len(target) - 1 - if last >= 0 and child.is_text and child.same_markup(target[last]): - target[last] = child.with_text(target[last].text + child.text) + if last >= 0 and pm_node.is_text(child) and child.same_markup(target[last]): + target[last] = child.with_text(cast("TextNode", target[last]).text + child.text) else: target.append(child) -def add_range(start, end, depth, target): - node = (end or start).node(depth) +def add_range( + start: Optional["ResolvedPos"], + end: Optional["ResolvedPos"], + depth: int, + target: List["Node"], +) -> None: + node = cast("ResolvedPos", end or start).node(depth) start_index = 0 end_index = end.index(depth) if end else node.child_count if start: @@ -183,26 +198,32 @@ def add_range(start, end, depth, target): if start.depth > depth: start_index += 1 elif start.text_offset: - add_node(start.node_after, target) + add_node(cast("Node", start.node_after), target) start_index += 1 i = start_index while i < end_index: add_node(node.child(i), target) i += 1 if end and end.depth == depth and end.text_offset: - add_node(end.node_before, target) + add_node(cast("Node", end.node_before), target) -def close(node, content): +def close(node: "Node", content: Fragment) -> "Node": if not node.type.valid_content(content): raise ReplaceError(f"Invalid content for node {node.type.name}") return node.copy(content) -def replace_three_way(from_, start, end, to, depth): - open_start = joinable(from_, start, depth + 1) if from_.depth > depth else False - open_end = joinable(end, to, depth + 1) if to.depth > depth else False - content = [] +def replace_three_way( + from_: "ResolvedPos", + start: "ResolvedPos", + end: "ResolvedPos", + to: "ResolvedPos", + depth: int, +) -> Fragment: + open_start = joinable(from_, start, depth + 1) if from_.depth > depth else None + open_end = joinable(end, to, depth + 1) if to.depth > depth else None + content: List["Node"] = [] add_range(None, from_, depth, content) if open_start and open_end and start.index(depth) == end.index(depth): check_join(open_start, open_end) @@ -222,8 +243,8 @@ def replace_three_way(from_, start, end, to, depth): return Fragment(content) -def replace_two_way(from_, to, depth): - content = [] +def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Fragment: + content: List["Node"] = [] add_range(None, from_, depth, content) if from_.depth > depth: type = joinable(from_, to, depth + 1) @@ -232,7 +253,9 @@ def replace_two_way(from_, to, depth): return Fragment(content) -def prepare_slice_for_replace(slice: Slice, along): +def prepare_slice_for_replace( + slice: Slice, along: "ResolvedPos" +) -> Dict[str, "ResolvedPos"]: extra = along.depth - slice.open_start parent = along.node(extra) node = parent.copy(slice.content) @@ -242,3 +265,6 @@ def prepare_slice_for_replace(slice: Slice, along): "start": node.resolve_no_cache(slice.open_start + extra), "end": node.resolve_no_cache(node.content.size - slice.open_end - extra), } + + +from . import node as pm_node # noqa: E402 diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index 11bddb7..a26e4a2 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -119,7 +119,7 @@ def marks(self) -> List["Mark"]: i += 1 return marks - def marks_across(self, end): + def marks_across(self, end: "ResolvedPos") -> Optional[List["Mark"]]: after = self.parent.maybe_child(self.index()) if not after or not after.is_inline: return None @@ -159,13 +159,13 @@ def block_range( d -= 1 return None - def same_parent(self, other): + def same_parent(self, other: "ResolvedPos") -> bool: return self.pos - self.parent_offset == other.pos - other.parent_offset - def max(self, other): + def max(self, other: "ResolvedPos") -> "ResolvedPos": return other if other.pos > self.pos else self - def min(self, other): + def min(self, other: "ResolvedPos") -> "ResolvedPos": return other if other.pos < self.pos else self def __str__(self) -> str: diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index c4fa503..437958f 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -123,7 +123,7 @@ def find_fittable(self) -> Optional[_Fittable]: cur = self.unplaced.content open_end = self.unplaced.open_end for d in range(start_depth): - node = cur.first_child + node = cast("Node", cur.first_child) if cur.child_count > 1: open_end = 0 if node.type.spec.get("isolating") and open_end <= d: From 54614bc9fdb227fcc0708da8cb0dab10f37aa6dc Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 11:08:42 -0400 Subject: [PATCH 10/40] Typing for prosemirror.model.mark --- prosemirror/model/mark.py | 35 ++++++++++++++++++++--------------- prosemirror/model/replace.py | 4 ++-- prosemirror/model/to_dom.py | 12 +++++------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index 956c55d..c3fc007 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,14 +1,19 @@ -from typing import Final, List +from typing import TYPE_CHECKING, Any, Final, List, Optional, cast + +from prosemirror.utils import JSONDict + +if TYPE_CHECKING: + from .schema import MarkType, Schema class Mark: none: Final[List["Mark"]] = [] - def __init__(self, type, attrs): + def __init__(self, type: "MarkType", attrs: JSONDict) -> None: self.type = type self.attrs = attrs - def add_to_set(self, set): + def add_to_set(self, set: List["Mark"]) -> List["Mark"]: copy = None placed = False for i in range(len(set)): @@ -34,36 +39,36 @@ def add_to_set(self, set): copy.append(self) return copy - def remove_from_set(self, set): + def remove_from_set(self, set: List["Mark"]) -> List["Mark"]: return [item for item in set if not item.eq(self)] - def is_in_set(self, set): + def is_in_set(self, set: List["Mark"]) -> bool: return any(item.eq(self) for item in set) - def eq(self, other): + def eq(self, other: "Mark") -> bool: if self == other: return True return self.type.name == other.type.name and self.attrs == other.attrs - def to_json(self): + def to_json(self) -> JSONDict: return {"type": self.type.name, "attrs": self.attrs} @classmethod - def from_json(cls, schema, json_data): - if isinstance(json_data, str): - import json - - json_data = json.loads(json_data) + def from_json( + cls, + schema: "Schema[Any, Any]", + json_data: JSONDict, + ) -> "Mark": if not json_data: raise ValueError("Invalid input for Mark.fromJSON") name = json_data["type"] type = schema.marks.get(name) if not type: raise ValueError(f"There is no mark type {name} in this schema") - return type.create(json_data.get("attrs")) + return type.create(cast(Optional[JSONDict], json_data.get("attrs"))) @classmethod - def same_set(cls, a, b): + def same_set(cls, a: List["Mark"], b: List["Mark"]) -> bool: if a == b: return True if len(a) != len(b): @@ -71,7 +76,7 @@ def same_set(cls, a, b): return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b)) @classmethod - def set_from(cls, marks): + def set_from(cls, marks: Optional[List["Mark"]]) -> List["Mark"]: if not marks: return cls.none if isinstance(marks, cls): diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 0ae9fed..56f0cf8 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,8 +1,8 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, cast from prosemirror.utils import JSON, JSONDict -from .fragment import Fragment +from .fragment import Fragment if TYPE_CHECKING: from .node import Node, TextNode diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index a9627d1..d2fecf4 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -68,7 +68,7 @@ def serialize_fragment( tgt: DocumentFragment = target or DocumentFragment(children=[]) top = tgt - active: Optional[List[DocumentFragment]] = None + active: Optional[List[Tuple[Mark, DocumentFragment]]] = None def each(node: Node, *_): nonlocal top, active @@ -84,22 +84,20 @@ def each(node: Node, *_): rendered += 1 continue if ( - not next.eq(active[keep]) + not next.eq(active[keep][0]) or next.type.spec.get("spanning") is False ): break - keep += 2 + keep += 1 rendered += 1 while keep < len(active): - top = active.pop() - active.pop() + top = active.pop()[1] while rendered < len(node.marks): add = node.marks[rendered] rendered += 1 mark_dom = self.serialize_mark(add, node.is_inline) if mark_dom: - active.append(add) # type: ignore - active.append(top) + active.append((add, top)) top.children.append(mark_dom[0]) top = cast(DocumentFragment, mark_dom[1] or mark_dom[0]) top.children.append(self.serialize_node_inner(node)) From 63caf8645932a6dd5e3a2fa479a9fad2e855d2e1 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 11:20:15 -0400 Subject: [PATCH 11/40] typing for to_dom --- prosemirror/model/to_dom.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index d2fecf4..cc59e67 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -2,15 +2,19 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from . import Fragment, Mark, Node, Schema +from prosemirror.model.fragment import Fragment +from prosemirror.model.mark import Mark +from prosemirror.model.node import Node +from prosemirror.model.schema import Schema HTMLNode = Union["Element", str] class DocumentFragment: - def __init__(self, children: List[HTMLNode]): + def __init__(self, children: List[HTMLNode]) -> None: self.children = children - def __str__(self): + def __str__(self) -> str: return "".join([str(c) for c in self.children]) @@ -35,12 +39,14 @@ class Element(DocumentFragment): ] ) - def __init__(self, name: str, attrs: Dict[str, str], children: List[HTMLNode]): + def __init__( + self, name: str, attrs: Dict[str, str], children: List[HTMLNode] + ) -> None: self.name = name self.attrs = attrs super().__init__(children) - def __str__(self): + def __str__(self) -> str: attrs_str = " ".join([f'{k}="{html.escape(v)}"' for k, v in self.attrs.items()]) open_tag_str = " ".join([s for s in [self.name, attrs_str] if s]) if self.name in self.self_closing_elements: @@ -58,7 +64,7 @@ def __init__( self, nodes: Dict[str, Callable[[Node], HTMLOutputSpec]], marks: Dict[str, Callable[[Mark, bool], HTMLOutputSpec]], - ): + ) -> None: self.nodes = nodes self.marks = marks @@ -70,7 +76,7 @@ def serialize_fragment( top = tgt active: Optional[List[Tuple[Mark, DocumentFragment]]] = None - def each(node: Node, *_): + def each(node: Node, offset: int, index: int) -> None: nonlocal top, active if active or node.marks: @@ -175,18 +181,18 @@ def from_schema(cls, schema: Schema) -> "DOMSerializer": return cls(cls.nodes_from_schema(schema), cls.marks_from_schema(schema)) @classmethod - def nodes_from_schema(cls, schema: Schema): + def nodes_from_schema(cls, schema: Schema) -> Dict[str, Callable]: result = gather_to_dom(schema.nodes) if "text" not in result: result["text"] = lambda node: node.text return result @classmethod - def marks_from_schema(cls, schema: Schema): + def marks_from_schema(cls, schema: Schema) -> Dict[str, Callable]: return gather_to_dom(schema.marks) -def gather_to_dom(obj: Dict[str, Any]): +def gather_to_dom(obj: Dict[str, Any]) -> Dict[str, Callable]: result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") From aa202943d5bacd2c06f37cb912267211da932a7a Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 11:44:03 -0400 Subject: [PATCH 12/40] Fix strict issues --- prosemirror/model/content.py | 1 + prosemirror/model/fragment.py | 4 ++-- prosemirror/model/from_dom.py | 25 ++++++++++++------------- prosemirror/model/node.py | 8 ++++---- prosemirror/model/replace.py | 2 +- prosemirror/model/schema.py | 12 ++++++------ prosemirror/model/to_dom.py | 35 ++++++++++++++++++++++++----------- 7 files changed, 50 insertions(+), 37 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 01792d1..80aa837 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -209,6 +209,7 @@ def iteratee(m: "ContentMatch", i: int) -> str: class TokenStream: inline: Optional[bool] + tokens: List[str] def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: self.string = string diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 91b54f4..bc830d8 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -71,7 +71,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable, str] = "", + leaf_text: Union[Callable[["Node"], str], str] = "", ) -> str: text = [] separated = True @@ -250,7 +250,7 @@ def to_json(self) -> JSON: return None @classmethod - def from_json(cls, schema: "Schema", value: Any) -> "Fragment": + def from_json(cls, schema: "Schema[str, str]", value: Any) -> "Fragment": if not value: return cls.empty if isinstance(value, str): diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 9d8bc22..f297f9b 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -57,7 +57,7 @@ class ParseRule: attrs: Optional[Attrs] get_attrs: Optional[Callable[[DOMNode], Union[None, Attrs, Literal[False]]]] content_element: Union[str, DOMNode, Callable[[DOMNode], DOMNode], None] - get_content: Optional[Callable[[DOMNode, Schema], Fragment]] + get_content: Optional[Callable[[DOMNode, Schema[str, str]], Fragment]] preserve_whitespace: WSType @classmethod @@ -88,10 +88,10 @@ class DOMParser: _styles: List[ParseRule] _normalize_lists: bool - schema: Schema + schema: Schema[str, str] rules: List[ParseRule] - def __init__(self, schema: Schema, rules: List[ParseRule]) -> None: + def __init__(self, schema: Schema[str, str], rules: List[ParseRule]) -> None: self.schema = schema self.rules = rules self._tags = [rule for rule in rules if rule.tag is not None] @@ -209,7 +209,7 @@ def match_style( return None @classmethod - def schema_rules(cls, schema: Schema) -> List[ParseRule]: + def schema_rules(cls, schema: Schema[str, str]) -> List[ParseRule]: result: List[ParseRule] = [] def insert(rule: ParseRule) -> None: @@ -253,13 +253,13 @@ def insert(rule: ParseRule) -> None: return result @classmethod - def from_schema(cls, schema: Schema) -> "DOMParser": + def from_schema(cls, schema: Schema[str, str]) -> "DOMParser": if "dom_parser" not in schema.cached: schema.cached["dom_parser"] = DOMParser( schema, DOMParser.schema_rules(schema) ) - return schema.cached["dom_parser"] + return cast("DOMParser", schema.cached["dom_parser"]) BLOCK_TAGS: Dict[str, bool] = { @@ -412,11 +412,10 @@ def finish(self, open_end: bool) -> Union[Node, Fragment]: m = re.findall(r"[ \t\r\n\u000c]+$", last.text) if m: - text = cast(TextNode, last) if len(last.text) == len(m[0]): self.content.pop() else: - self.content[-1] = text.with_text(text.text[0 : -len(m[0])]) + self.content[-1] = last.with_text(last.text[0 : -len(m[0])]) content = Fragment.from_(self.content) if not open_end and self.match is not None: @@ -547,7 +546,7 @@ def add_text_node(self, dom_: DOMNode) -> None: if ( top.options & OPT_PRESERVE_WS_FULL - or top.inline_context(dom_) # type: ignore + or top.inline_context(dom_) or re.search(r"[^ \t\r\n\u000c]", value) is not None ): if not (top.options & OPT_PRESERVE_WS): @@ -585,7 +584,7 @@ def add_text_node(self, dom_: DOMNode) -> None: self.find_in_text(dom_) else: - self.find_inside(dom_) # type: ignore + self.find_inside(dom_) def add_element( self, dom_: DOMNode, match_after: Optional[ParseRule] = None @@ -670,7 +669,7 @@ def read_styles(self, styles: List[str]) -> Optional[Tuple[List[Mark], List[Mark remove = m.add_to_set(remove) else: add = ( - self.parser.schema.marks[rule.mark] + self.parser.schema.marks[cast(str, rule.mark)] .create(rule.attrs) .add_to_set(add) ) @@ -1041,7 +1040,7 @@ def normalize_list(dom_: DOMNode) -> None: prev_item = None while child is not None: - name = child.tag.lower() if get_node_type(child) == 1 else None # type: ignore + name = child.tag.lower() if get_node_type(child) == 1 else None if name and name in LIST_TAGS and prev_item: prev_item.append(child) @@ -1171,7 +1170,7 @@ def get_node_type(element: DOMNode) -> int: return 8 -def from_html(schema: Schema, html: str) -> JSONDict: +def from_html(schema: Schema[str, str], html: str) -> JSONDict: fragment = lxml.html.fragment_fromstring(html, create_parent="document-fragment") prose_doc = DOMParser.from_schema(schema).parse(fragment) diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 9929530..65fe886 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -15,7 +15,7 @@ from .schema import MarkType, NodeType, Schema -empty_attrs: dict = {} +empty_attrs: JSONDict = {} class ChildInfo(TypedDict): @@ -79,7 +79,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable, str] = "", + leaf_text: Union[Callable[["Node"], str], str] = "", ) -> str: return self.content.text_between(from_, to, block_separator, leaf_text) @@ -314,7 +314,7 @@ def to_json(self) -> JSONDict: return obj @classmethod - def from_json(cls, schema: "Schema", json_data: Any) -> "Node": + def from_json(cls, schema: "Schema[str, str]", json_data: Any) -> "Node": if isinstance(json_data, str): import json @@ -368,7 +368,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable, str] = "", + leaf_text: Union[Callable[["Node"], str], str] = "", ) -> str: return self.text[from_:to] diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 56f0cf8..ef451a3 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -96,7 +96,7 @@ def to_json(self) -> JSON: return json @classmethod - def from_json(cls, schema: "Schema", json_data: JSONDict) -> "Slice": + def from_json(cls, schema: "Schema[str, str]", json_data: JSONDict) -> "Slice": if not json_data: return cls.empty open_start = json_data.get("openStart", 0) or 0 diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 1f96fb6..2d37e82 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -254,7 +254,7 @@ class MarkType: instance: Optional[Mark] def __init__( - self, name: str, rank: int, schema: "Schema", spec: "MarkSpec" + self, name: str, rank: int, schema: "Schema[str, str]", spec: "MarkSpec" ) -> None: self.name = name self.schema = schema @@ -296,8 +296,8 @@ def excludes(self, other: "MarkType") -> bool: return any(other.name == e.name for e in self.excluded) -Nodes = TypeVar("Nodes", bound=str) -Marks = TypeVar("Marks", bound=str) +Nodes = TypeVar("Nodes", bound=str, covariant=True) +Marks = TypeVar("Marks", bound=str, covariant=True) class SchemaSpec(TypedDict, Generic[Nodes, Marks]): @@ -368,13 +368,13 @@ class AttributeSpec(TypedDict, total=False): class Schema(Generic[Nodes, Marks]): - spec: SchemaSpec + spec: SchemaSpec[Nodes, Marks] nodes: Dict[Nodes, "NodeType"] marks: Dict[Marks, "MarkType"] - def __init__(self, spec: SchemaSpec) -> None: + def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: self.spec = spec self.nodes = NodeType.compile(self.spec["nodes"], self) self.marks = MarkType.compile(self.spec.get("marks", {}), self) @@ -459,7 +459,7 @@ def node_type(self, name: str) -> NodeType: return found -def gather_marks(schema: Schema, marks: List[str]) -> List[MarkType]: +def gather_marks(schema: Schema[str, str], marks: List[str]) -> List[MarkType]: found = [] for name in marks: mark = schema.marks.get(name) diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index cc59e67..8f73889 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -1,11 +1,20 @@ import html -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast - -from . import Fragment, Mark, Node, Schema -from prosemirror.model.fragment import Fragment -from prosemirror.model.mark import Mark -from prosemirror.model.node import Node -from prosemirror.model.schema import Schema +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from .fragment import Fragment +from .mark import Mark +from .node import Node +from .schema import Schema HTMLNode = Union["Element", str] @@ -177,22 +186,26 @@ def render_spec( return dom, content_dom @classmethod - def from_schema(cls, schema: Schema) -> "DOMSerializer": + def from_schema(cls, schema: Schema[str, str]) -> "DOMSerializer": return cls(cls.nodes_from_schema(schema), cls.marks_from_schema(schema)) @classmethod - def nodes_from_schema(cls, schema: Schema) -> Dict[str, Callable]: + def nodes_from_schema( + cls, schema: Schema[str, str] + ) -> Dict[str, Callable[["Node"], HTMLOutputSpec]]: result = gather_to_dom(schema.nodes) if "text" not in result: result["text"] = lambda node: node.text return result @classmethod - def marks_from_schema(cls, schema: Schema) -> Dict[str, Callable]: + def marks_from_schema( + cls, schema: Schema[str, str] + ) -> Dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: return gather_to_dom(schema.marks) -def gather_to_dom(obj: Dict[str, Any]) -> Dict[str, Callable]: +def gather_to_dom(obj: Dict[str, Any]) -> Dict[str, Callable]: # type: ignore result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") From 4fbf21332c9dc2e49a771cae3686847e4b2e7e94 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 11:55:25 -0400 Subject: [PATCH 13/40] Fix typing in tests --- tests/prosemirror_model/tests/test_node.py | 5 ++++- tests/prosemirror_transform/tests/test_trans.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/prosemirror_model/tests/test_node.py b/tests/prosemirror_model/tests/test_node.py index d2d3f91..a17d194 100644 --- a/tests/prosemirror_model/tests/test_node.py +++ b/tests/prosemirror_model/tests/test_node.py @@ -1,3 +1,4 @@ +from typing import Literal from prosemirror.model import Fragment, Schema from prosemirror.test_builder import eq, out from prosemirror.test_builder import test_schema as schema @@ -15,7 +16,9 @@ hr = out["hr"] img = out["img"] -custom_schema = Schema( +custom_schema: Schema[ + Literal["doc", "paragraph", "text", "contact", "hard_break"], str +] = Schema( { "nodes": { "doc": {"content": "paragraph+"}, diff --git a/tests/prosemirror_transform/tests/test_trans.py b/tests/prosemirror_transform/tests/test_trans.py index 573a3d8..c82f625 100644 --- a/tests/prosemirror_transform/tests/test_trans.py +++ b/tests/prosemirror_transform/tests/test_trans.py @@ -846,7 +846,7 @@ class TestTopLevelMarkReplace: { "nodes": { **schema.spec["nodes"], - **({"doc": {**schema.spec["nodes"].get("doc"), "marks": "_"}}), + "doc": {**schema.spec["nodes"]["doc"], "marks": "_"}, # type: ignore }, "marks": schema.spec["marks"], } @@ -889,7 +889,7 @@ class TestEnforcingHeadingAndBody: nodes_sepc = schema.spec["nodes"].copy() nodes_sepc.update( { - "doc": {**nodes_sepc["doc"], "content": "heading body"}, + "doc": {**nodes_sepc["doc"], "content": "heading body"}, # type: ignore "body": {"content": "block+"}, } ) From 455624aa14099d4f2207088e1c1488f00c992c53 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 12:30:25 -0400 Subject: [PATCH 14/40] Additional typing fixes --- prosemirror/model/fragment.py | 3 ++- prosemirror/model/from_dom.py | 2 +- prosemirror/model/mark.py | 2 +- prosemirror/model/node.py | 2 +- prosemirror/model/replace.py | 2 +- prosemirror/model/resolvedpos.py | 6 +++-- prosemirror/model/schema.py | 5 +--- prosemirror/model/to_dom.py | 46 +++++++++++++++++--------------- pyproject.toml | 10 +++++++ 9 files changed, 45 insertions(+), 33 deletions(-) diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index bc830d8..a0ac503 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -265,7 +265,8 @@ def from_json(cls, schema: "Schema[str, str]", value: Any) -> "Fragment": def from_array(cls, array: List["Node"]) -> "Fragment": if not array: return cls.empty - joined, size = None, 0 + joined: Optional[List["Node"]] = None + size = 0 for i in range(len(array)): node = array[i] size += node.node_size diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index f297f9b..9336d4e 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -1037,7 +1037,7 @@ def remove_pending_mark(self, mark: Mark, upto: NodeContext) -> None: def normalize_list(dom_: DOMNode) -> None: child = next(iter(dom_)) - prev_item = None + prev_item: Optional[DOMNode] = None while child is not None: name = child.tag.lower() if get_node_type(child) == 1 else None diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index c3fc007..842f316 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -14,7 +14,7 @@ def __init__(self, type: "MarkType", attrs: JSONDict) -> None: self.attrs = attrs def add_to_set(self, set: List["Mark"]) -> List["Mark"]: - copy = None + copy: Optional[List["Mark"]] = None placed = False for i in range(len(set)): other = set[i] diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 65fe886..f6e62d8 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -271,7 +271,7 @@ def can_replace( return True def can_replace_with( - self, from_: int, to: int, type: "NodeType", marks: None = None + self, from_: int, to: int, type: "NodeType", marks: Optional[List[Mark]] = None ) -> bool: if marks and not self.type.allows_marks(marks): return False diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index ef451a3..f08dcda 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -34,7 +34,7 @@ def remove_range(content: Fragment, from_: int, to: int) -> Fragment: def insert_into( - content: Fragment, dist: int, insert: Fragment, parent: None + content: Fragment, dist: int, insert: Fragment, parent: Optional["Node"] ) -> Optional[Fragment]: a = content.find_index(dist) index, offset = a["index"], a["offset"] diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index a26e4a2..6077286 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Union, cast, Callable from .mark import Mark @@ -144,7 +144,9 @@ def shared_depth(self, pos: int) -> int: return 0 def block_range( - self, other: Optional["ResolvedPos"] = None, pred: None = None + self, + other: Optional["ResolvedPos"] = None, + pred: Optional[Callable[["Node"], bool]] = None, ) -> Optional["NodeRange"]: if other is None: other = self diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 2d37e82..bb3a8d6 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -22,9 +22,6 @@ from .content import ContentMatch -if TYPE_CHECKING: - pass - Attrs: TypeAlias = JSONDict @@ -197,7 +194,7 @@ def allows_marks(self, marks: List[Mark]) -> bool: def allowed_marks(self, marks: List[Mark]) -> List[Mark]: if self.mark_set is None: return marks - copy = None + copy: Optional[List[Mark]] = None for i, mark in enumerate(marks): if not self.allows_mark_type(mark.type): if not copy: diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 8f73889..46cb898 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -8,6 +8,7 @@ Sequence, Tuple, Union, + Set, cast, ) @@ -27,27 +28,28 @@ def __str__(self) -> str: return "".join([str(c) for c in self.children]) -class Element(DocumentFragment): - self_closing_elements = frozenset( - [ - "area", - "base", - "br", - "col", - "embed", - "hr", - "img", - "input", - "keygen", - "link", - "meta", - "param", - "source", - "track", - "wbr", - ] - ) +SELF_CLOSING_ELEMENTS = frozenset( + { + "area", + "base", + "br", + "col", + "embed", + "hr", + "img", + "input", + "keygen", + "link", + "meta", + "param", + "source", + "track", + "wbr", + } +) + +class Element(DocumentFragment): def __init__( self, name: str, attrs: Dict[str, str], children: List[HTMLNode] ) -> None: @@ -58,7 +60,7 @@ def __init__( def __str__(self) -> str: attrs_str = " ".join([f'{k}="{html.escape(v)}"' for k, v in self.attrs.items()]) open_tag_str = " ".join([s for s in [self.name, attrs_str] if s]) - if self.name in self.self_closing_elements: + if self.name in SELF_CLOSING_ELEMENTS: assert not self.children, "self-closing elements should not have children" return f"<{open_tag_str}>" children_str = "".join([str(c) for c in self.children]) @@ -157,7 +159,7 @@ def render_spec( tag_name = structure[0] if " " in tag_name[1:]: raise NotImplementedError("XML namespaces are not supported") - content_dom = None + content_dom: Optional[Element] = None dom = Element(name=tag_name, attrs={}, children=[]) attrs = structure[1] if len(structure) > 1 else None start = 1 diff --git a/pyproject.toml b/pyproject.toml index 6bea8ce..58c6b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,16 @@ select = [ "RUF", ] +[tool.mypy] +warn_return_any = true +warn_unused_configs = true + +[[tool.mypy.overrides]] +module = "prosemirror.model.*" +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_incomplete_defs = true + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From fcdb113b81a2b17a2c90504edcfc71c92cef62be Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 15:03:01 -0400 Subject: [PATCH 15/40] A few typing and mypyc fixes, remove dataclass --- prosemirror/model/content.py | 43 ++++++++++++------- prosemirror/model/node.py | 6 ++- prosemirror/model/resolvedpos.py | 2 +- prosemirror/model/schema.py | 5 +-- prosemirror/model/to_dom.py | 1 - tests/prosemirror_model/tests/test_node.py | 1 + .../tests/test_structure.py | 2 +- 7 files changed, 37 insertions(+), 23 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 80aa837..d8c110d 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -1,5 +1,4 @@ import re -from dataclasses import dataclass, field from functools import cmp_to_key, reduce from typing import ( TYPE_CHECKING, @@ -22,17 +21,25 @@ from .schema import NodeType -@dataclass class MatchEdge: type: "NodeType" next: "ContentMatch" + def __init__(self, type: "NodeType", next: "ContentMatch") -> None: + self.type = type + self.next = next + -@dataclass class WrapCacheEntry: target: "NodeType" computed: Optional[List["NodeType"]] + def __init__( + self, target: "NodeType", computed: Optional[List["NodeType"]] + ) -> None: + self.target = target + self.computed = computed + class Active(TypedDict): match: "ContentMatch" @@ -40,7 +47,6 @@ class Active(TypedDict): via: Optional["Active"] -@dataclass(eq=False) class ContentMatch: """ Instances of this class represent a match state of a node type's @@ -51,16 +57,21 @@ class ContentMatch: empty: ClassVar["ContentMatch"] valid_end: bool - next: List[MatchEdge] = field(default_factory=list, init=False) - wrap_cache: List[WrapCacheEntry] = field(default_factory=list, init=False) + next: List[MatchEdge] + wrap_cache: List[WrapCacheEntry] + + def __init__(self, valid_end: bool) -> None: + self.valid_end = valid_end + self.next = [] + self.wrap_cache = [] @classmethod def parse(cls, string: str, node_types: Dict[str, "NodeType"]) -> "ContentMatch": stream = TokenStream(string, node_types) - if stream.next is None: + if stream.next() is None: return ContentMatch.empty expr = parse_expr(stream) - if stream.next: + if stream.next() is not None: stream.err("Unexpected trailing text") match = dfa(nfa(expr)) check_for_dead_ends(match, stream) @@ -218,7 +229,6 @@ def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: self.pos = 0 self.tokens = [i for i in TOKEN_REGEX.findall(string) if i.strip()] - @property def next(self) -> Optional[str]: try: return self.tokens[self.pos] @@ -226,7 +236,7 @@ def next(self) -> Optional[str]: return None def eat(self, tok: str) -> Union[int, bool]: - if self.next == tok: + if self.next() == tok: pos = self.pos self.pos += 1 return pos or True @@ -292,7 +302,8 @@ def parse_expr_seq(stream: TokenStream) -> Expr: exprs = [] while True: exprs.append(parse_expr_subscript(stream)) - if not (stream.next and stream.next != ")" and stream.next != "|"): + next_ = stream.next() + if not (next_ and next_ != ")" and next_ != "|"): break if len(exprs) == 1: return exprs[0] @@ -319,7 +330,7 @@ def parse_expr_subscript(stream: TokenStream) -> Expr: def parse_num(stream: TokenStream) -> int: - next = stream.next + next = stream.next() assert next is not None if NUMBER_REGEX.match(next): stream.err(f'Expected number, got "{next}"') @@ -332,7 +343,7 @@ def parse_expr_range(stream: TokenStream, expr: Expr) -> Expr: min_ = parse_num(stream) max_ = min_ if stream.eat(","): - if stream.next != "}": + if stream.next() != "}": max_ = parse_num(stream) else: max_ = -1 @@ -363,7 +374,7 @@ def parse_expr_atom( if not stream.eat(")"): stream.err("missing closing patren") return expr - elif not re.match(r"\W", cast(str, stream.next)): + elif not re.match(r"\W", cast(str, stream.next())): def iteratee(type: "NodeType") -> Expr: nonlocal stream @@ -374,14 +385,14 @@ def iteratee(type: "NodeType") -> Expr: return {"type": "name", "value": type} exprs = [ - iteratee(type) for type in resolve_name(stream, cast(str, stream.next)) + iteratee(type) for type in resolve_name(stream, cast(str, stream.next())) ] stream.pos += 1 if len(exprs) == 1: return exprs[0] return {"type": "choice", "exprs": exprs} else: - stream.err(f'Unexpected token "{stream.next}"') + stream.err(f'Unexpected token "{stream.next()}"') class Edge(TypedDict): diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index f6e62d8..99ab749 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -300,7 +300,11 @@ def check(self) -> None: f"Invalid collection of marks for node {self.type.name}:" f" {[m.type.name for m in self.marks]!r}" ) - return self.content.for_each(lambda node, *args: node.check()) + + def iteratee(node: "Node", offset: int, index: int) -> None: + node.check() + + return self.content.for_each(iteratee) def to_json(self) -> JSONDict: obj: Dict[str, JSON] = {"type": self.type.name} diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index 6077286..2d95f10 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union, cast, Callable +from typing import TYPE_CHECKING, Callable, List, Optional, Union, cast from .mark import Mark diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index bb3a8d6..fc40cc1 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -1,5 +1,4 @@ from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -87,7 +86,7 @@ def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> N self.default_attrs = default_attrs(self.attrs) self.content_match = None # type: ignore[assignment] self.mark_set = None - self.inline_content = None # type: ignore[assignment] + self.inline_content = False self.is_block = not (spec.get("inline") or name == "text") self.is_text = name == "text" @@ -413,7 +412,7 @@ def node( self, type: Union[str, NodeType], attrs: Optional[Attrs] = None, - content: Optional[Union[Fragment, Node]] = None, + content: Optional[Union[Fragment, Node, List[Node]]] = None, marks: Optional[List[Mark]] = None, ) -> Node: if isinstance(type, str): diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 46cb898..88f241a 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -8,7 +8,6 @@ Sequence, Tuple, Union, - Set, cast, ) diff --git a/tests/prosemirror_model/tests/test_node.py b/tests/prosemirror_model/tests/test_node.py index a17d194..d613d32 100644 --- a/tests/prosemirror_model/tests/test_node.py +++ b/tests/prosemirror_model/tests/test_node.py @@ -1,4 +1,5 @@ from typing import Literal + from prosemirror.model import Fragment, Schema from prosemirror.test_builder import eq, out from prosemirror.test_builder import test_schema as schema diff --git a/tests/prosemirror_transform/tests/test_structure.py b/tests/prosemirror_transform/tests/test_structure.py index 0211f4a..b02e0bf 100644 --- a/tests/prosemirror_transform/tests/test_structure.py +++ b/tests/prosemirror_transform/tests/test_structure.py @@ -24,7 +24,7 @@ def n(name, *content): - return schema.nodes[name].create(None, content) + return schema.nodes[name].create(None, list(content)) def t(str, em=None): From 5101e22e3d73f799f53e62f1cd28f2ea6d5e6c34 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 15:19:30 -0400 Subject: [PATCH 16/40] Address comments --- prosemirror/model/content.py | 8 ++++---- prosemirror/model/to_dom.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index d8c110d..95442d4 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -434,11 +434,11 @@ def compile(expr: Expr, from_: int) -> List[Edge]: elif expr["type"] == "seq": i = 0 while True: - nxt = compile(expr["exprs"][i], from_) + next_ = compile(expr["exprs"][i], from_) if i == len(expr["exprs"]) - 1: - return nxt + return next_ from_ = node() - connect(nxt, from_) + connect(next_, from_) i += 1 elif expr["type"] == "star": loop = node() @@ -488,7 +488,7 @@ def scan(n: int) -> None: nonlocal result edges = nfa[n] if len(edges) == 1 and not edges[0].get("term"): - return scan(cast(int, edges[0].get("to"))) + return scan(cast(int, edges[0]["to"])) result.append(n) for edge in edges: term, to = edge.get("term"), edge.get("to") diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 88f241a..94d0dd3 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -206,7 +206,7 @@ def marks_from_schema( return gather_to_dom(schema.marks) -def gather_to_dom(obj: Dict[str, Any]) -> Dict[str, Callable]: # type: ignore +def gather_to_dom(obj: Dict[str, Any]) -> Dict[str, Callable[..., Any]]: result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") From e4d3a9d0c50025b02b7db2c644f97abcaa701078 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 16:03:33 -0400 Subject: [PATCH 17/40] Fix on python 3.8 --- prosemirror/model/schema.py | 3 +-- prosemirror/utils.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index fc40cc1..3b5283c 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -6,13 +6,12 @@ List, Literal, Optional, - TypeAlias, TypeVar, Union, cast, ) -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypeAlias, TypedDict from prosemirror.model.fragment import Fragment from prosemirror.model.mark import Mark diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 110bde1..06a7c0c 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -1,5 +1,4 @@ -from collections.abc import Mapping, Sequence -from typing import Union +from typing import Mapping, Sequence, Union from typing_extensions import TypeAlias From 5f5ce01d69da39d62944f38fcb8f85f1c1e92963 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Fri, 16 Jun 2023 20:38:25 -0400 Subject: [PATCH 18/40] Fix attribute computation --- prosemirror/model/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 3b5283c..af6922d 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -121,7 +121,7 @@ def compatible_content(self, other: "NodeType") -> bool: return self == other or (self.content_match.compatible(other.content_match)) def compute_attrs(self, attrs: Optional[Attrs]) -> Attrs: - if not attrs and self.default_attrs: + if attrs is None and self.default_attrs is not None: return self.default_attrs return compute_attrs(self.attrs, attrs) From c61cf39eb0341b951d9efbceda8daf41854ba481 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 21 Jun 2023 10:32:52 -0400 Subject: [PATCH 19/40] Mutable JSON --- prosemirror/model/fragment.py | 4 ++-- prosemirror/model/mark.py | 10 +++++++--- prosemirror/model/node.py | 16 ++++++++-------- prosemirror/model/replace.py | 6 +++--- prosemirror/utils.py | 9 ++++++++- 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index a0ac503..4af9b3a 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -11,7 +11,7 @@ cast, ) -from prosemirror.utils import JSON, text_length +from prosemirror.utils import MutableJSONList, text_length if TYPE_CHECKING: from prosemirror.model.schema import Schema @@ -244,7 +244,7 @@ def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: i += 1 cur_pos = end - def to_json(self) -> JSON: + def to_json(self) -> Optional[MutableJSONList]: if self.content: return [item.to_json() for item in self.content] return None diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index 842f316..5dcf59e 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,6 +1,7 @@ +import copy from typing import TYPE_CHECKING, Any, Final, List, Optional, cast -from prosemirror.utils import JSONDict +from prosemirror.utils import JSONDict, MutableJSONDict if TYPE_CHECKING: from .schema import MarkType, Schema @@ -50,8 +51,11 @@ def eq(self, other: "Mark") -> bool: return True return self.type.name == other.type.name and self.attrs == other.attrs - def to_json(self) -> JSONDict: - return {"type": self.type.name, "attrs": self.attrs} + def to_json(self) -> MutableJSONDict: + result: MutableJSONDict = {"type": self.type.name} + if self.attrs: + result["attrs"] = cast(MutableJSONDict, copy.deepcopy(self.attrs)) + return result @classmethod def from_json( diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 99ab749..39b1c0a 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypedDict, Union +import copy +from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypedDict, Union, cast from typing_extensions import TypeGuard -from prosemirror.utils import JSON, JSONDict, text_length +from prosemirror.utils import JSONDict, MutableJSONDict, text_length from .comparedeep import compare_deep from .fragment import Fragment @@ -306,11 +307,10 @@ def iteratee(node: "Node", offset: int, index: int) -> None: return self.content.for_each(iteratee) - def to_json(self) -> JSONDict: - obj: Dict[str, JSON] = {"type": self.type.name} - for _ in self.attrs: - obj["attrs"] = self.attrs - break + def to_json(self) -> MutableJSONDict: + obj: MutableJSONDict = {"type": self.type.name} + if self.attrs: + obj["attrs"] = cast(MutableJSONDict, copy.deepcopy(self.attrs)) if getattr(self.content, "size", None): obj["content"] = self.content.to_json() if len(self.marks): @@ -407,7 +407,7 @@ def eq(self, other: Node) -> bool: def to_json( self, - ) -> JSONDict: + ) -> MutableJSONDict: return {**super().to_json(), "text": self.text} diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index f08dcda..502806c 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, cast -from prosemirror.utils import JSON, JSONDict +from prosemirror.utils import JSONDict, MutableJSONDict from .fragment import Fragment @@ -85,10 +85,10 @@ def eq(self, other: "Slice") -> bool: def __str__(self) -> str: return f"{self.content}({self.open_start},{self.open_end})" - def to_json(self) -> JSON: + def to_json(self) -> Optional[MutableJSONDict]: if not self.content.size: return None - json = {"content": self.content.to_json()} + json: MutableJSONDict = {"content": self.content.to_json()} if self.open_start > 0: json["openStart"] = self.open_start if self.open_end > 0: diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 06a7c0c..8af0b86 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -1,4 +1,4 @@ -from typing import Mapping, Sequence, Union +from typing import Mapping, MutableMapping, MutableSequence, Sequence, Union from typing_extensions import TypeAlias @@ -7,6 +7,13 @@ JSON: TypeAlias = Union[JSONDict, JSONList, str, int, float, bool, None] +MutableJSONDict: TypeAlias = MutableMapping[str, "MutableJSON"] +MutableJSONList: TypeAlias = MutableSequence["MutableJSON"] + +MutableJSON: TypeAlias = Union[ + MutableJSONDict, MutableJSONList, str, int, float, bool, None +] + def text_length(text: str) -> int: return len(text.encode("utf-16-le")) // 2 From 0f3b18e60d9cff2a27a4058398de1a0b4f3db3a7 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Mon, 13 Nov 2023 14:32:50 -0500 Subject: [PATCH 20/40] Upgrading to mypy 1.7.0. Making mypy strict for all modules. --- poetry.lock | 59 +++++++++++++++++++++++++------------------------- pyproject.toml | 11 ++-------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0f3284e..3538de6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -462,38 +462,38 @@ files = [ [[package]] name = "mypy" -version = "1.6.0" +version = "1.7.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:091f53ff88cb093dcc33c29eee522c087a438df65eb92acd371161c1f4380ff0"}, - {file = "mypy-1.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb7ff4007865833c470a601498ba30462b7374342580e2346bf7884557e40531"}, - {file = "mypy-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49499cf1e464f533fc45be54d20a6351a312f96ae7892d8e9f1708140e27ce41"}, - {file = "mypy-1.6.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c192445899c69f07874dabda7e931b0cc811ea055bf82c1ababf358b9b2a72c"}, - {file = "mypy-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:3df87094028e52766b0a59a3e46481bb98b27986ed6ded6a6cc35ecc75bb9182"}, - {file = "mypy-1.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c8835a07b8442da900db47ccfda76c92c69c3a575872a5b764332c4bacb5a0a"}, - {file = "mypy-1.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:24f3de8b9e7021cd794ad9dfbf2e9fe3f069ff5e28cb57af6f873ffec1cb0425"}, - {file = "mypy-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:856bad61ebc7d21dbc019b719e98303dc6256cec6dcc9ebb0b214b81d6901bd8"}, - {file = "mypy-1.6.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:89513ddfda06b5c8ebd64f026d20a61ef264e89125dc82633f3c34eeb50e7d60"}, - {file = "mypy-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:9f8464ed410ada641c29f5de3e6716cbdd4f460b31cf755b2af52f2d5ea79ead"}, - {file = "mypy-1.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:971104bcb180e4fed0d7bd85504c9036346ab44b7416c75dd93b5c8c6bb7e28f"}, - {file = "mypy-1.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab98b8f6fdf669711f3abe83a745f67f50e3cbaea3998b90e8608d2b459fd566"}, - {file = "mypy-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a69db3018b87b3e6e9dd28970f983ea6c933800c9edf8c503c3135b3274d5ad"}, - {file = "mypy-1.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:dccd850a2e3863891871c9e16c54c742dba5470f5120ffed8152956e9e0a5e13"}, - {file = "mypy-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:f8598307150b5722854f035d2e70a1ad9cc3c72d392c34fffd8c66d888c90f17"}, - {file = "mypy-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fea451a3125bf0bfe716e5d7ad4b92033c471e4b5b3e154c67525539d14dc15a"}, - {file = "mypy-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e28d7b221898c401494f3b77db3bac78a03ad0a0fff29a950317d87885c655d2"}, - {file = "mypy-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4b7a99275a61aa22256bab5839c35fe8a6887781862471df82afb4b445daae6"}, - {file = "mypy-1.6.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7469545380dddce5719e3656b80bdfbb217cfe8dbb1438532d6abc754b828fed"}, - {file = "mypy-1.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:7807a2a61e636af9ca247ba8494031fb060a0a744b9fee7de3a54bed8a753323"}, - {file = "mypy-1.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2dad072e01764823d4b2f06bc7365bb1d4b6c2f38c4d42fade3c8d45b0b4b67"}, - {file = "mypy-1.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b19006055dde8a5425baa5f3b57a19fa79df621606540493e5e893500148c72f"}, - {file = "mypy-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31eba8a7a71f0071f55227a8057468b8d2eb5bf578c8502c7f01abaec8141b2f"}, - {file = "mypy-1.6.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e0db37ac4ebb2fee7702767dfc1b773c7365731c22787cb99f507285014fcaf"}, - {file = "mypy-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:c69051274762cccd13498b568ed2430f8d22baa4b179911ad0c1577d336ed849"}, - {file = "mypy-1.6.0-py3-none-any.whl", hash = "sha256:9e1589ca150a51d9d00bb839bfeca2f7a04f32cd62fad87a847bc0818e15d7dc"}, - {file = "mypy-1.6.0.tar.gz", hash = "sha256:4f3d27537abde1be6d5f2c96c29a454da333a2a271ae7d5bc7110e6d4b7beb3f"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, + {file = "mypy-1.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b633f188fc5ae1b6edca39dae566974d7ef4e9aaaae00bc36efe1f855e5173ac"}, + {file = "mypy-1.7.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d6ed9a3997b90c6f891138e3f83fb8f475c74db4ccaa942a1c7bf99e83a989a1"}, + {file = "mypy-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:1fe46e96ae319df21359c8db77e1aecac8e5949da4773c0274c0ef3d8d1268a9"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:df67fbeb666ee8828f675fee724cc2cbd2e4828cc3df56703e02fe6a421b7401"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a79cdc12a02eb526d808a32a934c6fe6df07b05f3573d210e41808020aed8b5d"}, + {file = "mypy-1.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f65f385a6f43211effe8c682e8ec3f55d79391f70a201575def73d08db68ead1"}, + {file = "mypy-1.7.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e81ffd120ee24959b449b647c4b2fbfcf8acf3465e082b8d58fd6c4c2b27e46"}, + {file = "mypy-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:f29386804c3577c83d76520abf18cfcd7d68264c7e431c5907d250ab502658ee"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:87c076c174e2c7ef8ab416c4e252d94c08cd4980a10967754f91571070bf5fbe"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cb8d5f6d0fcd9e708bb190b224089e45902cacef6f6915481806b0c77f7786d"}, + {file = "mypy-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93e76c2256aa50d9c82a88e2f569232e9862c9982095f6d54e13509f01222fc"}, + {file = "mypy-1.7.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cddee95dea7990e2215576fae95f6b78a8c12f4c089d7e4367564704e99118d3"}, + {file = "mypy-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:d01921dbd691c4061a3e2ecdbfbfad029410c5c2b1ee88946bf45c62c6c91210"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:185cff9b9a7fec1f9f7d8352dff8a4c713b2e3eea9c6c4b5ff7f0edf46b91e41"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7b1e399c47b18feb6f8ad4a3eef3813e28c1e871ea7d4ea5d444b2ac03c418"}, + {file = "mypy-1.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9fe455ad58a20ec68599139ed1113b21f977b536a91b42bef3ffed5cce7391"}, + {file = "mypy-1.7.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d0fa29919d2e720c8dbaf07d5578f93d7b313c3e9954c8ec05b6d83da592e5d9"}, + {file = "mypy-1.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b53655a295c1ed1af9e96b462a736bf083adba7b314ae775563e3fb4e6795f5"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c1b06b4b109e342f7dccc9efda965fc3970a604db70f8560ddfdee7ef19afb05"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bf7a2f0a6907f231d5e41adba1a82d7d88cf1f61a70335889412dec99feeb0f8"}, + {file = "mypy-1.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551d4a0cdcbd1d2cccdcc7cb516bb4ae888794929f5b040bb51aae1846062901"}, + {file = "mypy-1.7.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55d28d7963bef00c330cb6461db80b0b72afe2f3c4e2963c99517cf06454e665"}, + {file = "mypy-1.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:870bd1ffc8a5862e593185a4c169804f2744112b4a7c55b93eb50f48e7a77010"}, + {file = "mypy-1.7.0-py3-none-any.whl", hash = "sha256:96650d9a4c651bc2a4991cf46f100973f656d69edc7faf91844e87fe627f7e96"}, + {file = "mypy-1.7.0.tar.gz", hash = "sha256:1e280b5697202efa698372d2f39e9a6713a0395a756b1c6bd48995f8d72690dc"}, ] [package.dependencies] @@ -504,6 +504,7 @@ typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] install-types = ["pip"] +mypyc = ["setuptools (>=50)"] reports = ["lxml"] [[package]] @@ -804,4 +805,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4" -content-hash = "2ae3228009bd37b3f0129532d27511794ccb98f9e4cdd7b8e08c546fca82b45e" +content-hash = "9c9ced24f9cbf5a963444de64b43c3e48e95a2c17608c93557fe971d252f08d7" diff --git a/pyproject.toml b/pyproject.toml index 58c6b6b..f212025 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ coverage = "^7.0.5" flake8 = "^6.0.0" isort = "^5.11.4" lxml-stubs = "^0.4.0" -mypy = "^1.3.0" +mypy = "1.7.0" pandoc = "^2.3" pydash = "^7.0.3" pytest = "^7.2.1" @@ -43,14 +43,7 @@ select = [ ] [tool.mypy] -warn_return_any = true -warn_unused_configs = true - -[[tool.mypy.overrides]] -module = "prosemirror.model.*" -disallow_untyped_defs = true -disallow_untyped_calls = true -disallow_incomplete_defs = true +strict = true [build-system] requires = ["poetry-core>=1.0.0"] From f1bda0000d6fb2ef8b947ceba435ff31b096dd6e Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Mon, 13 Nov 2023 14:33:43 -0500 Subject: [PATCH 21/40] Removing mutable json types. Creating new immutable instances instead of mutating existing ones. --- prosemirror/utils.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 8af0b86..270b036 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -1,18 +1,11 @@ -from typing import Mapping, MutableMapping, MutableSequence, Sequence, Union +from typing import Mapping, Sequence from typing_extensions import TypeAlias JSONDict: TypeAlias = Mapping[str, "JSON"] JSONList: TypeAlias = Sequence["JSON"] -JSON: TypeAlias = Union[JSONDict, JSONList, str, int, float, bool, None] - -MutableJSONDict: TypeAlias = MutableMapping[str, "MutableJSON"] -MutableJSONList: TypeAlias = MutableSequence["MutableJSON"] - -MutableJSON: TypeAlias = Union[ - MutableJSONDict, MutableJSONList, str, int, float, bool, None -] +JSON: TypeAlias = JSONDict | JSONList | str | int | float | bool | None def text_length(text: str) -> int: From 47ab7535757aa96fa5c84bd2dab5176d9567c0ea Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Mon, 13 Nov 2023 14:34:18 -0500 Subject: [PATCH 22/40] Using the built-in types for type annotations instead of the ones from the typing module if possible. --- prosemirror/model/__init__.py | 3 +- prosemirror/model/content.py | 90 +++++++--------- prosemirror/model/diff.py | 8 +- prosemirror/model/fragment.py | 57 +++++----- prosemirror/model/from_dom.py | 178 +++++++++++++++---------------- prosemirror/model/mark.py | 35 +++--- prosemirror/model/node.py | 98 +++++++++-------- prosemirror/model/replace.py | 44 +++++--- prosemirror/model/resolvedpos.py | 36 +++---- prosemirror/model/schema.py | 122 ++++++++++----------- prosemirror/model/to_dom.py | 36 +++---- 11 files changed, 354 insertions(+), 353 deletions(-) diff --git a/prosemirror/model/__init__.py b/prosemirror/model/__init__.py index 4c119e4..f255024 100644 --- a/prosemirror/model/__init__.py +++ b/prosemirror/model/__init__.py @@ -5,7 +5,7 @@ from .node import Node from .replace import ReplaceError, Slice from .resolvedpos import NodeRange, ResolvedPos -from .schema import MarkType, NodeType, Schema +from .schema import Attrs, MarkType, NodeType, Schema from .to_dom import DOMSerializer __all__ = [ @@ -16,6 +16,7 @@ "Slice", "ReplaceError", "Mark", + "Attrs", "Schema", "NodeType", "MarkType", diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 95442d4..93dc67b 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -3,14 +3,10 @@ from typing import ( TYPE_CHECKING, ClassVar, - Dict, - List, Literal, NamedTuple, NoReturn, - Optional, TypedDict, - Union, cast, ) @@ -32,19 +28,17 @@ def __init__(self, type: "NodeType", next: "ContentMatch") -> None: class WrapCacheEntry: target: "NodeType" - computed: Optional[List["NodeType"]] + computed: list["NodeType"] | None - def __init__( - self, target: "NodeType", computed: Optional[List["NodeType"]] - ) -> None: + def __init__(self, target: "NodeType", computed: list["NodeType"] | None) -> None: self.target = target self.computed = computed class Active(TypedDict): match: "ContentMatch" - type: Optional["NodeType"] - via: Optional["Active"] + type: "NodeType | None" + via: "Active | None" class ContentMatch: @@ -57,8 +51,8 @@ class ContentMatch: empty: ClassVar["ContentMatch"] valid_end: bool - next: List[MatchEdge] - wrap_cache: List[WrapCacheEntry] + next: list[MatchEdge] + wrap_cache: list[WrapCacheEntry] def __init__(self, valid_end: bool) -> None: self.valid_end = valid_end @@ -66,7 +60,7 @@ def __init__(self, valid_end: bool) -> None: self.wrap_cache = [] @classmethod - def parse(cls, string: str, node_types: Dict[str, "NodeType"]) -> "ContentMatch": + def parse(cls, string: str, node_types: dict[str, "NodeType"]) -> "ContentMatch": stream = TokenStream(string, node_types) if stream.next() is None: return ContentMatch.empty @@ -77,18 +71,18 @@ def parse(cls, string: str, node_types: Dict[str, "NodeType"]) -> "ContentMatch" check_for_dead_ends(match, stream) return match - def match_type(self, type: "NodeType") -> Optional["ContentMatch"]: + def match_type(self, type: "NodeType") -> "ContentMatch | None": for next in self.next: if next.type.name == type.name: return next.next return None def match_fragment( - self, frag: Fragment, start: int = 0, end: Optional[int] = None - ) -> Optional["ContentMatch"]: + self, frag: Fragment, start: int = 0, end: int | None = None + ) -> "ContentMatch | None": if end is None: end = frag.child_count - cur: Optional["ContentMatch"] = self + cur: "ContentMatch | None" = self i = start while cur and i < end: cur = cur.match_type(frag.child(i).type) @@ -100,7 +94,7 @@ def inline_content(self) -> bool: return bool(self.next) and self.next[0].type.is_inline @property - def default_type(self) -> Optional["NodeType"]: + def default_type(self) -> "NodeType | None": for next in self.next: type = next.type if not (type.is_text or type.has_required_attrs()): @@ -116,10 +110,10 @@ def compatible(self, other: "ContentMatch") -> bool: def fill_before( self, after: Fragment, to_end: bool = False, start_index: int = 0 - ) -> Optional[Fragment]: + ) -> Fragment | None: seen = [self] - def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: + def search(match: ContentMatch, types: list["NodeType"]) -> Fragment | None: nonlocal seen finished = match.match_fragment(after, start_index) if finished and (not to_end or finished.valid_end): @@ -138,7 +132,7 @@ def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: return search(self, []) - def find_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: + def find_wrapping(self, target: "NodeType") -> list["NodeType"] | None: for entry in self.wrap_cache: if entry.target.name == target.name: return entry.computed @@ -146,9 +140,9 @@ def find_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: self.wrap_cache.append(WrapCacheEntry(target, computed)) return computed - def compute_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: + def compute_wrapping(self, target: "NodeType") -> list["NodeType"] | None: seen = {} - active: List[Active] = [{"match": self, "type": None, "via": None}] + active: list[Active] = [{"match": self, "type": None, "via": None}] while len(active): current = active.pop(0) match = current["match"] @@ -219,23 +213,23 @@ def iteratee(m: "ContentMatch", i: int) -> str: class TokenStream: - inline: Optional[bool] - tokens: List[str] + inline: bool | None + tokens: list[str] - def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: + def __init__(self, string: str, node_types: dict[str, "NodeType"]) -> None: self.string = string self.node_types = node_types self.inline = None self.pos = 0 self.tokens = [i for i in TOKEN_REGEX.findall(string) if i.strip()] - def next(self) -> Optional[str]: + def next(self) -> str | None: try: return self.tokens[self.pos] except IndexError: return None - def eat(self, tok: str) -> Union[int, bool]: + def eat(self, tok: str) -> int | bool: if self.next() == tok: pos = self.pos self.pos += 1 @@ -249,12 +243,12 @@ def err(self, str: str) -> NoReturn: class ChoiceExpr(TypedDict): type: Literal["choice"] - exprs: List["Expr"] + exprs: list["Expr"] class SeqExpr(TypedDict): type: Literal["seq"] - exprs: List["Expr"] + exprs: list["Expr"] class PlusExpr(TypedDict): @@ -284,7 +278,7 @@ class NameExpr(TypedDict): value: "NodeType" -Expr = Union[ChoiceExpr, SeqExpr, PlusExpr, StarExpr, OptExpr, RangeExpr, NameExpr] +Expr = ChoiceExpr | SeqExpr | PlusExpr | StarExpr | OptExpr | RangeExpr | NameExpr def parse_expr(stream: TokenStream) -> Expr: @@ -352,7 +346,7 @@ def parse_expr_range(stream: TokenStream, expr: Expr) -> Expr: return {"type": "range", "min": min_, "max": max_, "expr": expr} -def resolve_name(stream: TokenStream, name: str) -> List["NodeType"]: +def resolve_name(stream: TokenStream, name: str) -> list["NodeType"]: types = stream.node_types type = types.get(name) if type: @@ -396,39 +390,37 @@ def iteratee(type: "NodeType") -> Expr: class Edge(TypedDict): - term: Optional["NodeType"] - to: Optional[int] + term: "NodeType | None" + to: int | None def nfa( expr: Expr, -) -> List[List[Edge]]: - nfa_: List[List[Edge]] = [[]] +) -> list[list[Edge]]: + nfa_: list[list[Edge]] = [[]] def node() -> int: nonlocal nfa_ nfa_.append([]) return len(nfa_) - 1 - def edge( - from_: int, to: Optional[int] = None, term: Optional["NodeType"] = None - ) -> Edge: + def edge(from_: int, to: int | None = None, term: "NodeType | None" = None) -> Edge: nonlocal nfa_ edge: Edge = {"term": term, "to": to} nfa_[from_].append(edge) return edge - def connect(edges: List[Edge], to: int) -> None: + def connect(edges: list[Edge], to: int) -> None: for edge in edges: edge["to"] = to - def compile(expr: Expr, from_: int) -> List[Edge]: + def compile(expr: Expr, from_: int) -> list[Edge]: if expr["type"] == "choice": return list( reduce( lambda out, expr: [*out, *compile(expr, from_)], expr["exprs"], - cast(List[Edge], []), + cast(list[Edge], []), ) ) elif expr["type"] == "seq": @@ -479,9 +471,9 @@ def cmp(a: int, b: int) -> int: def null_from( - nfa: List[List[Edge]], + nfa: list[list[Edge]], node: int, -) -> List[int]: +) -> list[int]: result = [] def scan(n: int) -> None: @@ -501,21 +493,21 @@ def scan(n: int) -> None: class DFAState(NamedTuple): state: "NodeType" - next: List[int] + next: list[int] -def dfa(nfa: List[List[Edge]]) -> ContentMatch: +def dfa(nfa: list[list[Edge]]) -> ContentMatch: labeled = {} - def explore(states: List[int]) -> ContentMatch: + def explore(states: list[int]) -> ContentMatch: nonlocal labeled - out: List[DFAState] = [] + out: list[DFAState] = [] for node in states: for item in nfa[node]: term, to = item.get("term"), item.get("to") if not term: continue - set: Optional[List[int]] = None + set: list[int] | None = None for t in out: if t[0] == term: set = t[1] diff --git a/prosemirror/model/diff.py b/prosemirror/model/diff.py index e4f7c1a..098988f 100644 --- a/prosemirror/model/diff.py +++ b/prosemirror/model/diff.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import TYPE_CHECKING, TypedDict from prosemirror.utils import text_length @@ -13,7 +13,7 @@ class Diff(TypedDict): b: int -def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: +def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> int | None: i = 0 while True: if a.child_count == i or b.child_count == i: @@ -52,9 +52,7 @@ def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: i += 1 -def find_diff_end( - a: "Fragment", b: "Fragment", pos_a: int, pos_b: int -) -> Optional[Diff]: +def find_diff_end(a: "Fragment", b: "Fragment", pos_a: int, pos_b: int) -> Diff | None: i_a, i_b = a.child_count, b.child_count while True: if i_a == 0 or i_b == 0: diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 4af9b3a..58bf892 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -3,15 +3,11 @@ Any, Callable, ClassVar, - Dict, Iterable, - List, - Optional, - Union, cast, ) -from prosemirror.utils import MutableJSONList, text_length +from prosemirror.utils import JSONList, text_length if TYPE_CHECKING: from prosemirror.model.schema import Schema @@ -20,16 +16,16 @@ from .node import Node, TextNode -def retIndex(index: int, offset: int) -> Dict[str, int]: +def retIndex(index: int, offset: int) -> dict[str, int]: return {"index": index, "offset": offset} class Fragment: empty: ClassVar["Fragment"] - content: List["Node"] + content: list["Node"] size: int - def __init__(self, content: List["Node"], size: Optional[int] = None) -> None: + def __init__(self, content: list["Node"], size: int | None = None) -> None: self.content = content self.size = size if size is not None else sum(c.node_size for c in content) @@ -37,9 +33,9 @@ def nodes_between( self, from_: int, to: int, - f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], + f: Callable[["Node", int, "Node | None", int], bool | None], node_start: int = 0, - parent: Optional["Node"] = None, + parent: "Node | None" = None, ) -> None: i = 0 pos = 0 @@ -62,7 +58,7 @@ def nodes_between( i += 1 def descendants( - self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] + self, f: Callable[["Node", int, "Node | None", int], bool | None] ) -> None: self.nodes_between(0, self.size, f) @@ -71,14 +67,12 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable[["Node"], str], str] = "", + leaf_text: Callable[["Node"], str] | str = "", ) -> str: text = [] separated = True - def iteratee( - node: "Node", pos: int, _parent: Optional["Node"], _to: int - ) -> None: + def iteratee(node: "Node", pos: int, _parent: "Node | None", _to: int) -> None: nonlocal text nonlocal separated if node.is_text: @@ -119,12 +113,12 @@ def append(self, other: "Fragment") -> "Fragment": i += 1 return Fragment(content, self.size + other.size) - def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": + def cut(self, from_: int, to: int | None = None) -> "Fragment": if to is None: to = self.size if from_ == 0 and to == self.size: return self - result: List["Node"] = [] + result: list["Node"] = [] size = 0 if to <= from_: return Fragment(result, size) @@ -149,7 +143,7 @@ def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": i += 1 return Fragment(result, size) - def cut_by_index(self, from_: int, to: Optional[int] = None) -> "Fragment": + def cut_by_index(self, from_: int, to: int | None = None) -> "Fragment": if from_ == to: return Fragment.empty if from_ == 0 and to == len(self.content): @@ -177,11 +171,11 @@ def eq(self, other: "Fragment") -> bool: return all(a.eq(b) for (a, b) in zip(self.content, other.content)) @property - def first_child(self) -> Optional["Node"]: + def first_child(self) -> "Node | None": return self.content[0] if self.content else None @property - def last_child(self) -> Optional["Node"]: + def last_child(self) -> "Node | None": return self.content[-1] if self.content else None @property @@ -191,7 +185,7 @@ def child_count(self) -> int: def child(self, index: int) -> "Node": return self.content[index] - def maybe_child(self, index: int) -> Optional["Node"]: + def maybe_child(self, index: int) -> "Node | None": try: return self.content[index] except IndexError: @@ -206,7 +200,7 @@ def for_each(self, f: Callable[["Node", int, int], Any]) -> None: p += child.node_size i += 1 - def find_diff_start(self, other: "Fragment", pos: int = 0) -> Optional[int]: + def find_diff_start(self, other: "Fragment", pos: int = 0) -> int | None: from .diff import find_diff_start return find_diff_start(self, other, pos) @@ -214,9 +208,9 @@ def find_diff_start(self, other: "Fragment", pos: int = 0) -> Optional[int]: def find_diff_end( self, other: "Fragment", - pos: Optional[int] = None, - other_pos: Optional[int] = None, - ) -> Optional["Diff"]: + pos: int | None = None, + other_pos: int | None = None, + ) -> "Diff | None": from .diff import find_diff_end if pos is None: @@ -225,7 +219,7 @@ def find_diff_end( other_pos = other.size return find_diff_end(self, other, pos, other_pos) - def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: + def find_index(self, pos: int, round: int = -1) -> dict[str, int]: if pos == 0: return retIndex(0, pos) if pos == self.size: @@ -244,7 +238,7 @@ def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: i += 1 cur_pos = end - def to_json(self) -> Optional[MutableJSONList]: + def to_json(self) -> JSONList | None: if self.content: return [item.to_json() for item in self.content] return None @@ -253,19 +247,22 @@ def to_json(self) -> Optional[MutableJSONList]: def from_json(cls, schema: "Schema[str, str]", value: Any) -> "Fragment": if not value: return cls.empty + if isinstance(value, str): import json value = json.loads(value) + if not isinstance(value, list): raise ValueError("Invalid input for Fragment.from_json") + return cls([schema.node_from_json(item) for item in value]) @classmethod - def from_array(cls, array: List["Node"]) -> "Fragment": + def from_array(cls, array: list["Node"]) -> "Fragment": if not array: return cls.empty - joined: Optional[List["Node"]] = None + joined: list["Node"] | None = None size = 0 for i in range(len(array)): node = array[i] @@ -281,7 +278,7 @@ def from_array(cls, array: List["Node"]) -> "Fragment": return cls(joined or array, size) @classmethod - def from_(cls, nodes: Union["Fragment", "Node", List["Node"], None]) -> "Fragment": + def from_(cls, nodes: "Fragment | Node | list[Node] | None") -> "Fragment": if not nodes: return cls.empty if isinstance(nodes, Fragment): diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 9336d4e..8162c69 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -1,7 +1,7 @@ import itertools import re from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, Literal, cast import lxml from lxml.cssselect import CSSSelector @@ -17,51 +17,51 @@ from .resolvedpos import ResolvedPos from .schema import Attrs, MarkType, NodeType, Schema -WSType = Union[bool, Literal["full"], None] +WSType = bool | Literal["full"] | None @dataclass class DOMPosition: node: DOMNode offset: int - pos: Optional[int] = None + pos: int | None = None @dataclass(frozen=True) class ParseOptions: preserve_whitespace: WSType = None - find_positions: Optional[List[DOMPosition]] = None - from_: Optional[int] = None - to_: Optional[int] = None - top_node: Optional[Node] = None - top_match: Optional[ContentMatch] = None - context: Optional[ResolvedPos] = None - rule_from_node: Optional[Callable[[DOMNode], "ParseRule"]] = None - top_open: Optional[bool] = None + find_positions: list[DOMPosition] | None = None + from_: int | None = None + to_: int | None = None + top_node: Node | None = None + top_match: ContentMatch | None = None + context: ResolvedPos | None = None + rule_from_node: Callable[[DOMNode], "ParseRule"] | None = None + top_open: bool | None = None @dataclass class ParseRule: - tag: Optional[str] - namespace: Optional[str] - style: Optional[str] - priority: Optional[int] - consuming: Optional[bool] - context: Optional[str] - node: Optional[str] - mark: Optional[str] - clear_mark: Optional[Callable[[Mark], bool]] - ignore: Optional[bool] - close_parent: Optional[bool] - skip: Optional[bool] - attrs: Optional[Attrs] - get_attrs: Optional[Callable[[DOMNode], Union[None, Attrs, Literal[False]]]] - content_element: Union[str, DOMNode, Callable[[DOMNode], DOMNode], None] - get_content: Optional[Callable[[DOMNode, Schema[str, str]], Fragment]] + tag: str | None + namespace: str | None + style: str | None + priority: int | None + consuming: bool | None + context: str | None + node: str | None + mark: str | None + clear_mark: Callable[[Mark], bool] | None + ignore: bool | None + close_parent: bool | None + skip: bool | None + attrs: Attrs | None + get_attrs: Callable[[DOMNode], None | Attrs | Literal[False]] | None + content_element: str | DOMNode | Callable[[DOMNode], DOMNode] | None + get_content: Callable[[DOMNode, Schema[str, str]], Fragment] | None preserve_whitespace: WSType @classmethod - def from_json(cls, data: Dict[str, Any]) -> "ParseRule": + def from_json(cls, data: dict[str, Any]) -> "ParseRule": return ParseRule( data.get("tag"), data.get("namespace"), @@ -84,14 +84,14 @@ def from_json(cls, data: Dict[str, Any]) -> "ParseRule": class DOMParser: - _tags: List[ParseRule] - _styles: List[ParseRule] + _tags: list[ParseRule] + _styles: list[ParseRule] _normalize_lists: bool schema: Schema[str, str] - rules: List[ParseRule] + rules: list[ParseRule] - def __init__(self, schema: Schema[str, str], rules: List[ParseRule]) -> None: + def __init__(self, schema: Schema[str, str], rules: list[ParseRule]) -> None: self.schema = schema self.rules = rules self._tags = [rule for rule in rules if rule.tag is not None] @@ -109,7 +109,7 @@ def __init__(self, schema: Schema[str, str], rules: List[ParseRule]) -> None: ) def parse( - self, dom_: lxml.html.HtmlElement, options: Optional[ParseOptions] = None + self, dom_: lxml.html.HtmlElement, options: ParseOptions | None = None ) -> Node: if options is None: options = ParseOptions() @@ -134,9 +134,7 @@ def parse( return cast(Node, context.finish()) - def parse_slice( - self, dom_: DOMNode, options: Optional[ParseOptions] = None - ) -> Slice: + def parse_slice(self, dom_: DOMNode, options: ParseOptions | None = None) -> Slice: if options is None: options = ParseOptions(preserve_whitespace=True) @@ -147,8 +145,8 @@ def parse_slice( return Slice.max_open(cast(Fragment, context.finish())) def match_tag( - self, dom_: DOMNode, context: "ParseContext", after: Optional[ParseRule] = None - ) -> Optional[ParseRule]: + self, dom_: DOMNode, context: "ParseContext", after: ParseRule | None = None + ) -> ParseRule | None: try: i = self._tags.index(after) + 1 if after is not None else 0 except ValueError: @@ -179,8 +177,8 @@ def match_style( prop: str, value: str, context: "ParseContext", - after: Optional[ParseRule] = None, - ) -> Optional[ParseRule]: + after: ParseRule | None = None, + ) -> ParseRule | None: i = self._styles.index(after) + 1 if after is not None else 0 for rule in self._styles[i:]: @@ -209,8 +207,8 @@ def match_style( return None @classmethod - def schema_rules(cls, schema: Schema[str, str]) -> List[ParseRule]: - result: List[ParseRule] = [] + def schema_rules(cls, schema: Schema[str, str]) -> list[ParseRule]: + result: list[ParseRule] = [] def insert(rule: ParseRule) -> None: priority = rule.priority if rule.priority is not None else 50 @@ -262,7 +260,7 @@ def from_schema(cls, schema: Schema[str, str]) -> "DOMParser": return cast("DOMParser", schema.cached["dom_parser"]) -BLOCK_TAGS: Dict[str, bool] = { +BLOCK_TAGS: dict[str, bool] = { "address": True, "article": True, "aside": True, @@ -297,7 +295,7 @@ def from_schema(cls, schema: Schema[str, str]) -> "DOMParser": "ul": True, } -IGNORE_TAGS: Dict[str, bool] = { +IGNORE_TAGS: dict[str, bool] = { "head": True, "noscript": True, "object": True, @@ -306,7 +304,7 @@ def from_schema(cls, schema: Schema[str, str]) -> "DOMParser": "title": True, } -LIST_TAGS: Dict[str, bool] = {"ol": True, "ul": True} +LIST_TAGS: dict[str, bool] = {"ol": True, "ul": True} OPT_PRESERVE_WS = 1 @@ -315,7 +313,7 @@ def from_schema(cls, schema: Schema[str, str]) -> "DOMParser": def ws_options_for( - _type: Optional[NodeType], preserve_whitespace: WSType, base: int + _type: NodeType | None, preserve_whitespace: WSType, base: int ) -> int: if preserve_whitespace is not None: return (OPT_PRESERVE_WS if preserve_whitespace else 0) | ( @@ -330,29 +328,29 @@ def ws_options_for( class NodeContext: - match: Optional[ContentMatch] - content: List[Node] + match: ContentMatch | None + content: list[Node] - active_marks: List[Mark] - stash_marks: List[Mark] + active_marks: list[Mark] + stash_marks: list[Mark] - type: Optional[NodeType] + type: NodeType | None options: int - attrs: Optional[Attrs] - marks: List[Mark] - pending_marks: List[Mark] + attrs: Attrs | None + marks: list[Mark] + pending_marks: list[Mark] solid: bool def __init__( self, - _type: Optional[NodeType], - attrs: Optional[Attrs], - marks: List[Mark], - pending_marks: List[Mark], + _type: NodeType | None, + attrs: Attrs | None, + marks: list[Mark], + pending_marks: list[Mark], solid: bool, - match: Optional[ContentMatch], + match: ContentMatch | None, options: int, ) -> None: self.type = _type @@ -376,7 +374,7 @@ def __init__( self.active_marks = Mark.none self.stash_marks = [] - def find_wrapping(self, node: Node) -> Optional[List[NodeType]]: + def find_wrapping(self, node: Node) -> list[NodeType] | None: if not self.match: if not self.type: return [] @@ -400,10 +398,10 @@ def find_wrapping(self, node: Node) -> Optional[List[NodeType]]: return self.match.find_wrapping(node.type) - def finish(self, open_end: bool) -> Union[Node, Fragment]: + def finish(self, open_end: bool) -> Node | Fragment: if not self.options & OPT_PRESERVE_WS: try: - last: Optional[Node] = self.content[-1] + last: Node | None = self.content[-1] except IndexError: last = None @@ -427,8 +425,8 @@ def finish(self, open_end: bool) -> Union[Node, Fragment]: self.type.create(self.attrs, content, self.marks) if self.type else content ) - def pop_from_stash_mark(self, mark: Mark) -> Optional[Mark]: - found_mark: Optional[Mark] = None + def pop_from_stash_mark(self, mark: Mark) -> Mark | None: + found_mark: Mark | None = None for stash_mark in self.stash_marks[::-1]: if mark.eq(stash_mark): found_mark = stash_mark @@ -459,9 +457,9 @@ def inline_context(self, node: DOMNode) -> bool: class ParseContext: open: int = 0 - find: Optional[List[DOMPosition]] + find: list[DOMPosition] | None needs_block: bool - nodes: List[NodeContext] + nodes: list[NodeContext] options: ParseOptions is_open: bool parser: DOMParser @@ -586,9 +584,7 @@ def add_text_node(self, dom_: DOMNode) -> None: else: self.find_inside(dom_) - def add_element( - self, dom_: DOMNode, match_after: Optional[ParseRule] = None - ) -> None: + def add_element(self, dom_: DOMNode, match_after: ParseRule | None = None) -> None: name = dom_.tag.lower() if name in LIST_TAGS and self.parser.normalize_lists: @@ -651,12 +647,12 @@ def ignore_fallback(self, dom_: DOMNode) -> None: ): self.find_place(self.parser.schema.text("-")) - def read_styles(self, styles: List[str]) -> Optional[Tuple[List[Mark], List[Mark]]]: - add: List[Mark] = Mark.none - remove: List[Mark] = Mark.none + def read_styles(self, styles: list[str]) -> tuple[list[Mark], list[Mark]] | None: + add: list[Mark] = Mark.none + remove: list[Mark] = Mark.none for i in range(0, len(styles), 2): - after: Optional[ParseRule] = None + after: ParseRule | None = None while True: rule = self.parser.match_style(styles[i], styles[i + 1], self, after) if not rule: @@ -682,11 +678,11 @@ def read_styles(self, styles: List[str]) -> Optional[Tuple[List[Mark], List[Mark return add, remove def add_element_by_rule( - self, dom_: DOMNode, rule: ParseRule, continue_after: Optional[ParseRule] = None + self, dom_: DOMNode, rule: ParseRule, continue_after: ParseRule | None = None ) -> None: sync: bool = False - mark: Optional[Mark] = None - node_type: Optional[NodeType] = None + mark: Mark | None = None + node_type: NodeType | None = None if rule.node is not None: node_type = self.parser.schema.nodes[rule.node] @@ -732,8 +728,8 @@ def add_element_by_rule( def add_all( self, parent: DOMNode, - start_index: Optional[int] = None, - end_index: Optional[int] = None, + start_index: int | None = None, + end_index: int | None = None, ) -> None: index = start_index if start_index is not None else 0 @@ -754,8 +750,8 @@ def add_all( self.find_at_point(parent, index) def find_place(self, node: Node) -> bool: - route: Optional[List[NodeType]] = None - sync: Optional[NodeContext] = None + route: list[NodeType] | None = None + sync: NodeContext | None = None depth = self.open while depth >= 0: @@ -811,7 +807,7 @@ def insert_node(self, node: Node) -> bool: return False def enter( - self, type_: NodeType, attrs: Optional[Attrs] = None, preserve_ws: WSType = None + self, type_: NodeType, attrs: Attrs | None = None, preserve_ws: WSType = None ) -> bool: ok = self.find_place(type_.create(attrs)) if ok: @@ -822,7 +818,7 @@ def enter( def enter_inner( self, type_: NodeType, - attrs: Optional[Attrs] = None, + attrs: Attrs | None = None, solid: bool = False, preserve_ws: WSType = None, ) -> None: @@ -859,7 +855,7 @@ def close_extra(self, open_end: bool = False) -> None: self.nodes = self.nodes[: self.open + 1] - def finish(self) -> Union[Node, Fragment]: + def finish(self) -> Node | Fragment: self.open = 0 self.close_extra(self.is_open) return self.nodes[0].finish(self.is_open or bool(self.options.top_open)) @@ -954,7 +950,7 @@ def match(i: int, depth: int) -> bool: return False else: if depth > 0 or (depth == 0 and use_root): - next: Optional[NodeType] = self.nodes[depth].type + next: NodeType | None = self.nodes[depth].type elif option is not None and depth >= min_depth: next = option.node(depth - min_depth).type else: @@ -975,7 +971,7 @@ def match(i: int, depth: int) -> bool: return match(len(parts) - 1, self.open) - def textblock_from_context(self) -> Optional[NodeType]: + def textblock_from_context(self) -> NodeType | None: context = self.options.context if context: @@ -1037,7 +1033,7 @@ def remove_pending_mark(self, mark: Mark, upto: NodeContext) -> None: def normalize_list(dom_: DOMNode) -> None: child = next(iter(dom_)) - prev_item: Optional[DOMNode] = None + prev_item: DOMNode | None = None while child is not None: name = child.tag.lower() if get_node_type(child) == 1 else None @@ -1056,12 +1052,12 @@ def normalize_list(dom_: DOMNode) -> None: def matches(dom_: DOMNode, selector_str: str) -> bool: selector = CSSSelector(selector_str) - return bool(dom_ in selector(dom_)) # type: ignore + return bool(dom_ in selector(dom_)) # type: ignore[operator] -def parse_styles(style: str) -> List[str]: +def parse_styles(style: str) -> list[str]: regex = r"\s*([\w-]+)\s*:\s*([^;]+)" - result: List[str] = [] + result: list[str] = [] for m in re.findall(regex, style): result.append(m[0]) @@ -1077,7 +1073,7 @@ def mark_may_apply(mark_type: MarkType, node_type: NodeType) -> bool: if not parent.allows_mark_type(mark_type): continue - seen: List[ContentMatch] = [] + seen: list[ContentMatch] = [] def scan(match: ContentMatch) -> bool: seen.append(match) @@ -1101,7 +1097,7 @@ def scan(match: ContentMatch) -> bool: return False -def find_same_mark_in_set(mark: Mark, mark_set: List[Mark]) -> Optional[Mark]: +def find_same_mark_in_set(mark: Mark, mark_set: list[Mark]) -> Mark | None: for comp in mark_set: if mark.eq(comp): return comp diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index 5dcf59e..4ae33a1 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,21 +1,23 @@ import copy -from typing import TYPE_CHECKING, Any, Final, List, Optional, cast +from typing import TYPE_CHECKING, Any, Final, cast -from prosemirror.utils import JSONDict, MutableJSONDict +from prosemirror.utils import JSONDict + +from .schema import Attrs if TYPE_CHECKING: from .schema import MarkType, Schema class Mark: - none: Final[List["Mark"]] = [] + none: Final[list["Mark"]] = [] - def __init__(self, type: "MarkType", attrs: JSONDict) -> None: + def __init__(self, type: "MarkType", attrs: Attrs) -> None: self.type = type self.attrs = attrs - def add_to_set(self, set: List["Mark"]) -> List["Mark"]: - copy: Optional[List["Mark"]] = None + def add_to_set(self, set: list["Mark"]) -> list["Mark"]: + copy: list["Mark"] | None | None = None placed = False for i in range(len(set)): other = set[i] @@ -40,10 +42,10 @@ def add_to_set(self, set: List["Mark"]) -> List["Mark"]: copy.append(self) return copy - def remove_from_set(self, set: List["Mark"]) -> List["Mark"]: + def remove_from_set(self, set: list["Mark"]) -> list["Mark"]: return [item for item in set if not item.eq(self)] - def is_in_set(self, set: List["Mark"]) -> bool: + def is_in_set(self, set: list["Mark"]) -> bool: return any(item.eq(self) for item in set) def eq(self, other: "Mark") -> bool: @@ -51,10 +53,13 @@ def eq(self, other: "Mark") -> bool: return True return self.type.name == other.type.name and self.attrs == other.attrs - def to_json(self) -> MutableJSONDict: - result: MutableJSONDict = {"type": self.type.name} + def to_json(self) -> JSONDict: + result: JSONDict = {"type": self.type.name} if self.attrs: - result["attrs"] = cast(MutableJSONDict, copy.deepcopy(self.attrs)) + result = { + **result, + "attrs": copy.deepcopy(self.attrs), + } return result @classmethod @@ -69,10 +74,10 @@ def from_json( type = schema.marks.get(name) if not type: raise ValueError(f"There is no mark type {name} in this schema") - return type.create(cast(Optional[JSONDict], json_data.get("attrs"))) + return type.create(cast(JSONDict | None, json_data.get("attrs"))) @classmethod - def same_set(cls, a: List["Mark"], b: List["Mark"]) -> bool: + def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: if a == b: return True if len(a) != len(b): @@ -80,10 +85,10 @@ def same_set(cls, a: List["Mark"], b: List["Mark"]) -> bool: return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b)) @classmethod - def set_from(cls, marks: Optional[List["Mark"]]) -> List["Mark"]: + def set_from(cls, marks: "list[Mark] | Mark | None") -> list["Mark"]: if not marks: return cls.none - if isinstance(marks, cls): + if isinstance(marks, Mark): return [marks] copy = marks[:] return sorted(copy, key=lambda item: item.type.rank) diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 39b1c0a..bb48697 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,9 +1,9 @@ import copy -from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Callable, TypedDict, cast from typing_extensions import TypeGuard -from prosemirror.utils import JSONDict, MutableJSONDict, text_length +from prosemirror.utils import JSONDict, text_length from .comparedeep import compare_deep from .fragment import Fragment @@ -13,14 +13,14 @@ if TYPE_CHECKING: from .content import ContentMatch - from .schema import MarkType, NodeType, Schema + from .schema import Attrs, MarkType, NodeType, Schema empty_attrs: JSONDict = {} class ChildInfo(TypedDict): - node: Optional["Node"] + node: "Node | None" index: int offset: int @@ -29,9 +29,9 @@ class Node: def __init__( self, type: "NodeType", - attrs: JSONDict, - content: Optional[Fragment], - marks: List[Mark], + attrs: "Attrs", + content: Fragment | None, + marks: list[Mark], ) -> None: self.type = type self.attrs = attrs @@ -49,7 +49,7 @@ def child_count(self) -> int: def child(self, index: int) -> "Node": return self.content.child(index) - def maybe_child(self, index: int) -> Optional["Node"]: + def maybe_child(self, index: int) -> "Node | None": return self.content.maybe_child(index) def for_each(self, f: Callable[["Node", int, int], None]) -> None: @@ -59,13 +59,13 @@ def nodes_between( self, from_: int, to: int, - f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], + f: Callable[["Node", int, "Node | None", int], bool | None], start_pos: int = 0, ) -> None: self.content.nodes_between(from_, to, f, start_pos, self) def descendants( - self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] + self, f: Callable[["Node", int, "Node | None", int], bool | None] ) -> None: self.nodes_between(0, self.content.size, f) @@ -80,16 +80,16 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable[["Node"], str], str] = "", + leaf_text: Callable[["Node"], str] | str = "", ) -> str: return self.content.text_between(from_, to, block_separator, leaf_text) @property - def first_child(self) -> Optional["Node"]: + def first_child(self) -> "Node | None": return self.content.first_child @property - def last_child(self) -> Optional["Node"]: + def last_child(self) -> "Node | None": return self.content.last_child def eq(self, other: "Node") -> bool: @@ -103,8 +103,8 @@ def same_markup(self, other: "Node") -> bool: def has_markup( self, type: "NodeType", - attrs: Optional[JSONDict] = None, - marks: Optional[List[Mark]] = None, + attrs: "Attrs | None" = None, + marks: list[Mark] | None = None, ) -> bool: return ( self.type.name == type.name @@ -112,23 +112,23 @@ def has_markup( and (Mark.same_set(self.marks, marks or Mark.none)) ) - def copy(self, content: Optional[Fragment] = None) -> "Node": + def copy(self, content: Fragment | None = None) -> "Node": if content == self.content: return self return self.__class__(self.type, self.attrs, content, self.marks) - def mark(self, marks: List[Mark]) -> "Node": + def mark(self, marks: list[Mark]) -> "Node": if marks == self.marks: return self return self.__class__(self.type, self.attrs, self.content, marks) - def cut(self, from_: int, to: Optional[int] = None) -> "Node": + def cut(self, from_: int, to: int | None = None) -> "Node": if from_ == 0 and to == self.content.size: return self return self.copy(self.content.cut(from_, to)) def slice( - self, from_: int, to: Optional[int] = None, include_parents: bool = False + self, from_: int, to: int | None = None, include_parents: bool = False ) -> Slice: if to is None: to = self.content.size @@ -145,7 +145,7 @@ def slice( def replace(self, from_: int, to: int, slice: Slice) -> "Node": return replace(self.resolve(from_), self.resolve(to), slice) - def node_at(self, pos: int) -> Optional["Node"]: + def node_at(self, pos: int) -> "Node | None": node = self while True: index_info = node.content.find_index(pos) @@ -183,14 +183,12 @@ def resolve(self, pos: int) -> ResolvedPos: def resolve_no_cache(self, pos: int) -> ResolvedPos: return ResolvedPos.resolve(self, pos) - def range_has_mark( - self, from_: int, to: int, type: Union[Mark, "MarkType"] - ) -> bool: + def range_has_mark(self, from_: int, to: int, type: "Mark | MarkType") -> bool: found = False if to > from_: def iteratee( - node: "Node", pos: int, parent: Optional["Node"], index: int + node: "Node", pos: int, parent: "Node | None", index: int ) -> bool: nonlocal found if type.is_in_set(node.marks): @@ -256,12 +254,12 @@ def can_replace( to: int, replacement: Fragment = Fragment.empty, start: int = 0, - end: Optional[int] = None, + end: int | None = None, ) -> bool: if end is None: end = replacement.child_count one = self.content_match_at(from_).match_fragment(replacement, start, end) - two: Optional["ContentMatch"] = None + two: "ContentMatch | None" = None if one: two = one.match_fragment(self.content, to) if not two or not two.valid_end: @@ -272,12 +270,12 @@ def can_replace( return True def can_replace_with( - self, from_: int, to: int, type: "NodeType", marks: Optional[List[Mark]] = None + self, from_: int, to: int, type: "NodeType", marks: list[Mark] | None = None ) -> bool: if marks and not self.type.allows_marks(marks): return False start = self.content_match_at(from_).match_type(type) - end: Optional["ContentMatch"] = None + end: "ContentMatch | None" = None if start: end = start.match_fragment(self.content, to) return end.valid_end if end else False @@ -307,22 +305,32 @@ def iteratee(node: "Node", offset: int, index: int) -> None: return self.content.for_each(iteratee) - def to_json(self) -> MutableJSONDict: - obj: MutableJSONDict = {"type": self.type.name} + def to_json(self) -> JSONDict: + obj: JSONDict = {"type": self.type.name} if self.attrs: - obj["attrs"] = cast(MutableJSONDict, copy.deepcopy(self.attrs)) + obj = { + **obj, + "attrs": copy.deepcopy(self.attrs), + } if getattr(self.content, "size", None): - obj["content"] = self.content.to_json() + obj = { + **obj, + "content": self.content.to_json(), + } if len(self.marks): - obj["marks"] = [n.to_json() for n in self.marks] + obj = { + **obj, + "marks": [n.to_json() for n in self.marks], + } return obj @classmethod - def from_json(cls, schema: "Schema[str, str]", json_data: Any) -> "Node": + def from_json(cls, schema: "Schema[str, str]", json_data: JSONDict | str) -> "Node": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not json_data: raise ValueError("Invalid input for Node.from_json") marks = None @@ -331,10 +339,10 @@ def from_json(cls, schema: "Schema[str, str]", json_data: Any) -> "Node": raise ValueError("Invalid mark data for Node.fromJSON") marks = [schema.mark_from_json(item) for item in json_data["marks"]] if json_data["type"] == "text": - return schema.text(json_data["text"], marks) + return schema.text(str(json_data["text"]), marks) content = Fragment.from_json(schema, json_data.get("content")) - return schema.node_type(json_data["type"]).create( - json_data.get("attrs"), content, marks + return schema.node_type(str(json_data["type"])).create( + cast("Attrs", json_data.get("attrs")), content, marks ) @@ -342,9 +350,9 @@ class TextNode(Node): def __init__( self, type: "NodeType", - attrs: JSONDict, + attrs: "Attrs", content: str, - marks: List[Mark], + marks: list[Mark], ) -> None: super().__init__(type, attrs, None, marks) if not content: @@ -372,7 +380,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Union[Callable[["Node"], str], str] = "", + leaf_text: Callable[["Node"], str] | str = "", ) -> str: return self.text[from_:to] @@ -380,7 +388,7 @@ def text_between( def node_size(self) -> int: return text_length(self.text) - def mark(self, marks: List[Mark]) -> "TextNode": + def mark(self, marks: list[Mark]) -> "TextNode": return ( self if marks == self.marks @@ -392,7 +400,7 @@ def with_text(self, text: str) -> "TextNode": return self return TextNode(self.type, self.attrs, text, self.marks) - def cut(self, from_: int = 0, to: Optional[int] = None) -> "TextNode": + def cut(self, from_: int = 0, to: int | None = None) -> "TextNode": if to is None: to = text_length(self.text) if from_ == 0 and to == text_length(self.text): @@ -407,11 +415,11 @@ def eq(self, other: Node) -> bool: def to_json( self, - ) -> MutableJSONDict: + ) -> JSONDict: return {**super().to_json(), "text": self.text} -def wrap_marks(marks: List[Mark], str: str) -> str: +def wrap_marks(marks: list[Mark], str: str) -> str: i = len(marks) - 1 while i >= 0: str = marks[i].type.name + "(" + str + ")" diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 502806c..9440f78 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, cast +from typing import TYPE_CHECKING, ClassVar, cast -from prosemirror.utils import JSONDict, MutableJSONDict +from prosemirror.utils import JSONDict from .fragment import Fragment @@ -34,8 +34,8 @@ def remove_range(content: Fragment, from_: int, to: int) -> Fragment: def insert_into( - content: Fragment, dist: int, insert: Fragment, parent: Optional["Node"] -) -> Optional[Fragment]: + content: Fragment, dist: int, insert: Fragment, parent: "Node | None" +) -> Fragment | None: a = content.find_index(dist) index, offset = a["index"], a["offset"] child = content.maybe_child(index) @@ -62,7 +62,7 @@ def __init__(self, content: Fragment, open_start: int, open_end: int) -> None: def size(self) -> int: return self.content.size - self.open_start - self.open_end - def insert_at(self, pos: int, fragment: Fragment) -> Optional["Slice"]: + def insert_at(self, pos: int, fragment: Fragment) -> "Slice | None": content = insert_into(self.content, pos + self.open_start, fragment, None) if content: return Slice(content, self.open_start, self.open_end) @@ -85,18 +85,28 @@ def eq(self, other: "Slice") -> bool: def __str__(self) -> str: return f"{self.content}({self.open_start},{self.open_end})" - def to_json(self) -> Optional[MutableJSONDict]: + def to_json(self) -> JSONDict | None: if not self.content.size: return None - json: MutableJSONDict = {"content": self.content.to_json()} + json: JSONDict = {"content": self.content.to_json()} if self.open_start > 0: - json["openStart"] = self.open_start + json = { + **json, + "openStart": self.open_start, + } if self.open_end > 0: - json["openEnd"] = self.open_end + json = { + **json, + "openEnd": self.open_end, + } return json @classmethod - def from_json(cls, schema: "Schema[str, str]", json_data: JSONDict) -> "Slice": + def from_json( + cls, + schema: "Schema[str, str]", + json_data: JSONDict | None, + ) -> "Slice": if not json_data: return cls.empty open_start = json_data.get("openStart", 0) or 0 @@ -176,7 +186,7 @@ def joinable(before: "ResolvedPos", after: "ResolvedPos", depth: int) -> "Node": return node -def add_node(child: "Node", target: List["Node"]) -> None: +def add_node(child: "Node", target: list["Node"]) -> None: last = len(target) - 1 if last >= 0 and pm_node.is_text(child) and child.same_markup(target[last]): target[last] = child.with_text(cast("TextNode", target[last]).text + child.text) @@ -185,10 +195,10 @@ def add_node(child: "Node", target: List["Node"]) -> None: def add_range( - start: Optional["ResolvedPos"], - end: Optional["ResolvedPos"], + start: "ResolvedPos | None", + end: "ResolvedPos | None", depth: int, - target: List["Node"], + target: list["Node"], ) -> None: node = cast("ResolvedPos", end or start).node(depth) start_index = 0 @@ -223,7 +233,7 @@ def replace_three_way( ) -> Fragment: open_start = joinable(from_, start, depth + 1) if from_.depth > depth else None open_end = joinable(end, to, depth + 1) if to.depth > depth else None - content: List["Node"] = [] + content: list["Node"] = [] add_range(None, from_, depth, content) if open_start and open_end and start.index(depth) == end.index(depth): check_join(open_start, open_end) @@ -244,7 +254,7 @@ def replace_three_way( def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Fragment: - content: List["Node"] = [] + content: list["Node"] = [] add_range(None, from_, depth, content) if from_.depth > depth: type = joinable(from_, to, depth + 1) @@ -255,7 +265,7 @@ def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Frag def prepare_slice_for_replace( slice: Slice, along: "ResolvedPos" -) -> Dict[str, "ResolvedPos"]: +) -> dict[str, "ResolvedPos"]: extra = along.depth - slice.open_start parent = along.node(extra) node = parent.copy(slice.content) diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index 2d95f10..b3cc4ff 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, List, Optional, Union, cast +from typing import TYPE_CHECKING, Callable, cast from .mark import Mark @@ -7,15 +7,13 @@ class ResolvedPos: - def __init__( - self, pos: int, path: List[Union["Node", int]], parent_offset: int - ) -> None: + def __init__(self, pos: int, path: list["Node | int"], parent_offset: int) -> None: self.pos = pos self.path = path self.depth = int(len(path) / 3 - 1) self.parent_offset = parent_offset - def resolve_depth(self, val: Optional[int] = None) -> int: + def resolve_depth(self, val: int | None = None) -> int: if val is None: return self.depth return self.depth + val if val < 0 else val @@ -31,7 +29,7 @@ def doc(self) -> "Node": def node(self, depth: int) -> "Node": return cast("Node", self.path[self.resolve_depth(depth) * 3]) - def index(self, depth: Optional[int] = None) -> int: + def index(self, depth: int | None = None) -> int: return cast(int, self.path[self.resolve_depth(depth) * 3 + 1]) def index_after(self, depth: int) -> int: @@ -40,15 +38,15 @@ def index_after(self, depth: int) -> int: 0 if depth == self.depth and not self.text_offset else 1 ) - def start(self, depth: Optional[int] = None) -> int: + def start(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) return 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 - def end(self, depth: Optional[int] = None) -> int: + def end(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) return self.start(depth) + self.node(depth).content.size - def before(self, depth: Optional[int] = None) -> int: + def before(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) if not depth: raise ValueError("There is no position before the top level node") @@ -56,7 +54,7 @@ def before(self, depth: Optional[int] = None) -> int: self.pos if depth == self.depth + 1 else cast(int, self.path[depth * 3 - 1]) ) - def after(self, depth: Optional[int] = None) -> int: + def after(self, depth: int | None = None) -> int: depth = self.resolve_depth(depth) if not depth: raise ValueError("There is no position after the top level node") @@ -72,7 +70,7 @@ def text_offset(self) -> int: return self.pos - cast(int, self.path[-1]) @property - def node_after(self) -> Optional["Node"]: + def node_after(self) -> "Node | None": parent = self.parent index = self.index(self.depth) if index == parent.child_count: @@ -82,14 +80,14 @@ def node_after(self) -> Optional["Node"]: return parent.child(index).cut(d_off) if d_off else child @property - def node_before(self) -> Optional["Node"]: + def node_before(self) -> "Node | None": index = self.index(self.depth) d_off = self.pos - cast(int, self.path[-1]) if d_off: return self.parent.child(index).cut(0, d_off) return None if index == 0 else self.parent.child(index - 1) - def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: + def pos_at_index(self, index: int, depth: int | None = None) -> int: depth = self.resolve_depth(depth) node = cast("Node", self.path[depth * 3]) pos = 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 @@ -97,7 +95,7 @@ def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: pos += node.child(i).node_size return pos - def marks(self) -> List["Mark"]: + def marks(self) -> list["Mark"]: parent = self.parent index = self.index() if parent.content.size == 0: @@ -119,7 +117,7 @@ def marks(self) -> List["Mark"]: i += 1 return marks - def marks_across(self, end: "ResolvedPos") -> Optional[List["Mark"]]: + def marks_across(self, end: "ResolvedPos") -> list["Mark"] | None: after = self.parent.maybe_child(self.index()) if not after or not after.is_inline: return None @@ -145,9 +143,9 @@ def shared_depth(self, pos: int) -> int: def block_range( self, - other: Optional["ResolvedPos"] = None, - pred: Optional[Callable[["Node"], bool]] = None, - ) -> Optional["NodeRange"]: + other: "ResolvedPos | None" = None, + pred: Callable[["Node"], bool] | None = None, + ) -> "NodeRange | None": if other is None: other = self if other.pos < self.pos: @@ -183,7 +181,7 @@ def __str__(self) -> str: def resolve(cls, doc: "Node", pos: int) -> "ResolvedPos": if not (pos >= 0 and pos <= doc.content.size): raise ValueError(f"Position {pos} out of range") - path: List[Union["Node", int]] = [] + path: list["Node | int"] = [] start = 0 parent_offset = pos node = doc diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index af6922d..b7622bb 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -1,29 +1,24 @@ from typing import ( Any, Callable, - Dict, Generic, - List, Literal, - Optional, TypeVar, - Union, cast, ) from typing_extensions import NotRequired, TypeAlias, TypedDict +from prosemirror.model.content import ContentMatch from prosemirror.model.fragment import Fragment from prosemirror.model.mark import Mark from prosemirror.model.node import Node, TextNode from prosemirror.utils import JSON, JSONDict -from .content import ContentMatch - Attrs: TypeAlias = JSONDict -def default_attrs(attrs: "Attributes") -> Optional[Attrs]: +def default_attrs(attrs: "Attributes") -> Attrs | None: defaults = {} for attr_name, attr in attrs.items(): if not attr.has_default: @@ -32,7 +27,7 @@ def default_attrs(attrs: "Attributes") -> Optional[Attrs]: return defaults -def compute_attrs(attrs: "Attributes", value: Optional[Attrs]) -> Attrs: +def compute_attrs(attrs: "Attributes", value: Attrs | None) -> Attrs: built = {} for name in attrs: given = None @@ -48,7 +43,7 @@ def compute_attrs(attrs: "Attributes", value: Optional[Attrs]) -> Attrs: return built -def init_attrs(attrs: Optional["AttributeSpecs"]) -> "Attributes": +def init_attrs(attrs: "AttributeSpecs | None") -> "Attributes": result = {} if attrs: for name in attrs: @@ -72,9 +67,7 @@ class NodeType: inline_content: bool - content_match: "ContentMatch" - - mark_set: Optional[List["MarkType"]] + mark_set: list["MarkType"] | None def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> None: self.name = name @@ -83,12 +76,21 @@ def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> N self.groups = spec["group"].split(" ") if "group" in spec else [] self.attrs = init_attrs(spec.get("attrs")) self.default_attrs = default_attrs(self.attrs) - self.content_match = None # type: ignore[assignment] + self._content_match: ContentMatch | None = None self.mark_set = None self.inline_content = False self.is_block = not (spec.get("inline") or name == "text") self.is_text = name == "text" + @property + def content_match(self) -> ContentMatch: + assert self._content_match is not None + return self._content_match + + @content_match.setter + def content_match(self, value: ContentMatch) -> None: + self._content_match = value + @property def is_inline(self) -> bool: return not self.is_block @@ -107,7 +109,7 @@ def is_atom(self) -> bool: @property def whitespace(self) -> Literal["pre", "normal"]: - return cast(Literal["pre", "normal"], self.spec.get("whitespace")) or ( + return self.spec.get("whitespace") or ( "pre" if self.spec.get("code") else "normal" ) @@ -120,16 +122,16 @@ def has_required_attrs(self) -> bool: def compatible_content(self, other: "NodeType") -> bool: return self == other or (self.content_match.compatible(other.content_match)) - def compute_attrs(self, attrs: Optional[Attrs]) -> Attrs: + def compute_attrs(self, attrs: Attrs | None) -> Attrs: if attrs is None and self.default_attrs is not None: return self.default_attrs return compute_attrs(self.attrs, attrs) def create( self, - attrs: Optional[Attrs] = None, - content: Optional[Union[Fragment, Node, List[Node]]] = None, - marks: Optional[List[Mark]] = None, + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, ) -> Node: if self.is_text: raise ValueError("NodeType.create cannot construct text nodes") @@ -142,9 +144,9 @@ def create( def create_checked( self, - attrs: Optional[Attrs] = None, - content: Optional[Union[Fragment, Node, List[Node]]] = None, - marks: Optional[List[Mark]] = None, + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, ) -> Node: content = Fragment.from_(content) if not self.valid_content(content): @@ -153,10 +155,10 @@ def create_checked( def create_and_fill( self, - attrs: Optional[Attrs] = None, - content: Optional[Union[Fragment, Node, List[Node]]] = None, - marks: Optional[List[Mark]] = None, - ) -> Optional[Node]: + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, + ) -> Node | None: attrs = self.compute_attrs(attrs) frag = Fragment.from_(content) if frag.size: @@ -184,15 +186,15 @@ def valid_content(self, content: Fragment) -> bool: def allows_mark_type(self, mark_type: "MarkType") -> bool: return self.mark_set is None or mark_type in self.mark_set - def allows_marks(self, marks: List[Mark]) -> bool: + def allows_marks(self, marks: list[Mark]) -> bool: if self.mark_set is None: return True return all(self.allows_mark_type(mark.type) for mark in marks) - def allowed_marks(self, marks: List[Mark]) -> List[Mark]: + def allowed_marks(self, marks: list[Mark]) -> list[Mark]: if self.mark_set is None: return marks - copy: Optional[List[Mark]] = None + copy: list[Mark] | None = None for i, mark in enumerate(marks): if not self.allows_mark_type(mark.type): if not copy: @@ -208,9 +210,9 @@ def allowed_marks(self, marks: List[Mark]) -> List[Mark]: @classmethod def compile( - cls, nodes: Dict["Nodes", "NodeSpec"], schema: "Schema[Nodes, Marks]" - ) -> Dict["Nodes", "NodeType"]: - result: Dict["Nodes", "NodeType"] = {} + cls, nodes: dict["Nodes", "NodeSpec"], schema: "Schema[Nodes, Marks]" + ) -> dict["Nodes", "NodeType"]: + result: dict["Nodes", "NodeType"] = {} for name, spec in nodes.items(): result[name] = NodeType(name, schema, spec) @@ -231,7 +233,7 @@ def __repr__(self) -> str: return self.__str__() -Attributes: TypeAlias = Dict[str, "Attribute"] +Attributes: TypeAlias = dict[str, "Attribute"] class Attribute: @@ -245,8 +247,8 @@ def is_required(self) -> bool: class MarkType: - excluded: List["MarkType"] - instance: Optional[Mark] + excluded: list["MarkType"] + instance: Mark | None def __init__( self, name: str, rank: int, schema: "Schema[str, str]", spec: "MarkSpec" @@ -264,7 +266,7 @@ def __init__( def create( self, - attrs: Optional[Attrs] = None, + attrs: Attrs | None = None, ) -> Mark: if not attrs and self.instance: return self.instance @@ -272,8 +274,8 @@ def create( @classmethod def compile( - cls, marks: Dict["Marks", "MarkSpec"], schema: "Schema[Nodes, Marks]" - ) -> Dict["Marks", "MarkType"]: + cls, marks: dict["Marks", "MarkSpec"], schema: "Schema[Nodes, Marks]" + ) -> dict["Marks", "MarkType"]: result = {} rank = 0 for name, spec in marks.items(): @@ -281,16 +283,17 @@ def compile( rank += 1 return result - def remove_from_set(self, set_: List["Mark"]) -> List["Mark"]: + def remove_from_set(self, set_: list["Mark"]) -> list["Mark"]: return [item for item in set_ if item.type != self] - def is_in_set(self, set: List[Mark]) -> Optional[Mark]: + def is_in_set(self, set: list[Mark]) -> Mark | None: return next((item for item in set if item.type == self), None) def excludes(self, other: "MarkType") -> bool: return any(other.name == e.name for e in self.excluded) +# XXX I don't get these... Nodes = TypeVar("Nodes", bound=str, covariant=True) Marks = TypeVar("Marks", bound=str, covariant=True) @@ -307,13 +310,13 @@ class SchemaSpec(TypedDict, Generic[Nodes, Marks]): # determines which [parse rules](#model.NodeSpec.parseDOM) take # precedence by default, and which nodes come first in a given # [group](#model.NodeSpec.group). - nodes: Dict[Nodes, "NodeSpec"] + nodes: dict[Nodes, "NodeSpec"] # The mark types that exist in this schema. The order in which they # are provided determines the order in which [mark # sets](#model.Mark.addToSet) are sorted and in which [parse # rules](#model.MarkSpec.parseDOM) are tried. - marks: NotRequired[Dict[Marks, "MarkSpec"]] + marks: NotRequired[dict[Marks, "MarkSpec"]] # The name of the default top-level node for the schema. Defaults # to `"doc"`. @@ -340,12 +343,12 @@ class NodeSpec(TypedDict, total=False): defining: bool isolating: bool toDOM: Callable[[Node], Any] # FIXME: add types - parseDOM: List[Dict[str, Any]] # FIXME: add types + parseDOM: list[dict[str, Any]] # FIXME: add types toDebugString: Callable[[Node], str] leafText: Callable[[Node], str] -AttributeSpecs: TypeAlias = Dict[str, "AttributeSpec"] +AttributeSpecs: TypeAlias = dict[str, "AttributeSpec"] class MarkSpec(TypedDict, total=False): @@ -355,7 +358,7 @@ class MarkSpec(TypedDict, total=False): group: str spanning: bool toDOM: Callable[[Mark, bool], Any] # FIXME: add types - parseDOM: List[Dict[str, Any]] # FIXME: add types + parseDOM: list[dict[str, Any]] # FIXME: add types class AttributeSpec(TypedDict, total=False): @@ -365,9 +368,9 @@ class AttributeSpec(TypedDict, total=False): class Schema(Generic[Nodes, Marks]): spec: SchemaSpec[Nodes, Marks] - nodes: Dict[Nodes, "NodeType"] + nodes: dict[Nodes, "NodeType"] - marks: Dict[Marks, "MarkType"] + marks: dict[Marks, "MarkType"] def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: self.spec = spec @@ -382,7 +385,7 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: mark_expr = type.spec.get("marks") if content_expr not in content_expr_cache: content_expr_cache[content_expr] = ContentMatch.parse( - content_expr, cast(Dict[str, "NodeType"], self.nodes) + content_expr, cast(dict[str, "NodeType"], self.nodes) ) type.content_match = content_expr_cache[content_expr] @@ -404,15 +407,15 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: ) self.top_node_type = self.nodes[cast(Nodes, self.spec.get("topNode") or "doc")] - self.cached: Dict[str, Any] = {} + self.cached: dict[str, Any] = {} self.cached["wrappings"] = {} def node( self, - type: Union[str, NodeType], - attrs: Optional[Attrs] = None, - content: Optional[Union[Fragment, Node, List[Node]]] = None, - marks: Optional[List[Mark]] = None, + type: str | NodeType, + attrs: Attrs | None = None, + content: Fragment | Node | list[Node] | None = None, + marks: list[Mark] | None = None, ) -> Node: if isinstance(type, str): type = self.node_type(type) @@ -422,7 +425,7 @@ def node( raise ValueError(f"Node type from different schema used ({type.name})") return type.create_checked(attrs, content, marks) - def text(self, text: str, marks: Optional[List[Mark]] = None) -> TextNode: + def text(self, text: str, marks: list[Mark] | None = None) -> TextNode: type = self.nodes[cast(Nodes, "text")] return TextNode( type, cast(Attrs, type.default_attrs), text, Mark.set_from(marks) @@ -430,20 +433,19 @@ def text(self, text: str, marks: Optional[List[Mark]] = None) -> TextNode: def mark( self, - type: Union[str, MarkType], - attrs: Optional[ - Union[Dict[str, Optional[str]], Dict[str, str], Dict[str, int]] - ] = None, + type: str | MarkType, + attrs: Attrs | None = None, ) -> Mark: if isinstance(type, str): type = self.marks[cast(Marks, type)] return type.create(attrs) - def node_from_json(self, json_data: JSON) -> Union[Node, TextNode]: + def node_from_json(self, json_data: JSONDict) -> Node | TextNode: return Node.from_json(self, json_data) def mark_from_json( - self, json_data: Dict[str, Union[str, Dict[str, Optional[str]], Dict[str, int]]] + self, + json_data: JSONDict, ) -> Mark: return Mark.from_json(self, json_data) @@ -454,7 +456,7 @@ def node_type(self, name: str) -> NodeType: return found -def gather_marks(schema: Schema[str, str], marks: List[str]) -> List[MarkType]: +def gather_marks(schema: Schema[str, str], marks: list[str]) -> list[MarkType]: found = [] for name in marks: mark = schema.marks.get(name) diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 94d0dd3..8e0b954 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -2,12 +2,8 @@ from typing import ( Any, Callable, - Dict, - List, - Optional, Sequence, - Tuple, - Union, + TypeAlias, cast, ) @@ -16,11 +12,11 @@ from .node import Node from .schema import Schema -HTMLNode = Union["Element", str] +HTMLNode: TypeAlias = "Element | str" class DocumentFragment: - def __init__(self, children: List[HTMLNode]) -> None: + def __init__(self, children: list[HTMLNode]) -> None: self.children = children def __str__(self) -> str: @@ -50,7 +46,7 @@ def __str__(self) -> str: class Element(DocumentFragment): def __init__( - self, name: str, attrs: Dict[str, str], children: List[HTMLNode] + self, name: str, attrs: dict[str, str], children: list[HTMLNode] ) -> None: self.name = name self.attrs = attrs @@ -66,25 +62,25 @@ def __str__(self) -> str: return f"<{open_tag_str}>{children_str}" -HTMLOutputSpec = Union[str, Sequence[Any], Element] +HTMLOutputSpec = str | Sequence[Any] | Element class DOMSerializer: def __init__( self, - nodes: Dict[str, Callable[[Node], HTMLOutputSpec]], - marks: Dict[str, Callable[[Mark, bool], HTMLOutputSpec]], + nodes: dict[str, Callable[[Node], HTMLOutputSpec]], + marks: dict[str, Callable[[Mark, bool], HTMLOutputSpec]], ) -> None: self.nodes = nodes self.marks = marks def serialize_fragment( - self, fragment: Fragment, target: Optional[Element] = None + self, fragment: Fragment, target: Element | None = None ) -> DocumentFragment: tgt: DocumentFragment = target or DocumentFragment(children=[]) top = tgt - active: Optional[List[Tuple[Mark, DocumentFragment]]] = None + active: list[tuple[Mark, DocumentFragment]] | None = None def each(node: Node, offset: int, index: int) -> None: nonlocal top, active @@ -141,16 +137,14 @@ def serialize_node(self, node: Node) -> HTMLNode: def serialize_mark( self, mark: Mark, inline: bool - ) -> Optional[Tuple[HTMLNode, Optional[Element]]]: + ) -> tuple[HTMLNode, Element | None] | None: to_dom = self.marks.get(mark.type.name) if to_dom: return type(self).render_spec(to_dom(mark, inline)) return None @classmethod - def render_spec( - cls, structure: HTMLOutputSpec - ) -> Tuple[HTMLNode, Optional[Element]]: + def render_spec(cls, structure: HTMLOutputSpec) -> tuple[HTMLNode, Element | None]: if isinstance(structure, str): return html.escape(structure), None if isinstance(structure, Element): @@ -158,7 +152,7 @@ def render_spec( tag_name = structure[0] if " " in tag_name[1:]: raise NotImplementedError("XML namespaces are not supported") - content_dom: Optional[Element] = None + content_dom: Element | None = None dom = Element(name=tag_name, attrs={}, children=[]) attrs = structure[1] if len(structure) > 1 else None start = 1 @@ -193,7 +187,7 @@ def from_schema(cls, schema: Schema[str, str]) -> "DOMSerializer": @classmethod def nodes_from_schema( cls, schema: Schema[str, str] - ) -> Dict[str, Callable[["Node"], HTMLOutputSpec]]: + ) -> dict[str, Callable[["Node"], HTMLOutputSpec]]: result = gather_to_dom(schema.nodes) if "text" not in result: result["text"] = lambda node: node.text @@ -202,11 +196,11 @@ def nodes_from_schema( @classmethod def marks_from_schema( cls, schema: Schema[str, str] - ) -> Dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: + ) -> dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: return gather_to_dom(schema.marks) -def gather_to_dom(obj: Dict[str, Any]) -> Dict[str, Callable[..., Any]]: +def gather_to_dom(obj: dict[str, Any]) -> dict[str, Callable[..., Any]]: result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") From 4c08bf4d372d1cc09a768e04600a75492558b891 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Mon, 13 Nov 2023 14:35:31 -0500 Subject: [PATCH 23/40] Adding types to the transform module. Creating the Mappable abstract class (just like in the original prosemirror repo). --- prosemirror/schema/basic/schema_basic.py | 6 +- prosemirror/schema/list/schema_list.py | 34 ++-- prosemirror/test_builder/__init__.py | 6 +- prosemirror/test_builder/build.py | 24 ++- prosemirror/transform/__init__.py | 16 +- prosemirror/transform/attr_step.py | 34 ++-- prosemirror/transform/map.py | 101 ++++++---- prosemirror/transform/mark_step.py | 113 ++++++----- prosemirror/transform/replace.py | 163 +++++++++++----- prosemirror/transform/replace_step.py | 84 +++++---- prosemirror/transform/step.py | 72 ++++--- prosemirror/transform/structure.py | 85 ++++++--- prosemirror/transform/transform.py | 229 ++++++++++++++++------- 13 files changed, 643 insertions(+), 324 deletions(-) diff --git a/prosemirror/schema/basic/schema_basic.py b/prosemirror/schema/basic/schema_basic.py index f98cc5e..88be8e4 100644 --- a/prosemirror/schema/basic/schema_basic.py +++ b/prosemirror/schema/basic/schema_basic.py @@ -1,5 +1,3 @@ -from typing import Dict - from prosemirror.model import Schema from prosemirror.model.schema import MarkSpec, NodeSpec @@ -9,7 +7,7 @@ pre_dom = ["pre", ["code", 0]] br_dom = ["br"] -nodes: Dict[str, NodeSpec] = { +nodes: dict[str, NodeSpec] = { "doc": {"content": "block+"}, "paragraph": { "content": "inline*", @@ -90,7 +88,7 @@ strong_dom = ["strong", 0] code_dom = ["code", 0] -marks: Dict[str, MarkSpec] = { +marks: dict[str, MarkSpec] = { "link": { "attrs": {"href": {}, "title": {"default": None}}, "inclusive": False, diff --git a/prosemirror/schema/list/schema_list.py b/prosemirror/schema/list/schema_list.py index 8d26d1c..a0146b8 100644 --- a/prosemirror/schema/list/schema_list.py +++ b/prosemirror/schema/list/schema_list.py @@ -1,38 +1,44 @@ +from typing import cast + +from prosemirror.model.schema import Nodes, NodeSpec + OL_DOM = ["ol", 0] UL_DOM = ["ul", 0] LI_DOM = ["li", 0] -orderd_list = { - "attrs": {"order": {"default": 1}}, - "parseDOM": [{"tag": "ol"}], - "toDOM": lambda node: ( +orderd_list = NodeSpec( + attrs={"order": {"default": 1}}, + parseDOM=[{"tag": "ol"}], + toDOM=lambda node: ( OL_DOM if node.attrs.get("order") == 1 else ["ol", {"start": node.attrs["order"]}, 0] ), -} +) -bullet_list = {"parseDOM": [{"tag": "ul"}], "toDOM": lambda _: UL_DOM} +bullet_list = NodeSpec(parseDOM=[{"tag": "ul"}], toDOM=lambda _: UL_DOM) -list_item = {"parseDOM": [{"tag": "li"}], "defining": True, "toDOM": lambda _: LI_DOM} +list_item = NodeSpec(parseDOM=[{"tag": "li"}], defining=True, toDOM=lambda _: LI_DOM) -def add(obj, props): +def add(obj: "NodeSpec", props: "NodeSpec") -> "NodeSpec": return {**obj, **props} -def add_list_nodes(nodes, item_content, list_group): +def add_list_nodes( + nodes: dict["Nodes", "NodeSpec"], item_content: str, list_group: str +) -> dict["Nodes", "NodeSpec"]: copy = nodes.copy() copy.update( { - "ordered_list": add( - orderd_list, {"content": "list_item+", "group": list_group} + cast(Nodes, "ordered_list"): add( + orderd_list, NodeSpec(content="list_item+", group=list_group) ), - "bullet_list": add( - bullet_list, {"content": "list_item+", "group": list_group} + cast(Nodes, "bullet_list"): add( + bullet_list, NodeSpec(content="list_item+", group=list_group) ), - "list_item": add(list_item, {"content": item_content}), + cast(Nodes, "list_item"): add(list_item, NodeSpec(content=item_content)), } ) return copy diff --git a/prosemirror/test_builder/__init__.py b/prosemirror/test_builder/__init__.py index e500353..454b0be 100644 --- a/prosemirror/test_builder/__init__.py +++ b/prosemirror/test_builder/__init__.py @@ -1,4 +1,6 @@ -from prosemirror.model import Schema +# type: ignore + +from prosemirror.model import Node, Schema from prosemirror.schema.basic import schema as _schema from prosemirror.schema.list import add_list_nodes @@ -31,5 +33,5 @@ ) -def eq(a, b): +def eq(a: Node, b: Node) -> bool: return a.eq(b) diff --git a/prosemirror/test_builder/build.py b/prosemirror/test_builder/build.py index 9f22644..9e15dca 100644 --- a/prosemirror/test_builder/build.py +++ b/prosemirror/test_builder/build.py @@ -1,11 +1,19 @@ +# type: ignore + import re +from collections.abc import Callable -from prosemirror.model import Node +from prosemirror.model import Node, Schema +from prosemirror.utils import JSONDict -NO_TAG = Node.tag = {} # type: ignore +NO_TAG = Node.tag = {} -def flatten(schema, children, f): +def flatten( + schema: Schema[str, str], + children: list[Node | JSONDict | str], + f: Callable[[Node], Node], +) -> tuple[list[Node], dict[str, int]]: result, pos, tag = [], 0, NO_TAG for child in children: @@ -47,7 +55,7 @@ def flatten(schema, children, f): node = f(child) pos += node.node_size result.append(node) - return {"nodes": result, "tag": tag} + return result, tag def id(x): @@ -66,9 +74,7 @@ def result(*args): ): my_attrs.update(args[0]) args = args[1:] - flatten_res = flatten(type.schema, args, id) - nodes = flatten_res["nodes"] - tag = flatten_res["tag"] + nodes, tag = flatten(type.schema, args, id) node = type.create(my_attrs, nodes) if tag != NO_TAG: node.tag = tag @@ -102,8 +108,8 @@ def f(n): n if mark.type.is_in_set(n.marks) else n.mark(mark.add_to_set(n.marks)) ) - flatten_res = flatten(type.schema, args, f) - return {"flat": flatten_res["nodes"], "tag": flatten_res["tag"]} + nodes, tag = flatten(type.schema, args, f) + return {"flat": nodes, "tag": tag} return result diff --git a/prosemirror/transform/__init__.py b/prosemirror/transform/__init__.py index b4975d8..9f6044f 100644 --- a/prosemirror/transform/__init__.py +++ b/prosemirror/transform/__init__.py @@ -1,6 +1,11 @@ +from .attr_step import AttrStep from .map import Mapping, MapResult, StepMap -from .mark_step import AddMarkStep, RemoveMarkStep -from .replace import replace_step +from .mark_step import AddMarkStep, AddNodeMarkStep, RemoveMarkStep, RemoveNodeMarkStep +from .replace import ( + close_fragment, + covered_depths, + fits_trivially, +) from .replace_step import ReplaceAroundStep, ReplaceStep from .step import Step, StepResult from .structure import ( @@ -26,12 +31,19 @@ "drop_point", "lift_target", "find_wrapping", + "close_fragment", + "covered_depths", + "fits_trivially", + "replace_step", "StepMap", "MapResult", "Mapping", + "AttrStep", "AddMarkStep", + "AddNodeMarkStep", "RemoveMarkStep", "ReplaceAroundStep", + "RemoveNodeMarkStep", "ReplaceStep", "replace_step", ] diff --git a/prosemirror/transform/attr_step.py b/prosemirror/transform/attr_step.py index 76f8d3e..d7e8c64 100644 --- a/prosemirror/transform/attr_step.py +++ b/prosemirror/transform/attr_step.py @@ -1,16 +1,19 @@ -from prosemirror.model import Fragment, Slice +from typing import cast -from .step import Step, StepMap, StepResult +from prosemirror.model import Fragment, Node, Schema, Slice +from prosemirror.transform.map import Mappable, StepMap +from prosemirror.transform.step import Step, StepResult +from prosemirror.utils import JSON, JSONDict class AttrStep(Step): - def __init__(self, pos, attr, value): + def __init__(self, pos: int, attr: str, value: JSON) -> None: super().__init__() self.pos = pos self.attr = attr self.value = value - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: node = doc.node_at(self.pos) if not node: return StepResult.fail("No node at attribute step's position") @@ -26,32 +29,33 @@ def apply(self, doc): Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def get_map(self): + def get_map(self) -> StepMap: return StepMap.empty - def invert(self, doc): - return AttrStep(self.pos, self.attr, doc.node_at(self.pos).attrs[self.attr]) + def invert(self, doc: Node) -> "AttrStep": + node_at_pos = doc.node_at(self.pos) + assert node_at_pos is not None + return AttrStep(self.pos, self.attr, node_at_pos.attrs[self.attr]) - def map(self, mapping): + def map(self, mapping: Mappable) -> Step | None: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AttrStep(pos.pos, self.attr, self.value) - def to_json(self): - json_data = { + def to_json(self) -> JSONDict: + return { "stepType": "attr", "pos": self.pos, "attr": self.attr, "value": self.value, } - return json_data - @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "AttrStep": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["pos"], int) or not isinstance( json_data["attr"], str ): @@ -59,4 +63,4 @@ def from_json(schema, json_data): return AttrStep(json_data["pos"], json_data["attr"], json_data["value"]) -Step.json_id("attr", AttrStep) +# Step.json_id("attr", AttrStep) diff --git a/prosemirror/transform/map.py b/prosemirror/transform/map.py index 06d3016..196f58b 100644 --- a/prosemirror/transform/map.py +++ b/prosemirror/transform/map.py @@ -1,18 +1,20 @@ -from typing import ClassVar +import abc +from collections.abc import Callable +from typing import Any, ClassVar, Literal, overload lower16 = 0xFFFF factor16 = 2**16 -def make_recover(index, offset): +def make_recover(index: float, offset: int) -> int: return int(index + offset * factor16) -def recover_index(value): +def recover_index(value: int) -> int: return int(value & lower16) -def recover_offset(value): +def recover_offset(value: int) -> int: return int((value - (value & lower16)) / factor16) @@ -23,7 +25,7 @@ def recover_offset(value): class MapResult: - def __init__(self, pos, del_info=0, recover=None): + def __init__(self, pos: int, del_info: int = 0, recover: int | None = None) -> None: self.pos = pos self.del_info = del_info self.recover = recover @@ -37,26 +39,36 @@ def __init__(self, pos, del_info=0, recover=None): # get deletedAcross() { return (this.delInfo & DEL_ACROSS) > 0 } @property - def deleted(self): + def deleted(self) -> bool: return (self.del_info & DEL_SIDE) > 0 @property - def deleted_before(self): + def deleted_before(self) -> bool: return (self.del_info & (DEL_BEFORE | DEL_ACROSS)) > 0 @property - def deleted_after(self): + def deleted_after(self) -> bool: return (self.del_info & (DEL_AFTER | DEL_ACROSS)) > 0 @property - def deleted_across(self): + def deleted_across(self) -> bool: return (self.del_info & DEL_ACROSS) > 0 -class StepMap: +class Mappable(metaclass=abc.ABCMeta): + @abc.abstractmethod + def map(self, pos: int, assoc: int = 1) -> int: + ... + + @abc.abstractmethod + def map_result(self, pos: int, assoc: int = 1) -> MapResult: + ... + + +class StepMap(Mappable): empty: ClassVar["StepMap"] - def __init__(self, ranges, inverted=False): + def __init__(self, ranges: list[int | Any], inverted: bool = False) -> None: # prosemirror-transform overrides the constructor to return the # StepMap.empty singleton when ranges are empty. # It is not easy to do in Python, and the intent of that is to make sure @@ -64,7 +76,7 @@ def __init__(self, ranges, inverted=False): self.ranges = ranges self.inverted = inverted - def recover(self, value): + def recover(self, value: int) -> int: diff = 0 index = recover_index(value) if not self.inverted: @@ -72,13 +84,21 @@ def recover(self, value): diff += self.ranges[i * 3 + 2] - self.ranges[i * 3 + 1] return self.ranges[index * 3] + diff + recover_offset(value) - def map_result(self, pos, assoc=1): + def map(self, pos: int, assoc: int = 1) -> int: + return self._map(pos, assoc, True) + + def map_result(self, pos: int, assoc: int = 1) -> MapResult: return self._map(pos, assoc, False) - def map(self, pos, assoc=1): - return self._map(pos, assoc, True) + @overload + def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: + ... + + @overload + def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: + ... - def _map(self, pos, assoc, simple): + def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: diff = 0 old_index = 2 if self.inverted else 1 new_index = 1 if self.inverted else 2 @@ -117,7 +137,7 @@ def _map(self, pos, assoc, simple): diff += new_size - old_size return pos + diff if simple else MapResult(pos + diff, 0, None) - def touches(self, pos, recover): + def touches(self, pos: int, recover: int) -> bool: diff = 0 index = recover_index(recover) old_index = 2 if self.inverted else 1 @@ -133,7 +153,7 @@ def touches(self, pos, recover): diff += self.ranges[i + new_index] - old_size return False - def for_each(self, f): + def for_each(self, f: Callable[[int, int, int, int], None]) -> None: old_index = 2 if self.inverted else 1 new_index = 1 if self.inverted else 2 i = 0 @@ -147,40 +167,46 @@ def for_each(self, f): f(old_start, old_start + old_size, new_start, new_start + new_size) i += 3 - def invert(self): + def invert(self) -> "StepMap": return StepMap(self.ranges, not self.inverted) - def __str__(self): + def __str__(self) -> str: return ("-" if self.inverted else "") + str(self.ranges) StepMap.empty = StepMap([]) -class Mapping: - def __init__(self, maps=None, mirror=None, from_=None, to=None): +class Mapping(Mappable): + def __init__( + self, + maps: list[StepMap] | None = None, + mirror: list[int] | None = None, + from_: int | None = None, + to: int | None = None, + ) -> None: self.maps = maps or [] self.from_ = from_ or 0 self.to = len(self.maps) if to is None else to self.mirror = mirror - def slice(self, from_=0, to=None): + def slice(self, from_: int = 0, to: int | None = None) -> "Mapping": if to is None: to = len(self.maps) return Mapping(self.maps, self.mirror, from_, to) - def copy(self): + def copy(self) -> "Mapping": return Mapping( self.maps[:], (self.mirror[:] if self.mirror else None), self.from_, self.to ) - def append_map(self, map, mirrors=None): + def append_map(self, map: StepMap, mirrors: int | None = None) -> None: self.maps.append(map) self.to = len(self.maps) if mirrors is not None: self.set_mirror(len(self.maps) - 1, mirrors) - def append_mapping(self, mapping: "Mapping"): + def append_mapping(self, mapping: "Mapping") -> None: i = 0 start_size = len(self.maps) while i < len(mapping.maps): @@ -191,18 +217,19 @@ def append_mapping(self, mapping: "Mapping"): (start_size + mirr) if (mirr is not None and mirr < i) else None, ) - def get_mirror(self, n): + def get_mirror(self, n: int) -> int | None: if self.mirror: for i in range(len(self.mirror)): if (self.mirror[i]) == n: return self.mirror[i + (-1 if i % 2 else 1)] + return None - def set_mirror(self, n, m): + def set_mirror(self, n: int, m: int) -> None: if not self.mirror: self.mirror = [] self.mirror.extend([n, m]) - def append_mapping_inverted(self, mapping: "Mapping"): + def append_mapping_inverted(self, mapping: "Mapping") -> None: i = len(mapping.maps) - 1 total_size = len(self.maps) + len(mapping.maps) while i >= 0: @@ -213,22 +240,30 @@ def append_mapping_inverted(self, mapping: "Mapping"): ) i -= 1 - def invert(self): + def invert(self) -> "Mapping": inverse = Mapping() inverse.append_mapping_inverted(self) return inverse - def map(self, pos, assoc=1): + def map(self, pos: int, assoc: int = 1) -> int: if self.mirror: return self._map(pos, assoc, True) for i in range(self.from_, self.to): pos = self.maps[i].map(pos, assoc) return pos - def map_result(self, pos, assoc=1): + def map_result(self, pos: int, assoc: int = 1) -> MapResult: return self._map(pos, assoc, False) - def _map(self, pos, assoc, simple): + @overload + def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: + ... + + @overload + def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: + ... + + def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: del_info = 0 i = self.from_ diff --git a/prosemirror/transform/mark_step.py b/prosemirror/transform/mark_step.py index c77afb3..32faabc 100644 --- a/prosemirror/transform/mark_step.py +++ b/prosemirror/transform/mark_step.py @@ -1,9 +1,16 @@ -from prosemirror.model import Fragment, Slice +from typing import Callable, cast -from .step import Step, StepResult +from prosemirror.model import Fragment, Mark, Node, Schema, Slice +from prosemirror.transform.map import Mappable +from prosemirror.transform.step import Step, StepResult, step_json_id +from prosemirror.utils import JSONDict -def map_fragment(fragment: Fragment, f, parent=None): +def map_fragment( + fragment: Fragment, + f: Callable[[Node, Node | None, int], Node], + parent: Node | None = None, +) -> Fragment: mapped = [] for i in range(fragment.child_count): child = fragment.child(i) @@ -16,19 +23,21 @@ def map_fragment(fragment: Fragment, f, parent=None): class AddMarkStep(Step): - def __init__(self, from_, to, mark): + def __init__(self, from_: int, to: int, mark: Mark) -> None: super().__init__() self.from_ = from_ self.to = to self.mark = mark - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: old_slice = doc.slice(self.from_, self.to) from__ = doc.resolve(self.from_) parent = from__.node(from__.shared_depth(self.to)) - def iteratee(node, parent, *args): - if not node.is_atom or not parent.type.allows_mark_type(self.mark.type): + def iteratee(node: Node, parent: Node | None, i: int) -> Node: + if parent and ( + not node.is_atom or not parent.type.allows_mark_type(self.mark.type) + ): return node return node.mark(self.mark.add_to_set(node.marks)) @@ -39,17 +48,17 @@ def iteratee(node, parent, *args): ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc=None): + def invert(self, doc: Node | None = None) -> "RemoveMarkStep": return RemoveMarkStep(self.from_, self.to, self.mark) - def map(self, mapping): + def map(self, mapping: Mappable) -> "AddMarkStep | None": from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if from_.deleted and to.deleted or from_.pos > to.pos: return None return AddMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: "AddMarkStep"): + def merge(self, other: "Step") -> "AddMarkStep | None": if ( isinstance(other, AddMarkStep) and other.mark.eq(self.mark) @@ -59,45 +68,48 @@ def merge(self, other: "AddMarkStep"): return AddMarkStep( min(self.from_, other.from_), max(self.to, other.to), self.mark ) + return None - def to_json(self): - json_data = { + def to_json(self) -> JSONDict: + return { "stepType": "addMark", "mark": self.mark.to_json(), "from": self.from_, "to": self.to, } - return json_data @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "AddMarkStep": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["from"], int) or not isinstance( json_data["to"], int ): raise ValueError("Invalid input for AddMarkStep.from_json") return AddMarkStep( - json_data["from"], json_data["to"], schema.mark_from_json(json_data["mark"]) + json_data["from"], + json_data["to"], + schema.mark_from_json(cast(JSONDict, json_data["mark"])), ) -Step.json_id("addMark", AddMarkStep) +step_json_id("addMark", AddMarkStep) class RemoveMarkStep(Step): - def __init__(self, from_, to, mark): + def __init__(self, from_: int, to: int, mark: Mark) -> None: super().__init__() self.from_ = from_ self.to = to self.mark = mark - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: old_slice = doc.slice(self.from_, self.to) - def iteratee(node, *args): + def iteratee(node: Node, parent: Node | None, i: int) -> Node: return node.mark(self.mark.remove_from_set(node.marks)) slice = Slice( @@ -107,17 +119,17 @@ def iteratee(node, *args): ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc=None): + def invert(self, doc: Node | None = None) -> AddMarkStep: return AddMarkStep(self.from_, self.to, self.mark) - def map(self, mapping): + def map(self, mapping: Mappable) -> "RemoveMarkStep | None": from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if (from_.deleted and to.deleted) or (from_.pos > to.pos): return None return RemoveMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: "RemoveMarkStep"): + def merge(self, other: "Step") -> "RemoveMarkStep | None": if ( isinstance(other, RemoveMarkStep) and (other.mark.eq(self.mark)) @@ -127,41 +139,44 @@ def merge(self, other: "RemoveMarkStep"): return RemoveMarkStep( min(self.from_, other.from_), max(self.to, other.to), self.mark ) + return None - def to_json(self): - json_data = { + def to_json(self) -> JSONDict: + return { "stepType": "removeMark", "mark": self.mark.to_json(), "from": self.from_, "to": self.to, } - return json_data @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["from"], int) or not isinstance( json_data["to"], int ): raise ValueError("Invalid input for RemoveMarkStep.from_json") return RemoveMarkStep( - json_data["from"], json_data["to"], schema.mark_from_json(json_data["mark"]) + json_data["from"], + json_data["to"], + schema.mark_from_json(cast(JSONDict, json_data["mark"])), ) -Step.json_id("removeMark", RemoveMarkStep) +step_json_id("removeMark", RemoveMarkStep) class AddNodeMarkStep(Step): - def __init__(self, pos, mark): + def __init__(self, pos: int, mark: Mark) -> None: super().__init__() self.pos = pos self.mark = mark - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: node = doc.node_at(self.pos) if not node: return StepResult.fail("No node at mark step's position") @@ -173,7 +188,7 @@ def apply(self, doc): Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def invert(self, doc): + def invert(self, doc: Node) -> "RemoveNodeMarkStep | AddNodeMarkStep": node = doc.node_at(self.pos) if node: new_set = self.mark.add_to_set(node.marks) @@ -184,11 +199,11 @@ def invert(self, doc): return AddNodeMarkStep(self.pos, self.mark) return RemoveNodeMarkStep(self.pos, self.mark) - def map(self, mapping): + def map(self, mapping: Mappable) -> "AddNodeMarkStep | None": pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AddNodeMarkStep(pos.pos, self.mark) - def to_json(self): + def to_json(self) -> JSONDict: return { "stepType": "addNodeMark", "pos": self.pos, @@ -196,28 +211,29 @@ def to_json(self): } @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["pos"], int): raise ValueError("Invalid input for AddNodeMarkStep.from_json") return AddNodeMarkStep( - json_data["pos"], schema.mark_from_json(json_data["mark"]) + json_data["pos"], schema.mark_from_json(cast(JSONDict, json_data["mark"])) ) -Step.json_id("addNodeMark", AddNodeMarkStep) +step_json_id("addNodeMark", AddNodeMarkStep) class RemoveNodeMarkStep(Step): - def __init__(self, pos, mark): + def __init__(self, pos: int, mark: Mark) -> None: super().__init__() self.pos = pos self.mark = mark - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: node = doc.node_at(self.pos) if not node: return StepResult.fail("No node at mark step's position") @@ -231,17 +247,17 @@ def apply(self, doc): Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def invert(self, doc): + def invert(self, doc: Node) -> "RemoveNodeMarkStep | AddNodeMarkStep": node = doc.node_at(self.pos) if not node or not self.mark.is_in_set(node.marks): return self return AddNodeMarkStep(self.pos, self.mark) - def map(self, mapping): + def map(self, mapping: Mappable) -> "RemoveNodeMarkStep | None": pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else RemoveNodeMarkStep(pos.pos, self.mark) - def to_json(self): + def to_json(self) -> JSONDict: return { "stepType": "removeNodeMark", "pos": self.pos, @@ -249,16 +265,17 @@ def to_json(self): } @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["pos"], int): raise ValueError("Invalid input for RemoveNodeMarkStep.from_json") return RemoveNodeMarkStep( - json_data["pos"], schema.mark_from_json(json_data["mark"]) + json_data["pos"], schema.mark_from_json(cast(JSONDict, json_data["mark"])) ) -Step.json_id("removeNodeMark", RemoveNodeMarkStep) +step_json_id("removeNodeMark", RemoveNodeMarkStep) diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index 437958f..06ee53d 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -1,11 +1,24 @@ -from typing import List, Optional, cast - -from prosemirror.model import Fragment, Node, ResolvedPos, Slice - -from .replace_step import ReplaceAroundStep, ReplaceStep, Step - - -def replace_step(doc, from_, to=None, slice=None): +from typing import cast + +from prosemirror.model import ( + Attrs, + ContentMatch, + Fragment, + Node, + NodeType, + ResolvedPos, + Slice, +) +from prosemirror.transform.replace_step import ReplaceAroundStep, ReplaceStep +from prosemirror.transform.step import Step + + +def replace_step( + doc: Node, + from_: int, + to: int | None = None, + slice: Slice | None = None, +) -> Step | None: if to is None: to = from_ if slice is None: @@ -20,7 +33,11 @@ def replace_step(doc, from_, to=None, slice=None): return Fitter(from__, to_, slice).fit() -def fits_trivially(from__, to_, slice): +def fits_trivially( + from__: ResolvedPos, + to_: ResolvedPos, + slice: Slice, +) -> bool: if not slice.open_start and not slice.open_end and from__.start() == to_.start(): return from__.parent.can_replace(from__.index(), to_.index(), slice.content) return False @@ -29,7 +46,7 @@ def fits_trivially(from__, to_, slice): class _FrontierItem: __slots__ = ("type", "match") - def __init__(self, type_, match): + def __init__(self, type_: NodeType, match: ContentMatch) -> None: self.type = type_ self.match = match @@ -37,7 +54,14 @@ def __init__(self, type_, match): class _Fittable: __slots__ = ("slice_depth", "frontier_depth", "parent", "inject", "wrap") - def __init__(self, slice_depth, frontier_depth, parent, inject=None, wrap=None): + def __init__( + self, + slice_depth: int, + frontier_depth: int, + parent: Node | None, + inject: Fragment | None = None, + wrap: list[NodeType] | None = None, + ) -> None: self.slice_depth = slice_depth self.frontier_depth = frontier_depth self.parent = parent @@ -48,7 +72,12 @@ def __init__(self, slice_depth, frontier_depth, parent, inject=None, wrap=None): class _CloseLevel: __slots__ = ("depth", "fit", "move") - def __init__(self, depth, fit, move): + def __init__( + self, + depth: int, + fit: Fragment, + move: ResolvedPos, + ) -> None: self.depth = depth self.fit = fit self.move = move @@ -57,12 +86,12 @@ def __init__(self, depth, fit, move): class Fitter: __slots__ = ("to_", "from__", "unplaced", "frontier", "placed") - def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice): + def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice) -> None: self.to_ = to_ self.from__ = from__ self.unplaced = slice - self.frontier: List[_FrontierItem] = [] + self.frontier: list[_FrontierItem] = [] for i in range(from__.depth + 1): node = from__.node(i) self.frontier.append( @@ -77,7 +106,7 @@ def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice): def depth(self) -> int: return len(self.frontier) - 1 - def fit(self) -> Optional[Step]: + def fit(self) -> Step | None: while self.unplaced.size: fit = self.find_fittable() if fit: @@ -118,7 +147,7 @@ def fit(self) -> Optional[Step]: return ReplaceStep(from__.pos, to_.pos, slice) return None - def find_fittable(self) -> Optional[_Fittable]: + def find_fittable(self) -> _Fittable | None: start_depth = self.unplaced.open_start cur = self.unplaced.content open_end = self.unplaced.open_end @@ -153,17 +182,18 @@ def find_fittable(self) -> Optional[_Fittable]: inject = _nothing wrap = _nothing - def _lazy_inject(): + def _lazy_inject() -> Fragment | None: nonlocal inject if inject is _nothing: inject = match.fill_before(Fragment.from_(first), False) - return inject + return cast(Fragment | None, inject) - def _lazy_wrap(): + def _lazy_wrap() -> list[NodeType] | None: nonlocal wrap + assert first is not None if wrap is _nothing: wrap = match.find_wrapping(first.type) - return wrap + return cast(list[NodeType] | None, wrap) if pass_ == 1 and ( (match.match_type(first.type) or _lazy_inject()) @@ -206,7 +236,7 @@ def open_more(self) -> bool: ) return True - def drop_node(self): + def drop_node(self) -> None: content = self.unplaced.content open_start = self.unplaced.open_start open_end = self.unplaced.open_end @@ -225,7 +255,7 @@ def drop_node(self): open_end, ) - def place_nodes(self, fittable: _Fittable): + def place_nodes(self, fittable: _Fittable) -> None: slice_depth = fittable.slice_depth frontier_depth = fittable.frontier_depth parent = fittable.parent @@ -249,7 +279,9 @@ def place_nodes(self, fittable: _Fittable): if inject: for i in range(inject.child_count): add.append(inject.child(i)) - match = match.match_fragment(inject) + matched_fragment = match.match_fragment(inject) + assert matched_fragment is not None + match = matched_fragment open_end_count = (fragment.size + slice_depth) - ( slice.content.size - slice.open_end @@ -294,6 +326,7 @@ def place_nodes(self, fittable: _Fittable): cur = fragment for _ in range(open_end_count): node = cur.last_child + assert node is not None self.frontier.append( _FrontierItem(node.type, node.content_match_at(node.child_count)) ) @@ -314,7 +347,7 @@ def place_nodes(self, fittable: _Fittable): slice.open_end if open_end_count < 0 else slice_depth - 1, ) - def must_move_inline(self): + def must_move_inline(self) -> int: if not self.to_.parent.is_text_block: return -1 top = self.frontier[self.depth] @@ -322,11 +355,11 @@ def must_move_inline(self): _nothing = object() level = _nothing - def _lazy_level(): + def _lazy_level() -> _CloseLevel | None: nonlocal level if level is _nothing: level = self.find_close_level(self.to_) - return level + return cast(_CloseLevel | None, level) if ( not top.type.is_text_block @@ -335,8 +368,8 @@ def _lazy_level(): ) or ( self.to_.depth == self.depth - and _lazy_level() - and _lazy_level().depth == self.depth + and (lazy_level := _lazy_level()) + and lazy_level.depth == self.depth ) ): return -1 @@ -350,7 +383,7 @@ def _lazy_level(): after += 1 return after - def find_close_level(self, to_): + def find_close_level(self, to_: ResolvedPos) -> _CloseLevel | None: for i in range(min(self.depth, to_.depth), -1, -1): match = self.frontier[i].match type_ = self.frontier[i].type @@ -373,7 +406,7 @@ def find_close_level(self, to_): ) return None - def close(self, to_): + def close(self, to_: ResolvedPos) -> ResolvedPos | None: close = self.find_close_level(to_) if not close: return None @@ -389,18 +422,25 @@ def close(self, to_): self.open_frontier_node(node.type, node.attrs, add) return to_ - def open_frontier_node(self, type_, attrs=None, content=None): + def open_frontier_node( + self, + type_: NodeType, + attrs: Attrs | None = None, + content: Fragment | None = None, + ) -> None: top = self.frontier[self.depth] - top.match = top.match.match_type(type_) + top_match = top.match.match_type(type_) + assert top_match is not None + top.match = top_match self.placed = add_to_fragment( self.placed, self.depth, Fragment.from_(type_.create(attrs, content)) ) self.frontier.append(_FrontierItem(type_, type_.content_match)) - def close_frontier_node(self): + def close_frontier_node(self) -> None: open_ = self.frontier.pop() add = open_.match.fill_before(Fragment.empty, True) - if add.child_count: + if add and add.child_count: self.placed = add_to_fragment(self.placed, len(self.frontier), add) @@ -432,11 +472,12 @@ def content_at(fragment: Fragment, depth: int) -> Fragment: return fragment -def close_node_start(node, open_start, open_end): +def close_node_start(node: Node, open_start: int, open_end: int) -> Node: if open_start <= 0: return node frag = node.content if open_start > 1: + assert frag.first_child is not None frag = frag.replace_child( 0, close_node_start( @@ -446,17 +487,25 @@ def close_node_start(node, open_start, open_end): ), ) if open_start > 0: - frag = node.type.content_match.fill_before(frag).append(frag) + fill_before_frag = node.type.content_match.fill_before(frag) + assert fill_before_frag is not None + frag = fill_before_frag.append(frag) if open_end <= 0: - frag = frag.append( - node.type.content_match.match_fragment(frag).fill_before( - Fragment.empty, True - ) - ) + matched_fragment = node.type.content_match.match_fragment(frag) + assert matched_fragment is not None + fill_before_frag = matched_fragment.fill_before(Fragment.empty, True) + assert fill_before_frag is not None + frag = frag.append(fill_before_frag) return node.copy(frag) -def content_after_fits(to_, depth, type_, match, open_): +def content_after_fits( + to_: ResolvedPos, + depth: int, + type_: NodeType, + match: ContentMatch, + open_: bool, +) -> Fragment | None: node = to_.node(depth) index = to_.index_after(depth) if open_ else to_.index(depth) if index == node.child_count and not type_.compatible_content(node.type): @@ -465,16 +514,23 @@ def content_after_fits(to_, depth, type_, match, open_): return fit if fit and not invalid_marks(type_, node.content, index) else None -def invalid_marks(type_, fragment, start): +def invalid_marks(type_: NodeType, fragment: Fragment, start: int) -> bool: for i in range(start, fragment.child_count): if not type_.allows_marks(fragment.child(i).marks): return True return False -def close_fragment(fragment, depth, old_open, new_open, parent): +def close_fragment( + fragment: Fragment, + depth: int, + old_open: int, + new_open: int, + parent: Node | None, +) -> Fragment: if depth < old_open: first = fragment.first_child + assert first is not None fragment = fragment.replace_child( 0, first.copy( @@ -482,15 +538,26 @@ def close_fragment(fragment, depth, old_open, new_open, parent): ), ) if depth > new_open: + assert parent is not None match = parent.content_match_at(0) - start = match.fill_before(fragment).append(fragment) - fragment = start.append( - match.match_fragment(start).fill_before(Fragment.empty, True) + fill_before_frag = match.fill_before(fragment) + assert fill_before_frag is not None + start = fill_before_frag.append(fragment) + matched_fragment = match.match_fragment(start) + assert matched_fragment is not None + matched_fragment_fill_before = matched_fragment.fill_before( + Fragment.empty, True ) + assert matched_fragment_fill_before is not None + fragment = start.append(matched_fragment_fill_before) + return fragment -def covered_depths(from__, to_): +def covered_depths( + from__: ResolvedPos, + to_: ResolvedPos, +) -> list[int]: result = [] min_depth = min(from__.depth, to_.depth) for d in range(min_depth, -1, -1): diff --git a/prosemirror/transform/replace_step.py b/prosemirror/transform/replace_step.py index c4a0802..e7d92bb 100644 --- a/prosemirror/transform/replace_step.py +++ b/prosemirror/transform/replace_step.py @@ -1,38 +1,42 @@ -from prosemirror.model import Slice +from typing import cast -from .map import StepMap -from .step import Step, StepResult +from prosemirror.model import Node, Schema, Slice +from prosemirror.transform.map import Mappable, StepMap +from prosemirror.transform.step import Step, StepResult, step_json_id +from prosemirror.utils import JSONDict class ReplaceStep(Step): - def __init__(self, from_: int, to: int, slice: Slice, structure=None): + def __init__( + self, from_: int, to: int, slice: Slice, structure: bool | None = None + ) -> None: super().__init__() self.from_ = from_ self.to = to self.slice = slice self.structure = bool(structure) - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: if self.structure and content_between(doc, self.from_, self.to): return StepResult.fail("Structure replace would overrite content") return StepResult.from_replace(doc, self.from_, self.to, self.slice) - def get_map(self): + def get_map(self) -> StepMap: return StepMap([self.from_, self.to - self.from_, self.slice.size]) - def invert(self, doc): + def invert(self, doc: Node) -> "ReplaceStep": return ReplaceStep( self.from_, self.from_ + self.slice.size, doc.slice(self.from_, self.to) ) - def map(self, mapping): + def map(self, mapping: Mappable) -> "ReplaceStep | None": from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if from_.deleted and to.deleted: return None return ReplaceStep(from_.pos, max(from_.pos, to.pos), self.slice) - def merge(self, other: "ReplaceStep"): + def merge(self, other: "Step") -> "ReplaceStep | None": if not isinstance(other, ReplaceStep) or other.structure or self.structure: return None if ( @@ -67,20 +71,27 @@ def merge(self, other: "ReplaceStep"): return ReplaceStep(other.from_, self.to, slice, self.structure) return None - def to_json(self): - json_data = {"stepType": "replace", "from": self.from_, "to": self.to} + def to_json(self) -> JSONDict: + json_data: JSONDict = {"stepType": "replace", "from": self.from_, "to": self.to} if self.slice.size: - json_data["slice"] = self.slice.to_json() + json_data = { + **json_data, + "slice": self.slice.to_json(), + } if self.structure: - json_data["structure"] = True + json_data = { + **json_data, + "structure": True, + } return json_data @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "ReplaceStep": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["from"], int) or not isinstance( json_data["to"], int ): @@ -88,12 +99,12 @@ def from_json(schema, json_data): return ReplaceStep( json_data["from"], json_data["to"], - Slice.from_json(schema, json_data.get("slice")), + Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), bool(json_data.get("structure")), ) -Step.json_id("replace", ReplaceStep) +step_json_id("replace", ReplaceStep) class ReplaceAroundStep(Step): @@ -105,8 +116,8 @@ def __init__( gap_to: int, slice: Slice, insert: int, - structure=None, - ): + structure: bool | None = None, + ) -> None: super().__init__() self.from_ = from_ self.to = to @@ -116,7 +127,7 @@ def __init__( self.insert = insert self.structure = bool(structure) - def apply(self, doc): + def apply(self, doc: Node) -> StepResult: if self.structure and ( content_between(doc, self.from_, self.gap_from) or content_between(doc, self.gap_to, self.to) @@ -130,7 +141,7 @@ def apply(self, doc): return StepResult.fail("Content does not fit in gap") return StepResult.from_replace(doc, self.from_, self.to, inserted) - def get_map(self): + def get_map(self) -> StepMap: return StepMap( [ self.from_, @@ -142,7 +153,7 @@ def get_map(self): ] ) - def invert(self, doc): + def invert(self, doc: Node) -> "ReplaceAroundStep": gap = self.gap_to - self.gap_from return ReplaceAroundStep( self.from_, @@ -156,7 +167,7 @@ def invert(self, doc): self.structure, ) - def map(self, mapping): + def map(self, mapping: Mappable) -> "ReplaceAroundStep | None": from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) gap_from = mapping.map(self.gap_from, -1) @@ -167,8 +178,8 @@ def map(self, mapping): from_.pos, to.pos, gap_from, gap_to, self.slice, self.insert, self.structure ) - def to_json(self): - json_data = { + def to_json(self) -> JSONDict: + json_data: JSONDict = { "stepType": "replaceAround", "from": self.from_, "to": self.to, @@ -177,17 +188,26 @@ def to_json(self): "insert": self.insert, } if self.slice.size: - json_data["slice"] = self.slice.to_json() + json_data = { + **json_data, + "slice": self.slice.to_json(), + } if self.structure: - json_data["structure"] = True + json_data = { + **json_data, + "structure": True, + } return json_data @staticmethod - def from_json(schema, json_data): + def from_json( + schema: Schema[str, str], json_data: JSONDict | str + ) -> "ReplaceAroundStep": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if ( not isinstance(json_data["from"], int) or not isinstance(json_data["to"], int) @@ -201,16 +221,16 @@ def from_json(schema, json_data): json_data["to"], json_data["gapFrom"], json_data["gapTo"], - Slice.from_json(schema, json_data.get("slice")), + Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), json_data["insert"], bool(json_data.get("structure")), ) -Step.json_id("replaceAround", ReplaceAroundStep) +step_json_id("replaceAround", ReplaceAroundStep) -def content_between(doc, from_, to): +def content_between(doc: Node, from_: int, to: int) -> bool: from__ = doc.resolve(from_) dist = to - from_ depth = from__.depth diff --git a/prosemirror/transform/step.py b/prosemirror/transform/step.py index 2789b0b..eebbd45 100644 --- a/prosemirror/transform/step.py +++ b/prosemirror/transform/step.py @@ -1,74 +1,88 @@ import abc -from typing import Dict, Type +from typing import Literal, Type, TypeVar, cast, overload -from prosemirror.model import ReplaceError - -from .map import StepMap +from prosemirror.model import Node, ReplaceError, Schema, Slice +from prosemirror.transform.map import Mappable, StepMap +from prosemirror.utils import JSONDict # like a registry -STEPS_BY_ID: Dict[str, Type["Step"]] = {} +STEPS_BY_ID: dict[str, Type["Step"]] = {} +StepSubclass = TypeVar("StepSubclass", bound="Step") class Step(metaclass=abc.ABCMeta): + json_id: str + @abc.abstractmethod - def apply(self, _doc): - return + def apply(self, _doc: Node) -> "StepResult": + ... - def get_map(self): + def get_map(self) -> StepMap: return StepMap.empty @abc.abstractmethod - def invert(self, _doc): - return + def invert(self, _doc: Node) -> "Step": + ... @abc.abstractmethod - def map(self, _mapping): - return + def map(self, _mapping: Mappable) -> "Step | None": + ... - def merge(self, _other): + def merge(self, _other: "Step") -> "Step | None": return None @abc.abstractmethod - def to_json(self): - return + def to_json(self) -> JSONDict: + ... @staticmethod - def from_json(schema, json_data): + def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not json_data or not json_data.get("stepType"): raise ValueError("Invalid inpit for Step.from_json") - type = STEPS_BY_ID.get(json_data["stepType"]) + type = STEPS_BY_ID.get(cast(str, json_data["stepType"])) if not type: raise ValueError(f'no step type {json_data["stepType"]} defined') return type.from_json(schema, json_data) - @staticmethod - def json_id(id, step_class): - if id in STEPS_BY_ID: - raise ValueError(f"Duplicated JSON ID for step type: {id}") - STEPS_BY_ID[id] = step_class - step_class.json_id = id - return step_class + +def step_json_id(id: str, step_class: type[StepSubclass]) -> type[StepSubclass]: + if id in STEPS_BY_ID: + raise ValueError(f"Duplicated JSON ID for step type: {id}") + + STEPS_BY_ID[id] = step_class + step_class.json_id = id + + return step_class class StepResult: - def __init__(self, doc, failed): + @overload + def __init__(self, doc: Node, failed: Literal[None]) -> None: + ... + + @overload + def __init__(self, doc: None, failed: str) -> None: + ... + + def __init__(self, doc: Node | None, failed: str | None) -> None: self.doc = doc self.failed = failed @classmethod - def ok(cls, doc): + def ok(cls, doc: Node) -> "StepResult": return cls(doc, None) @classmethod - def fail(cls, message): + def fail(cls, message: str) -> "StepResult": return cls(None, message) @classmethod - def from_replace(cls, doc, from_, to, slice): + def from_replace(cls, doc: Node, from_: int, to: int, slice: Slice) -> "StepResult": try: return cls.ok(doc.replace(from_, to, slice)) except ReplaceError as e: diff --git a/prosemirror/transform/structure.py b/prosemirror/transform/structure.py index 7240f5f..65240d5 100644 --- a/prosemirror/transform/structure.py +++ b/prosemirror/transform/structure.py @@ -1,13 +1,15 @@ -from prosemirror.model import Node +from typing import TypedDict +from prosemirror.model import Attrs, ContentMatch, Node, NodeRange, NodeType, Slice -def can_cut(node, start, end): + +def can_cut(node: Node, start: int, end: int) -> bool: if start == 0 or node.can_replace(start, node.child_count): return (end == node.child_count) or node.can_replace(0, end) return False -def lift_target(range_): +def lift_target(range_: NodeRange) -> int | None: parent = range_.parent content = parent.content.cut_by_index(range_.start_index, range_.end_index) depth = range_.depth @@ -25,18 +27,34 @@ def lift_target(range_): break depth -= 1 + return None + + +class NodeTypeWithAttrs(TypedDict): + type: NodeType + attrs: Attrs | None -def find_wrapping(range_, node_type, attrs=None, inner_range=None): + +def find_wrapping( + range_: NodeRange, + node_type: NodeType, + attrs: Attrs | None = None, + inner_range: NodeRange | None = None, +) -> list[NodeTypeWithAttrs] | None: if inner_range is None: inner_range = range_ + around = find_wrapping_outside(range_, node_type) - inner = False + inner = None + if around is not None: inner = find_wrapping_inside(inner_range, node_type) else: return None + if inner is None: return None + return ( [with_attrs(item) for item in around] + [{"type": node_type, "attrs": attrs}] @@ -44,11 +62,11 @@ def find_wrapping(range_, node_type, attrs=None, inner_range=None): ) -def with_attrs(type): - return {"type": type, "attrs": None} +def with_attrs(type: NodeType) -> NodeTypeWithAttrs: + return NodeTypeWithAttrs(type=type, attrs=None) -def find_wrapping_outside(range_, type): +def find_wrapping_outside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -59,41 +77,54 @@ def find_wrapping_outside(range_, type): return around if parent.can_replace_with(start_index, end_index, outer) else None -def find_wrapping_inside(range_, type): +def find_wrapping_inside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: parent = range_.parent start_index = range_.start_index end_index = range_.end_index inner = parent.child(start_index) inside = type.content_match.find_wrapping(inner.type) + if inside is None: return None + last_type = inside[-1] if len(inside) else type - inner_match = last_type.content_match + inner_match: ContentMatch | None = last_type.content_match i = start_index + while inner_match and i < end_index: inner_match = inner_match.match_type(parent.child(i).type) i += 1 + if not inner_match or not inner_match.valid_end: return None + return inside -def can_change_type(doc, pos, type): +def can_change_type(doc: Node, pos: int, type: NodeType) -> bool: pos_ = doc.resolve(pos) index = pos_.index() return pos_.parent.can_replace_with(index, index + 1, type) -def can_split(doc, pos, depth=None, types_after=None): +def can_split( + doc: Node, + pos: int, + depth: int | None = None, + types_after: list[dict[str, NodeType]] | None = None, +) -> bool: if depth is None: depth = 1 pos_ = doc.resolve(pos) base = pos_.depth - depth - inner_type = None + inner_type: dict[str, NodeType] | Node | None = None + if types_after: inner_type = types_after[-1] + if not inner_type: inner_type = pos_.parent + if isinstance(inner_type, Node): if ( base < 0 @@ -104,6 +135,7 @@ def can_split(doc, pos, depth=None, types_after=None): ) ): return False + elif isinstance(inner_type, dict): if ( base < 0 @@ -114,26 +146,29 @@ def can_split(doc, pos, depth=None, types_after=None): ) ): return False + d = pos_.depth - 1 i = depth - 2 + while d > base: node = pos_.node(d) index = pos_.index(d) if node.type.spec.get("isolating"): return False rest = node.content.cut_by_index(index, node.child_count) - after = None + after: dict[str, NodeType] | Node | None = None + if types_after and len(types_after) > i: after = types_after[i] if not after: after = node + if isinstance(after, dict): - if after != node: - rest = rest.replace_child(0, after["type"].create(after.get("attrs"))) if not node.can_replace(index + 1, node.child_count) or not after[ "type" ].valid_content(rest): return False + if isinstance(after, Node): if after != node: rest = rest.replace_child(0, after.type.create(after.attrs)) @@ -150,7 +185,7 @@ def can_split(doc, pos, depth=None, types_after=None): ) -def can_join(doc, pos): +def can_join(doc: Node, pos: int) -> bool | None: pos_ = doc.resolve(pos) index = pos_.index() return ( @@ -160,13 +195,13 @@ def can_join(doc, pos): ) -def joinable(a, b): +def joinable(a: Node | None, b: Node | None) -> bool: if a and b and not a.is_leaf: return a.can_append(b) return False -def join_point(doc, pos, dir=-1): +def join_point(doc: Node, pos: int, dir: int = -1) -> int | None: pos_ = doc.resolve(pos) for d in range(pos_.depth, -1, -1): before = None @@ -193,8 +228,10 @@ def join_point(doc, pos, dir=-1): break pos = pos_.before(d) if dir < 0 else pos_.after(d) + return None -def insert_point(doc, pos, node_type): + +def insert_point(doc: Node, pos: int, node_type: NodeType) -> int | None: pos_ = doc.resolve(pos) if pos_.parent.can_replace_with(pos_.index(), pos_.index(), node_type): return pos @@ -213,13 +250,16 @@ def insert_point(doc, pos, node_type): if index < pos_.node(d).child_count: return None + return None + -def drop_point(doc, pos, slice): +def drop_point(doc: Node, pos: int, slice: Slice) -> int | None: pos_ = doc.resolve(pos) if not slice.content.size: return pos content = slice.content for i in range(slice.open_start): + assert content.first_child is not None content = content.first_child.content pass_ = 1 while pass_ <= (2 if slice.open_start == 0 and slice.size else 1): @@ -236,10 +276,11 @@ def drop_point(doc, pos, slice): if pass_ == 1: fits = parent.can_replace(insert_pos, insert_pos, content) else: + assert content.first_child is not None wrapping = parent.content_match_at(insert_pos).find_wrapping( content.first_child.type ) - fits = wrapping and parent.can_replace_with( + fits = wrapping is not None and parent.can_replace_with( insert_pos, insert_pos, wrapping[0] ) if fits: diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index 4fc5862..a2b027e 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -1,18 +1,39 @@ -from typing import Union - -from prosemirror.model import Fragment, Mark, MarkType, Node, NodeType, Slice - -from . import replace, structure -from .attr_step import AttrStep -from .map import Mapping -from .mark_step import AddMarkStep, AddNodeMarkStep, RemoveMarkStep, RemoveNodeMarkStep -from .replace import close_fragment, covered_depths, fits_trivially, replace_step -from .replace_step import ReplaceAroundStep, ReplaceStep -from .structure import can_change_type, insert_point - - -def defines_content(type: Union[NodeType, MarkType]): - return type.spec.get("defining") or type.spec.get("definingForContent") +from typing import TypedDict + +from prosemirror.model import ( + Attrs, + ContentMatch, + Fragment, + Mark, + MarkType, + Node, + NodeRange, + NodeType, + Slice, +) +from prosemirror.transform import ( + AddMarkStep, + AddNodeMarkStep, + AttrStep, + Mapping, + RemoveMarkStep, + RemoveNodeMarkStep, + ReplaceAroundStep, + ReplaceStep, + Step, + StepResult, + close_fragment, + covered_depths, + fits_trivially, + structure, +) +from prosemirror.transform.replace import replace_step + + +def defines_content(type: NodeType | MarkType) -> bool | None: + if isinstance(type, NodeType): + return type.spec.get("defining") or type.spec.get("definingForContent") + return False class TransformError(ValueError): @@ -28,53 +49,57 @@ class Transform: drop_point = structure.drop_point lift_target = structure.lift_target find_wrapping = structure.find_wrapping - replace_step = replace.replace_step + replace_step = replace_step - def __init__(self, doc: Node): + def __init__(self, doc: Node) -> None: self.doc = doc - self.steps = [] # type: ignore - self.docs = [] # type: ignore + self.steps: list[Step] = [] + self.docs: list[Node] = [] self.mapping = Mapping() @property - def before(self): + def before(self) -> Node: return self.docs[0] if self.docs else self.doc - def step(self, object): + def step(self, object: Step) -> "Transform": result = self.maybe_step(object) if result.failed: raise TransformError(result.failed) return self - def maybe_step(self, step): + def maybe_step(self, step: Step) -> StepResult: result = step.apply(self.doc) - if not result.failed: + if not result.failed and result.doc: self.add_step(step, result.doc) return result - def doc_changed(self): + def doc_changed(self) -> bool: return bool(len(self.steps)) - def add_step(self, step, doc): + def add_step(self, step: Step, doc: Node) -> None: self.docs.append(self.doc) self.steps.append(step) self.mapping.append_map(step.get_map()) self.doc = doc # mark.js - def add_mark(self, from_, to, mark): + def add_mark(self, from_: int, to: int, mark: Mark) -> "Transform": removed = [] added = [] - removing = None - adding = None + removing: RemoveMarkStep | None = None + adding: AddMarkStep | None = None - def iteratee(node, pos, parent, *args): + def iteratee(node: Node, pos: int, parent: Node | None, i: int) -> None: nonlocal removing nonlocal adding if not node.is_inline: return marks = node.marks - if not mark.is_in_set(marks) and parent.type.allows_mark_type(mark.type): + if ( + not mark.is_in_set(marks) + and parent + and parent.type.allows_mark_type(mark.type) + ): start = max(pos, from_) end = min(pos + node.node_size, to) new_set = mark.add_to_set(marks) @@ -96,32 +121,44 @@ def iteratee(node, pos, parent, *args): added.append(adding) self.doc.nodes_between(from_, to, iteratee) + item: Step for item in removed: self.step(item) for item in added: self.step(item) return self - def remove_mark(self, from_, to, mark=None): - matched = [] + def remove_mark( + self, + from_: int, + to: int, + mark: Mark | MarkType | None = None, + ) -> "Transform": + class MatchedTypedDict(TypedDict): + style: Mark + from_: int + to: int + step: int + + matched: list[MatchedTypedDict] = [] step = 0 - def iteratee(node, pos, *args): + def iteratee(node: Node, pos: int, parent: Node | None, i: int) -> bool | None: nonlocal step if not node.is_inline: - return + return None step += 1 to_remove = None if isinstance(mark, MarkType): set_ = node.marks while True: - found = mark.is_in_set(set_) - if not found: + found_mark = mark.is_in_set(set_) + if not found_mark: break if to_remove is None: to_remove = [] - to_remove.append(found) - set_ = found.remove_from_set(set_) + to_remove.append(found_mark) + set_ = found_mark.remove_from_set(set_) elif mark: if mark.is_in_set(node.marks): to_remove = [mark] @@ -141,26 +178,35 @@ def iteratee(node, pos, *args): matched.append( { "style": style, - "from": max(pos, from_), + "from_": max(pos, from_), "to": end, "step": step, } ) + return None self.doc.nodes_between(from_, to, iteratee) for item in matched: - self.step(RemoveMarkStep(item["from"], item["to"], item["style"])) + self.step(RemoveMarkStep(item["from_"], item["to"], item["style"])) return self - def clear_incompatible(self, pos, parent_type, match=None): + def clear_incompatible( + self, + pos: int, + parent_type: NodeType, + match: ContentMatch | None = None, + ) -> "Transform": if match is None: match = parent_type.content_match node = self.doc.node_at(pos) + assert match is not None + assert node is not None del_steps = [] cur = pos + 1 for i in range(node.child_count): child = node.child(i) end = cur + child.node_size + assert match is not None allowed = match.match_type(child.type) if not allowed: del_steps.append(ReplaceStep(cur, end, Slice.empty)) @@ -172,13 +218,19 @@ def clear_incompatible(self, pos, parent_type, match=None): cur = end if not match.valid_end: fill = match.fill_before(Fragment.empty, True) + assert fill is not None self.replace(cur, cur, Slice(fill, 0, 0)) for item in reversed(del_steps): self.step(item) return self # replace.js - def replace(self, from_, to=None, slice=None): + def replace( + self, + from_: int, + to: int | None = None, + slice: Slice | None = None, + ) -> "Transform": if to is None: to = from_ if slice is None: @@ -188,16 +240,25 @@ def replace(self, from_, to=None, slice=None): self.step(step) return self - def replace_with(self, from_, to, content): + def replace_with( + self, + from_: int, + to: int, + content: list[Node] | Node, + ) -> "Transform": return self.replace(from_, to, Slice(Fragment.from_(content), 0, 0)) - def delete(self, from_, to): + def delete(self, from_: int, to: int) -> "Transform": return self.replace(from_, to, Slice.empty) - def insert(self, pos, content): + def insert( + self, + pos: int, + content: list[Node] | Node, + ) -> "Transform": return self.replace_with(pos, pos, content) - def replace_range(self, from_, to, slice): + def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": if not slice.size: return self.delete_range(from_, to) from__ = self.doc.resolve(from_) @@ -232,9 +293,11 @@ def replace_range(self, from_, to, slice): i = 0 while True: node = content.first_child - left_nodes.append(node) - if i == slice.open_start: + + if i == slice.open_start or node is None: break + + left_nodes.append(node) content = node.content i += 1 @@ -289,26 +352,30 @@ def replace_range(self, from_, to, slice): to = to_.after(depth) return self - def replace_range_with(self, from_, to, node): + def replace_range_with(self, from_: int, to: int, node: Node) -> "Transform": if ( not node.is_inline and from_ == to and self.doc.resolve(from_).parent.content.size ): - point = insert_point(self.doc, from_, node.type) + point = structure.insert_point(self.doc, from_, node.type) if point is not None: from_ = to = point + return self.replace_range(from_, to, Slice(Fragment.from_(node), 0, 0)) - def delete_range(self, from_, to): + def delete_range(self, from_: int, to: int) -> "Transform": from__ = self.doc.resolve(from_) to_ = self.doc.resolve(to) covered = covered_depths(from__, to_) + for i in range(len(covered)): depth = covered[i] last = len(covered) - 1 == i + if (last and depth == 0) or from__.node(depth).type.content_match.valid_end: return self.delete(from__.start(depth), to_.end(depth)) + if depth > 0 and ( last or from__.node(depth - 1).can_replace( @@ -316,7 +383,9 @@ def delete_range(self, from_, to): ) ): return self.delete(from__.before(depth), to_.after(depth)) + d = 1 + while d <= from__.depth and d <= to_.depth: if ( from_ - from__.start(d) == from__.depth - d @@ -325,10 +394,11 @@ def delete_range(self, from_, to): ): return self.delete(from__.before(d), to) d += 1 + return self.delete(from_, to) # structure.js - def lift(self, range_, target): + def lift(self, range_: NodeRange, target: int) -> "Transform": from__ = range_.from_ to_ = range_.to depth = range_.depth @@ -374,7 +444,9 @@ def lift(self, range_, target): ) ) - def wrap(self, range_, wrappers): + def wrap( + self, range_: NodeRange, wrappers: list[structure.NodeTypeWithAttrs] + ) -> "Transform": content = Fragment.empty i = len(wrappers) - 1 while i >= 0: @@ -397,18 +469,26 @@ def wrap(self, range_, wrappers): ) ) - def set_block_type(self, from_, to, type, attrs): + def set_block_type( + self, + from_: int, + to: int | None, + type: NodeType, + attrs: Attrs | None, + ) -> "Transform": if to is None: to = from_ if not type.is_text_block: raise ValueError("Type given to set_block_type should be a textblock") map_from = len(self.steps) - def iteratee(node: "Node", pos, *args): + def iteratee( + node: "Node", pos: int, parent: "Node | None", i: int + ) -> bool | None: if ( node.is_text_block and not node.has_markup(type, attrs) - and can_change_type( + and structure.can_change_type( self.doc, self.mapping.slice(map_from).map(pos), type ) ): @@ -430,11 +510,18 @@ def iteratee(node: "Node", pos, *args): ) ) return False + return None self.doc.nodes_between(from_, to, iteratee) return self - def set_node_markup(self, pos, type, attrs, marks=None): + def set_node_markup( + self, + pos: int, + type: NodeType, + attrs: Attrs, + marks: None = None, + ) -> "Transform": node = self.doc.node_at(pos) if not node: raise ValueError("No node at given position") @@ -457,23 +544,33 @@ def set_node_markup(self, pos, type, attrs, marks=None): ) ) - def set_node_attribute(self, pos, attr, value): + def set_node_attribute(self, pos: int, attr: str, value: str | int) -> "Transform": return self.step(AttrStep(pos, attr, value)) - def add_node_mark(self, pos, mark): + def add_node_mark(self, pos: int, mark: Mark) -> "Transform": return self.step(AddNodeMarkStep(pos, mark)) - def remove_node_mark(self, pos, mark): - if not isinstance(mark, Mark): + def remove_node_mark(self, pos: int, mark: Mark | MarkType) -> "Transform": + if isinstance(mark, MarkType): node = self.doc.node_at(pos) + if not node: - raise ValueError("No node at position " + pos) - mark = mark.is_in_set(node.marks) - if not mark: + raise ValueError(f"No node at position {pos}") + + mark_in_set = mark.is_in_set(node.marks) + + if not mark_in_set: return self + + mark = mark_in_set return self.step(RemoveNodeMarkStep(pos, mark)) - def split(self, pos, depth=None, types_after=None): + def split( + self, + pos: int, + depth: int | None = None, + types_after: list[structure.NodeTypeWithAttrs] | None = None, + ) -> "Transform": if depth is None: depth = 1 pos_ = self.doc.resolve(pos) @@ -498,6 +595,6 @@ def split(self, pos, depth=None, types_after=None): ReplaceStep(pos, pos, Slice(before.append(after), depth, depth), True) ) - def join(self, pos, depth=1): + def join(self, pos: int, depth: int = 1) -> "Transform": step = ReplaceStep(pos - depth, pos + depth, Slice.empty, True) return self.step(step) From 2830ca9c6ea7f28f4a525dd4f67063f6d6a70ded Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Mon, 13 Nov 2023 14:37:42 -0500 Subject: [PATCH 24/40] Remove leftover debugging comment. --- prosemirror/model/schema.py | 1 - 1 file changed, 1 deletion(-) diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index b7622bb..0fb28f9 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -293,7 +293,6 @@ def excludes(self, other: "MarkType") -> bool: return any(other.name == e.name for e in self.excluded) -# XXX I don't get these... Nodes = TypeVar("Nodes", bound=str, covariant=True) Marks = TypeVar("Marks", bound=str, covariant=True) From 5c419513c693cb717fe2b1cf68fbd921e5c79210 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Tue, 14 Nov 2023 08:26:26 -0500 Subject: [PATCH 25/40] Using generic types for most uses of Schema, just like in the original repo. --- prosemirror/model/__init__.py | 3 +-- prosemirror/model/fragment.py | 2 +- prosemirror/model/from_dom.py | 16 ++++++++-------- prosemirror/model/mark.py | 4 +--- prosemirror/model/node.py | 8 ++++---- prosemirror/model/replace.py | 4 ++-- prosemirror/model/schema.py | 4 ++-- prosemirror/model/to_dom.py | 6 +++--- prosemirror/schema/basic/schema_basic.py | 4 +++- prosemirror/test_builder/__init__.py | 4 +++- prosemirror/test_builder/build.py | 3 ++- prosemirror/transform/attr_step.py | 4 ++-- prosemirror/transform/mark_step.py | 10 +++++----- prosemirror/transform/replace.py | 2 +- prosemirror/transform/replace_step.py | 6 +++--- prosemirror/transform/step.py | 4 ++-- prosemirror/transform/structure.py | 3 ++- prosemirror/transform/transform.py | 2 +- prosemirror/utils.py | 2 ++ 19 files changed, 48 insertions(+), 43 deletions(-) diff --git a/prosemirror/model/__init__.py b/prosemirror/model/__init__.py index f255024..4c119e4 100644 --- a/prosemirror/model/__init__.py +++ b/prosemirror/model/__init__.py @@ -5,7 +5,7 @@ from .node import Node from .replace import ReplaceError, Slice from .resolvedpos import NodeRange, ResolvedPos -from .schema import Attrs, MarkType, NodeType, Schema +from .schema import MarkType, NodeType, Schema from .to_dom import DOMSerializer __all__ = [ @@ -16,7 +16,6 @@ "Slice", "ReplaceError", "Mark", - "Attrs", "Schema", "NodeType", "MarkType", diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 58bf892..a1be381 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -244,7 +244,7 @@ def to_json(self) -> JSONList | None: return None @classmethod - def from_json(cls, schema: "Schema[str, str]", value: Any) -> "Fragment": + def from_json(cls, schema: "Schema[Any, Any]", value: Any) -> "Fragment": if not value: return cls.empty diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 8162c69..b6b1ac5 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -7,7 +7,7 @@ from lxml.cssselect import CSSSelector from lxml.html import HtmlElement as DOMNode -from prosemirror.utils import JSONDict +from prosemirror.utils import Attrs, JSONDict from .content import ContentMatch from .fragment import Fragment @@ -15,7 +15,7 @@ from .node import Node, TextNode from .replace import Slice from .resolvedpos import ResolvedPos -from .schema import Attrs, MarkType, NodeType, Schema +from .schema import MarkType, NodeType, Schema WSType = bool | Literal["full"] | None @@ -57,7 +57,7 @@ class ParseRule: attrs: Attrs | None get_attrs: Callable[[DOMNode], None | Attrs | Literal[False]] | None content_element: str | DOMNode | Callable[[DOMNode], DOMNode] | None - get_content: Callable[[DOMNode, Schema[str, str]], Fragment] | None + get_content: Callable[[DOMNode, Schema[Any, Any]], Fragment] | None preserve_whitespace: WSType @classmethod @@ -88,10 +88,10 @@ class DOMParser: _styles: list[ParseRule] _normalize_lists: bool - schema: Schema[str, str] + schema: Schema[Any, Any] rules: list[ParseRule] - def __init__(self, schema: Schema[str, str], rules: list[ParseRule]) -> None: + def __init__(self, schema: Schema[Any, Any], rules: list[ParseRule]) -> None: self.schema = schema self.rules = rules self._tags = [rule for rule in rules if rule.tag is not None] @@ -207,7 +207,7 @@ def match_style( return None @classmethod - def schema_rules(cls, schema: Schema[str, str]) -> list[ParseRule]: + def schema_rules(cls, schema: Schema[Any, Any]) -> list[ParseRule]: result: list[ParseRule] = [] def insert(rule: ParseRule) -> None: @@ -251,7 +251,7 @@ def insert(rule: ParseRule) -> None: return result @classmethod - def from_schema(cls, schema: Schema[str, str]) -> "DOMParser": + def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": if "dom_parser" not in schema.cached: schema.cached["dom_parser"] = DOMParser( schema, DOMParser.schema_rules(schema) @@ -1166,7 +1166,7 @@ def get_node_type(element: DOMNode) -> int: return 8 -def from_html(schema: Schema[str, str], html: str) -> JSONDict: +def from_html(schema: Schema[Any, Any], html: str) -> JSONDict: fragment = lxml.html.fragment_fromstring(html, create_parent="document-fragment") prose_doc = DOMParser.from_schema(schema).parse(fragment) diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index 4ae33a1..a39b84f 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,9 +1,7 @@ import copy from typing import TYPE_CHECKING, Any, Final, cast -from prosemirror.utils import JSONDict - -from .schema import Attrs +from prosemirror.utils import Attrs, JSONDict if TYPE_CHECKING: from .schema import MarkType, Schema diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index bb48697..6e9703d 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,9 +1,9 @@ import copy -from typing import TYPE_CHECKING, Callable, TypedDict, cast +from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast from typing_extensions import TypeGuard -from prosemirror.utils import JSONDict, text_length +from prosemirror.utils import Attrs, JSONDict, text_length from .comparedeep import compare_deep from .fragment import Fragment @@ -13,7 +13,7 @@ if TYPE_CHECKING: from .content import ContentMatch - from .schema import Attrs, MarkType, NodeType, Schema + from .schema import MarkType, NodeType, Schema empty_attrs: JSONDict = {} @@ -325,7 +325,7 @@ def to_json(self) -> JSONDict: return obj @classmethod - def from_json(cls, schema: "Schema[str, str]", json_data: JSONDict | str) -> "Node": + def from_json(cls, schema: "Schema[Any, Any]", json_data: JSONDict | str) -> "Node": if isinstance(json_data, str): import json diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 9440f78..7aa4d54 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from prosemirror.utils import JSONDict @@ -104,7 +104,7 @@ def to_json(self) -> JSONDict | None: @classmethod def from_json( cls, - schema: "Schema[str, str]", + schema: "Schema[Any, Any]", json_data: JSONDict | None, ) -> "Slice": if not json_data: diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 0fb28f9..9b38b13 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -251,7 +251,7 @@ class MarkType: instance: Mark | None def __init__( - self, name: str, rank: int, schema: "Schema[str, str]", spec: "MarkSpec" + self, name: str, rank: int, schema: "Schema[Any, Any]", spec: "MarkSpec" ) -> None: self.name = name self.schema = schema @@ -455,7 +455,7 @@ def node_type(self, name: str) -> NodeType: return found -def gather_marks(schema: Schema[str, str], marks: list[str]) -> list[MarkType]: +def gather_marks(schema: Schema[Any, Any], marks: list[str]) -> list[MarkType]: found = [] for name in marks: mark = schema.marks.get(name) diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 8e0b954..63ff971 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -181,12 +181,12 @@ def render_spec(cls, structure: HTMLOutputSpec) -> tuple[HTMLNode, Element | Non return dom, content_dom @classmethod - def from_schema(cls, schema: Schema[str, str]) -> "DOMSerializer": + def from_schema(cls, schema: Schema[Any, Any]) -> "DOMSerializer": return cls(cls.nodes_from_schema(schema), cls.marks_from_schema(schema)) @classmethod def nodes_from_schema( - cls, schema: Schema[str, str] + cls, schema: Schema[Any, Any] ) -> dict[str, Callable[["Node"], HTMLOutputSpec]]: result = gather_to_dom(schema.nodes) if "text" not in result: @@ -195,7 +195,7 @@ def nodes_from_schema( @classmethod def marks_from_schema( - cls, schema: Schema[str, str] + cls, schema: Schema[Any, Any] ) -> dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: return gather_to_dom(schema.marks) diff --git a/prosemirror/schema/basic/schema_basic.py b/prosemirror/schema/basic/schema_basic.py index 88be8e4..39857af 100644 --- a/prosemirror/schema/basic/schema_basic.py +++ b/prosemirror/schema/basic/schema_basic.py @@ -1,3 +1,5 @@ +from typing import Any + from prosemirror.model import Schema from prosemirror.model.schema import MarkSpec, NodeSpec @@ -111,4 +113,4 @@ } -schema: Schema[str, str] = Schema({"nodes": nodes, "marks": marks}) +schema: Schema[Any, Any] = Schema({"nodes": nodes, "marks": marks}) diff --git a/prosemirror/test_builder/__init__.py b/prosemirror/test_builder/__init__.py index 454b0be..22d766c 100644 --- a/prosemirror/test_builder/__init__.py +++ b/prosemirror/test_builder/__init__.py @@ -1,12 +1,14 @@ # type: ignore +from typing import Any + from prosemirror.model import Node, Schema from prosemirror.schema.basic import schema as _schema from prosemirror.schema.list import add_list_nodes from .build import builders -test_schema: Schema[str, str] = Schema( +test_schema: Schema[Any, Any] = Schema( { "nodes": add_list_nodes(_schema.spec["nodes"], "paragraph block*", "block"), "marks": _schema.spec["marks"], diff --git a/prosemirror/test_builder/build.py b/prosemirror/test_builder/build.py index 9e15dca..e1ab6d1 100644 --- a/prosemirror/test_builder/build.py +++ b/prosemirror/test_builder/build.py @@ -2,6 +2,7 @@ import re from collections.abc import Callable +from typing import Any from prosemirror.model import Node, Schema from prosemirror.utils import JSONDict @@ -10,7 +11,7 @@ def flatten( - schema: Schema[str, str], + schema: Schema[Any, Any], children: list[Node | JSONDict | str], f: Callable[[Node], Node], ) -> tuple[list[Node], dict[str, int]]: diff --git a/prosemirror/transform/attr_step.py b/prosemirror/transform/attr_step.py index d7e8c64..364ce85 100644 --- a/prosemirror/transform/attr_step.py +++ b/prosemirror/transform/attr_step.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from prosemirror.model import Fragment, Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -50,7 +50,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "AttrStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AttrStep": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/mark_step.py b/prosemirror/transform/mark_step.py index 32faabc..b117226 100644 --- a/prosemirror/transform/mark_step.py +++ b/prosemirror/transform/mark_step.py @@ -1,4 +1,4 @@ -from typing import Callable, cast +from typing import Any, Callable, cast from prosemirror.model import Fragment, Mark, Node, Schema, Slice from prosemirror.transform.map import Mappable @@ -79,7 +79,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "AddMarkStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AddMarkStep": if isinstance(json_data, str): import json @@ -150,7 +150,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json @@ -211,7 +211,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json @@ -265,7 +265,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index 06ee53d..977ef1a 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -1,7 +1,6 @@ from typing import cast from prosemirror.model import ( - Attrs, ContentMatch, Fragment, Node, @@ -11,6 +10,7 @@ ) from prosemirror.transform.replace_step import ReplaceAroundStep, ReplaceStep from prosemirror.transform.step import Step +from prosemirror.utils import Attrs def replace_step( diff --git a/prosemirror/transform/replace_step.py b/prosemirror/transform/replace_step.py index e7d92bb..30a22a2 100644 --- a/prosemirror/transform/replace_step.py +++ b/prosemirror/transform/replace_step.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from prosemirror.model import Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -86,7 +86,7 @@ def to_json(self) -> JSONDict: return json_data @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "ReplaceStep": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "ReplaceStep": if isinstance(json_data, str): import json @@ -201,7 +201,7 @@ def to_json(self) -> JSONDict: @staticmethod def from_json( - schema: Schema[str, str], json_data: JSONDict | str + schema: Schema[Any, Any], json_data: JSONDict | str ) -> "ReplaceAroundStep": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/step.py b/prosemirror/transform/step.py index eebbd45..7c48ef6 100644 --- a/prosemirror/transform/step.py +++ b/prosemirror/transform/step.py @@ -1,5 +1,5 @@ import abc -from typing import Literal, Type, TypeVar, cast, overload +from typing import Any, Literal, Type, TypeVar, cast, overload from prosemirror.model import Node, ReplaceError, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -36,7 +36,7 @@ def to_json(self) -> JSONDict: ... @staticmethod - def from_json(schema: Schema[str, str], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/structure.py b/prosemirror/transform/structure.py index 65240d5..24affef 100644 --- a/prosemirror/transform/structure.py +++ b/prosemirror/transform/structure.py @@ -1,6 +1,7 @@ from typing import TypedDict -from prosemirror.model import Attrs, ContentMatch, Node, NodeRange, NodeType, Slice +from prosemirror.model import ContentMatch, Node, NodeRange, NodeType, Slice +from prosemirror.utils import Attrs def can_cut(node: Node, start: int, end: int) -> bool: diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index a2b027e..f105015 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -1,7 +1,6 @@ from typing import TypedDict from prosemirror.model import ( - Attrs, ContentMatch, Fragment, Mark, @@ -28,6 +27,7 @@ structure, ) from prosemirror.transform.replace import replace_step +from prosemirror.utils import Attrs def defines_content(type: NodeType | MarkType) -> bool | None: diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 270b036..99d6739 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -7,6 +7,8 @@ JSON: TypeAlias = JSONDict | JSONList | str | int | float | bool | None +Attrs: TypeAlias = JSONDict + def text_length(text: str) -> int: return len(text.encode("utf-16-le")) // 2 From 7c2a85a0fc5268e8742f0da0edc62cc91b680e00 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Tue, 14 Nov 2023 08:27:54 -0500 Subject: [PATCH 26/40] Adding missing import of Attrs. --- prosemirror/model/schema.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index 9b38b13..a3fd6d7 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -13,9 +13,7 @@ from prosemirror.model.fragment import Fragment from prosemirror.model.mark import Mark from prosemirror.model.node import Node, TextNode -from prosemirror.utils import JSON, JSONDict - -Attrs: TypeAlias = JSONDict +from prosemirror.utils import JSON, Attrs, JSONDict def default_attrs(attrs: "Attributes") -> Attrs | None: From e9199322392ff787ec975fa08220a441fcfd7b29 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Tue, 14 Nov 2023 08:44:27 -0500 Subject: [PATCH 27/40] Reverting changes to Mark.to_json that broke tests. --- prosemirror/model/mark.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index a39b84f..babf6a2 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -52,13 +52,7 @@ def eq(self, other: "Mark") -> bool: return self.type.name == other.type.name and self.attrs == other.attrs def to_json(self) -> JSONDict: - result: JSONDict = {"type": self.type.name} - if self.attrs: - result = { - **result, - "attrs": copy.deepcopy(self.attrs), - } - return result + return {"type": self.type.name, "attrs": copy.deepcopy(self.attrs)} @classmethod def from_json( From 1ca91b9ab3acb60cbf1c2194fe3776976f93346a Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Tue, 14 Nov 2023 09:10:01 -0500 Subject: [PATCH 28/40] Restoring attr step. --- prosemirror/transform/attr_step.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prosemirror/transform/attr_step.py b/prosemirror/transform/attr_step.py index 364ce85..f4bbf65 100644 --- a/prosemirror/transform/attr_step.py +++ b/prosemirror/transform/attr_step.py @@ -2,7 +2,7 @@ from prosemirror.model import Fragment, Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap -from prosemirror.transform.step import Step, StepResult +from prosemirror.transform.step import Step, StepResult, step_json_id from prosemirror.utils import JSON, JSONDict @@ -63,4 +63,4 @@ def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AttrStep" return AttrStep(json_data["pos"], json_data["attr"], json_data["value"]) -# Step.json_id("attr", AttrStep) +step_json_id("attr", AttrStep) From ef9feb6f59d5bcadd0a3890e6f86e21e721705a1 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Tue, 14 Nov 2023 09:36:27 -0500 Subject: [PATCH 29/40] Fixing regression on Transform.replace_range. --- prosemirror/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index f105015..2c51bc9 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -293,11 +293,11 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": i = 0 while True: node = content.first_child + left_nodes.append(node) if i == slice.open_start or node is None: break - left_nodes.append(node) content = node.content i += 1 From 3edecdf2a8dc61d06060c628c851932a8420fab7 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 15 Nov 2023 09:55:15 -0500 Subject: [PATCH 30/40] Add py.typed --- prosemirror/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 prosemirror/py.typed diff --git a/prosemirror/py.typed b/prosemirror/py.typed new file mode 100644 index 0000000..e69de29 From e45c7eb4cfabb852577228f69b4b7c454d0c7ed5 Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 15 Nov 2023 10:31:31 -0500 Subject: [PATCH 31/40] Typing fixes after merging main --- prosemirror/transform/doc_attr_step.py | 10 +++++----- prosemirror/transform/structure.py | 6 +++--- prosemirror/transform/transform.py | 9 ++++++--- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/prosemirror/transform/doc_attr_step.py b/prosemirror/transform/doc_attr_step.py index e6dbb07..c00d990 100644 --- a/prosemirror/transform/doc_attr_step.py +++ b/prosemirror/transform/doc_attr_step.py @@ -1,11 +1,10 @@ -from typing import Any, Optional +from typing import Any, Optional, cast from prosemirror.model import Node, Schema -from prosemirror.transform.map import Mappable +from prosemirror.transform.map import Mappable, StepMap +from prosemirror.transform.step import Step, StepResult, step_json_id from prosemirror.utils import JSON, JSONDict -from .step import Step, StepMap, StepResult, step_json_id - class DocAttrStep(Step): def __init__(self, attr: str, value: JSON): @@ -44,7 +43,8 @@ def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "DocAttrSt if isinstance(json_data, str): import json - json_data = json.loads(json_data) + json_data = cast(JSONDict, json.loads(json_data)) + if not isinstance(json_data["attr"], str): raise ValueError("Invalid input for DocAttrStep.from_json") return DocAttrStep(json_data["attr"], json_data["value"]) diff --git a/prosemirror/transform/structure.py b/prosemirror/transform/structure.py index 1248f41..3307da6 100644 --- a/prosemirror/transform/structure.py +++ b/prosemirror/transform/structure.py @@ -112,13 +112,13 @@ def can_split( doc: Node, pos: int, depth: int | None = None, - types_after: list[dict[str, NodeType]] | None = None, + types_after: list[NodeTypeWithAttrs] | None = None, ) -> bool: if depth is None: depth = 1 pos_ = doc.resolve(pos) base = pos_.depth - depth - inner_type: dict[str, NodeType] | Node | None = None + inner_type: NodeTypeWithAttrs | Node | None = None if types_after: inner_type = types_after[-1] @@ -163,7 +163,7 @@ def can_split( rest = rest.replace_child( 0, override_child["type"].create(override_child.get("attrs")) ) - after: dict[str, NodeType] | Node | None = None + after: NodeTypeWithAttrs | Node | None = None if types_after and len(types_after) > i: after = types_after[i] if not after: diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index f2e5344..89ac2be 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -11,6 +11,7 @@ NodeType, Slice, ) +from prosemirror.model.node import TextNode from prosemirror.transform import ( AddMarkStep, AddNodeMarkStep, @@ -28,7 +29,7 @@ structure, ) from prosemirror.transform.replace import replace_step -from prosemirror.utils import Attrs +from prosemirror.utils import JSON, Attrs from .doc_attr_step import DocAttrStep @@ -219,6 +220,7 @@ def clear_incompatible( if not parent_type.allows_mark_type(child.marks[j].type): self.step(RemoveMarkStep(cur, end, child.marks[j])) if child.is_text and not parent_type.spec.get("code"): + assert isinstance(child, TextNode) newline = re.compile(r"\r?\n|\r") slice = None m = newline.search(child.text) @@ -315,9 +317,10 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": i = 0 while True: node = content.first_child + assert node is not None left_nodes.append(node) - if i == slice.open_start or node is None: + if i == slice.open_start: break content = node.content @@ -571,7 +574,7 @@ def set_node_markup( def set_node_attribute(self, pos: int, attr: str, value: str | int) -> "Transform": return self.step(AttrStep(pos, attr, value)) - def set_doc_attribute(self, attr: str, value): + def set_doc_attribute(self, attr: str, value: JSON) -> "Transform": return self.step(DocAttrStep(attr, value)) def add_node_mark(self, pos: int, mark: Mark) -> "Transform": From 81f82325c44fb5da2b3a76ce554009c33f774c5b Mon Sep 17 00:00:00 2001 From: Samuel Cormier-Iijima Date: Wed, 15 Nov 2023 10:38:04 -0500 Subject: [PATCH 32/40] Fix tests + typing in transform replace_range --- prosemirror/transform/transform.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index 89ac2be..eab3afd 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -317,10 +317,9 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": i = 0 while True: node = content.first_child - assert node is not None left_nodes.append(node) - if i == slice.open_start: + if i == slice.open_start or node is None: break content = node.content @@ -329,6 +328,7 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": d = preferred_depth - 1 while d >= 0: left_node = left_nodes[d] + assert left_node is not None def_ = defines_content(left_node.type) if def_ and not left_node.same_markup( from__.node(abs(preferred_target) - 1) @@ -340,9 +340,8 @@ def replace_range(self, from_: int, to: int, slice: Slice) -> "Transform": for j in range(slice.open_start, -1, -1): open_depth = (j + preferred_depth + 1) % (slice.open_start + 1) - if len(left_nodes) > open_depth: - insert = left_nodes[open_depth] - else: + insert = left_nodes[open_depth] if open_depth < len(left_nodes) else None + if insert is None: continue for i in range(len(target_depths)): target_depth = target_depths[ From 15397832ab26885473df92b242699d82e84ce18b Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Wed, 15 Nov 2023 15:15:06 -0500 Subject: [PATCH 33/40] Updating some type annotations to match the original repos. --- prosemirror/model/fragment.py | 3 ++- prosemirror/transform/transform.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index a1be381..bdbf87d 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -5,6 +5,7 @@ ClassVar, Iterable, cast, + Sequence, ) from prosemirror.utils import JSONList, text_length @@ -278,7 +279,7 @@ def from_array(cls, array: list["Node"]) -> "Fragment": return cls(joined or array, size) @classmethod - def from_(cls, nodes: "Fragment | Node | list[Node] | None") -> "Fragment": + def from_(cls, nodes: "Fragment | Node | Sequence[Node] | None") -> "Fragment": if not nodes: return cls.empty if isinstance(nodes, Fragment): diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index eab3afd..2781ea4 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -268,7 +268,7 @@ def replace_with( self, from_: int, to: int, - content: list[Node] | Node, + content: list[Node] | Node | Fragment, ) -> "Transform": return self.replace(from_, to, Slice(Fragment.from_(content), 0, 0)) @@ -278,7 +278,7 @@ def delete(self, from_: int, to: int) -> "Transform": def insert( self, pos: int, - content: list[Node] | Node, + content: list[Node] | Node | Fragment, ) -> "Transform": return self.replace_with(pos, pos, content) @@ -544,9 +544,9 @@ def iteratee( def set_node_markup( self, pos: int, - type: NodeType, - attrs: Attrs, - marks: None = None, + type: NodeType | None, + attrs: Attrs | None, + marks: list[Mark] | None = None, ) -> "Transform": node = self.doc.node_at(pos) if not node: From caa206d5f4b1ffeafaf52ae865c4fbfa97c0e2a6 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 09:03:12 -0500 Subject: [PATCH 34/40] Reverting the type annotation syntax to be valid for Python 3.9. --- prosemirror/model/content.py | 48 +++++---- prosemirror/model/diff.py | 8 +- prosemirror/model/fragment.py | 44 +++++---- prosemirror/model/from_dom.py | 130 +++++++++++++------------ prosemirror/model/mark.py | 8 +- prosemirror/model/node.py | 52 +++++----- prosemirror/model/replace.py | 16 +-- prosemirror/model/resolvedpos.py | 28 +++--- prosemirror/model/schema.py | 58 +++++------ prosemirror/model/to_dom.py | 19 ++-- prosemirror/transform/attr_step.py | 8 +- prosemirror/transform/doc_attr_step.py | 6 +- prosemirror/transform/map.py | 26 ++--- prosemirror/transform/mark_step.py | 40 ++++---- prosemirror/transform/replace.py | 42 ++++---- prosemirror/transform/replace_step.py | 22 +++-- prosemirror/transform/step.py | 10 +- prosemirror/transform/structure.py | 38 ++++---- prosemirror/transform/transform.py | 50 +++++----- prosemirror/utils.py | 4 +- 20 files changed, 350 insertions(+), 307 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 93dc67b..4b3a7df 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -6,7 +6,9 @@ Literal, NamedTuple, NoReturn, + Optional, TypedDict, + Union, cast, ) @@ -28,17 +30,19 @@ def __init__(self, type: "NodeType", next: "ContentMatch") -> None: class WrapCacheEntry: target: "NodeType" - computed: list["NodeType"] | None + computed: Optional[list["NodeType"]] - def __init__(self, target: "NodeType", computed: list["NodeType"] | None) -> None: + def __init__( + self, target: "NodeType", computed: Optional[list["NodeType"]] + ) -> None: self.target = target self.computed = computed class Active(TypedDict): match: "ContentMatch" - type: "NodeType | None" - via: "Active | None" + type: Optional["NodeType"] + via: Optional["Active"] class ContentMatch: @@ -71,18 +75,18 @@ def parse(cls, string: str, node_types: dict[str, "NodeType"]) -> "ContentMatch" check_for_dead_ends(match, stream) return match - def match_type(self, type: "NodeType") -> "ContentMatch | None": + def match_type(self, type: "NodeType") -> Optional["ContentMatch"]: for next in self.next: if next.type.name == type.name: return next.next return None def match_fragment( - self, frag: Fragment, start: int = 0, end: int | None = None - ) -> "ContentMatch | None": + self, frag: Fragment, start: int = 0, end: Optional[int] = None + ) -> Optional["ContentMatch"]: if end is None: end = frag.child_count - cur: "ContentMatch | None" = self + cur: Optional["ContentMatch"] = self i = start while cur and i < end: cur = cur.match_type(frag.child(i).type) @@ -94,7 +98,7 @@ def inline_content(self) -> bool: return bool(self.next) and self.next[0].type.is_inline @property - def default_type(self) -> "NodeType | None": + def default_type(self) -> Optional["NodeType"]: for next in self.next: type = next.type if not (type.is_text or type.has_required_attrs()): @@ -110,10 +114,10 @@ def compatible(self, other: "ContentMatch") -> bool: def fill_before( self, after: Fragment, to_end: bool = False, start_index: int = 0 - ) -> Fragment | None: + ) -> Optional[Fragment]: seen = [self] - def search(match: ContentMatch, types: list["NodeType"]) -> Fragment | None: + def search(match: ContentMatch, types: list["NodeType"]) -> Optional[Fragment]: nonlocal seen finished = match.match_fragment(after, start_index) if finished and (not to_end or finished.valid_end): @@ -132,7 +136,7 @@ def search(match: ContentMatch, types: list["NodeType"]) -> Fragment | None: return search(self, []) - def find_wrapping(self, target: "NodeType") -> list["NodeType"] | None: + def find_wrapping(self, target: "NodeType") -> Optional[list["NodeType"]]: for entry in self.wrap_cache: if entry.target.name == target.name: return entry.computed @@ -140,7 +144,7 @@ def find_wrapping(self, target: "NodeType") -> list["NodeType"] | None: self.wrap_cache.append(WrapCacheEntry(target, computed)) return computed - def compute_wrapping(self, target: "NodeType") -> list["NodeType"] | None: + def compute_wrapping(self, target: "NodeType") -> Optional[list["NodeType"]]: seen = {} active: list[Active] = [{"match": self, "type": None, "via": None}] while len(active): @@ -213,7 +217,7 @@ def iteratee(m: "ContentMatch", i: int) -> str: class TokenStream: - inline: bool | None + inline: Optional[bool] tokens: list[str] def __init__(self, string: str, node_types: dict[str, "NodeType"]) -> None: @@ -223,13 +227,13 @@ def __init__(self, string: str, node_types: dict[str, "NodeType"]) -> None: self.pos = 0 self.tokens = [i for i in TOKEN_REGEX.findall(string) if i.strip()] - def next(self) -> str | None: + def next(self) -> Optional[str]: try: return self.tokens[self.pos] except IndexError: return None - def eat(self, tok: str) -> int | bool: + def eat(self, tok: str) -> Union[int, bool]: if self.next() == tok: pos = self.pos self.pos += 1 @@ -278,7 +282,7 @@ class NameExpr(TypedDict): value: "NodeType" -Expr = ChoiceExpr | SeqExpr | PlusExpr | StarExpr | OptExpr | RangeExpr | NameExpr +Expr = Union[ChoiceExpr, SeqExpr, PlusExpr, StarExpr, OptExpr, RangeExpr, NameExpr] def parse_expr(stream: TokenStream) -> Expr: @@ -390,8 +394,8 @@ def iteratee(type: "NodeType") -> Expr: class Edge(TypedDict): - term: "NodeType | None" - to: int | None + term: Optional["NodeType"] + to: Optional[int] def nfa( @@ -404,7 +408,9 @@ def node() -> int: nfa_.append([]) return len(nfa_) - 1 - def edge(from_: int, to: int | None = None, term: "NodeType | None" = None) -> Edge: + def edge( + from_: int, to: Optional[int] = None, term: Optional["NodeType"] = None + ) -> Edge: nonlocal nfa_ edge: Edge = {"term": term, "to": to} nfa_[from_].append(edge) @@ -507,7 +513,7 @@ def explore(states: list[int]) -> ContentMatch: term, to = item.get("term"), item.get("to") if not term: continue - set: list[int] | None = None + set: Optional[list[int]] = None for t in out: if t[0] == term: set = t[1] diff --git a/prosemirror/model/diff.py b/prosemirror/model/diff.py index 098988f..e4f7c1a 100644 --- a/prosemirror/model/diff.py +++ b/prosemirror/model/diff.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict from prosemirror.utils import text_length @@ -13,7 +13,7 @@ class Diff(TypedDict): b: int -def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> int | None: +def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> Optional[int]: i = 0 while True: if a.child_count == i or b.child_count == i: @@ -52,7 +52,9 @@ def find_diff_start(a: "Fragment", b: "Fragment", pos: int) -> int | None: i += 1 -def find_diff_end(a: "Fragment", b: "Fragment", pos_a: int, pos_b: int) -> Diff | None: +def find_diff_end( + a: "Fragment", b: "Fragment", pos_a: int, pos_b: int +) -> Optional[Diff]: i_a, i_b = a.child_count, b.child_count while True: if i_a == 0 or i_b == 0: diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index bdbf87d..43e8f21 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -4,8 +4,10 @@ Callable, ClassVar, Iterable, + Optional, + Sequence, + Union, cast, - Sequence, ) from prosemirror.utils import JSONList, text_length @@ -26,7 +28,7 @@ class Fragment: content: list["Node"] size: int - def __init__(self, content: list["Node"], size: int | None = None) -> None: + def __init__(self, content: list["Node"], size: Optional[int] = None) -> None: self.content = content self.size = size if size is not None else sum(c.node_size for c in content) @@ -34,9 +36,9 @@ def nodes_between( self, from_: int, to: int, - f: Callable[["Node", int, "Node | None", int], bool | None], + f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], node_start: int = 0, - parent: "Node | None" = None, + parent: Optional["Node"] = None, ) -> None: i = 0 pos = 0 @@ -59,7 +61,7 @@ def nodes_between( i += 1 def descendants( - self, f: Callable[["Node", int, "Node | None", int], bool | None] + self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] ) -> None: self.nodes_between(0, self.size, f) @@ -68,12 +70,14 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Callable[["Node"], str] | str = "", + leaf_text: Union[Callable[["Node"], str], str] = "", ) -> str: text = [] separated = True - def iteratee(node: "Node", pos: int, _parent: "Node | None", _to: int) -> None: + def iteratee( + node: "Node", pos: int, _parent: Optional["Node"], _to: int + ) -> None: nonlocal text nonlocal separated if node.is_text: @@ -114,7 +118,7 @@ def append(self, other: "Fragment") -> "Fragment": i += 1 return Fragment(content, self.size + other.size) - def cut(self, from_: int, to: int | None = None) -> "Fragment": + def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": if to is None: to = self.size if from_ == 0 and to == self.size: @@ -144,7 +148,7 @@ def cut(self, from_: int, to: int | None = None) -> "Fragment": i += 1 return Fragment(result, size) - def cut_by_index(self, from_: int, to: int | None = None) -> "Fragment": + def cut_by_index(self, from_: int, to: Optional[int] = None) -> "Fragment": if from_ == to: return Fragment.empty if from_ == 0 and to == len(self.content): @@ -172,11 +176,11 @@ def eq(self, other: "Fragment") -> bool: return all(a.eq(b) for (a, b) in zip(self.content, other.content)) @property - def first_child(self) -> "Node | None": + def first_child(self) -> Optional["Node"]: return self.content[0] if self.content else None @property - def last_child(self) -> "Node | None": + def last_child(self) -> Optional["Node"]: return self.content[-1] if self.content else None @property @@ -186,7 +190,7 @@ def child_count(self) -> int: def child(self, index: int) -> "Node": return self.content[index] - def maybe_child(self, index: int) -> "Node | None": + def maybe_child(self, index: int) -> Optional["Node"]: try: return self.content[index] except IndexError: @@ -201,7 +205,7 @@ def for_each(self, f: Callable[["Node", int, int], Any]) -> None: p += child.node_size i += 1 - def find_diff_start(self, other: "Fragment", pos: int = 0) -> int | None: + def find_diff_start(self, other: "Fragment", pos: int = 0) -> Optional[int]: from .diff import find_diff_start return find_diff_start(self, other, pos) @@ -209,9 +213,9 @@ def find_diff_start(self, other: "Fragment", pos: int = 0) -> int | None: def find_diff_end( self, other: "Fragment", - pos: int | None = None, - other_pos: int | None = None, - ) -> "Diff | None": + pos: Optional[int] = None, + other_pos: Optional[int] = None, + ) -> Optional["Diff"]: from .diff import find_diff_end if pos is None: @@ -239,7 +243,7 @@ def find_index(self, pos: int, round: int = -1) -> dict[str, int]: i += 1 cur_pos = end - def to_json(self) -> JSONList | None: + def to_json(self) -> Optional[JSONList]: if self.content: return [item.to_json() for item in self.content] return None @@ -263,7 +267,7 @@ def from_json(cls, schema: "Schema[Any, Any]", value: Any) -> "Fragment": def from_array(cls, array: list["Node"]) -> "Fragment": if not array: return cls.empty - joined: list["Node"] | None = None + joined: Optional[list["Node"]] = None size = 0 for i in range(len(array)): node = array[i] @@ -279,7 +283,9 @@ def from_array(cls, array: list["Node"]) -> "Fragment": return cls(joined or array, size) @classmethod - def from_(cls, nodes: "Fragment | Node | Sequence[Node] | None") -> "Fragment": + def from_( + cls, nodes: Union["Fragment", "Node", Sequence["Node"], None] + ) -> "Fragment": if not nodes: return cls.empty if isinstance(nodes, Fragment): diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index b6b1ac5..86a22d0 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -1,7 +1,7 @@ import itertools import re from dataclasses import dataclass -from typing import Any, Callable, Literal, cast +from typing import Any, Callable, Literal, Optional, Union, cast import lxml from lxml.cssselect import CSSSelector @@ -17,47 +17,47 @@ from .resolvedpos import ResolvedPos from .schema import MarkType, NodeType, Schema -WSType = bool | Literal["full"] | None +WSType = Union[bool, Literal["full"], None] @dataclass class DOMPosition: node: DOMNode offset: int - pos: int | None = None + pos: Optional[int] = None @dataclass(frozen=True) class ParseOptions: preserve_whitespace: WSType = None - find_positions: list[DOMPosition] | None = None - from_: int | None = None - to_: int | None = None - top_node: Node | None = None - top_match: ContentMatch | None = None - context: ResolvedPos | None = None - rule_from_node: Callable[[DOMNode], "ParseRule"] | None = None - top_open: bool | None = None + find_positions: Optional[list[DOMPosition]] = None + from_: Optional[int] = None + to_: Optional[int] = None + top_node: Optional[Node] = None + top_match: Optional[ContentMatch] = None + context: Optional[ResolvedPos] = None + rule_from_node: Optional[Callable[[DOMNode], "ParseRule"]] = None + top_open: Optional[bool] = None @dataclass class ParseRule: - tag: str | None - namespace: str | None - style: str | None - priority: int | None - consuming: bool | None - context: str | None - node: str | None - mark: str | None - clear_mark: Callable[[Mark], bool] | None - ignore: bool | None - close_parent: bool | None - skip: bool | None - attrs: Attrs | None - get_attrs: Callable[[DOMNode], None | Attrs | Literal[False]] | None - content_element: str | DOMNode | Callable[[DOMNode], DOMNode] | None - get_content: Callable[[DOMNode, Schema[Any, Any]], Fragment] | None + tag: Optional[str] + namespace: Optional[str] + style: Optional[str] + priority: Optional[int] + consuming: Optional[bool] + context: Optional[str] + node: Optional[str] + mark: Optional[str] + clear_mark: Optional[Callable[[Mark], bool]] + ignore: Optional[bool] + close_parent: Optional[bool] + skip: Optional[bool] + attrs: Optional[Attrs] + get_attrs: Optional[Callable[[DOMNode], Union[Attrs, Literal[False], None]]] + content_element: Union[str, DOMNode, Callable[[DOMNode], DOMNode], None] + get_content: Optional[Callable[[DOMNode, Schema[Any, Any]], Fragment]] preserve_whitespace: WSType @classmethod @@ -109,7 +109,7 @@ def __init__(self, schema: Schema[Any, Any], rules: list[ParseRule]) -> None: ) def parse( - self, dom_: lxml.html.HtmlElement, options: ParseOptions | None = None + self, dom_: lxml.html.HtmlElement, options: Optional[ParseOptions] = None ) -> Node: if options is None: options = ParseOptions() @@ -134,7 +134,9 @@ def parse( return cast(Node, context.finish()) - def parse_slice(self, dom_: DOMNode, options: ParseOptions | None = None) -> Slice: + def parse_slice( + self, dom_: DOMNode, options: Optional[ParseOptions] = None + ) -> Slice: if options is None: options = ParseOptions(preserve_whitespace=True) @@ -145,8 +147,8 @@ def parse_slice(self, dom_: DOMNode, options: ParseOptions | None = None) -> Sli return Slice.max_open(cast(Fragment, context.finish())) def match_tag( - self, dom_: DOMNode, context: "ParseContext", after: ParseRule | None = None - ) -> ParseRule | None: + self, dom_: DOMNode, context: "ParseContext", after: Optional[ParseRule] = None + ) -> Optional[ParseRule]: try: i = self._tags.index(after) + 1 if after is not None else 0 except ValueError: @@ -177,8 +179,8 @@ def match_style( prop: str, value: str, context: "ParseContext", - after: ParseRule | None = None, - ) -> ParseRule | None: + after: Optional[ParseRule] = None, + ) -> Optional[ParseRule]: i = self._styles.index(after) + 1 if after is not None else 0 for rule in self._styles[i:]: @@ -313,7 +315,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": def ws_options_for( - _type: NodeType | None, preserve_whitespace: WSType, base: int + _type: Optional[NodeType], preserve_whitespace: WSType, base: int ) -> int: if preserve_whitespace is not None: return (OPT_PRESERVE_WS if preserve_whitespace else 0) | ( @@ -328,16 +330,16 @@ def ws_options_for( class NodeContext: - match: ContentMatch | None + match: Optional[ContentMatch] content: list[Node] active_marks: list[Mark] stash_marks: list[Mark] - type: NodeType | None + type: Optional[NodeType] options: int - attrs: Attrs | None + attrs: Optional[Attrs] marks: list[Mark] pending_marks: list[Mark] @@ -345,12 +347,12 @@ class NodeContext: def __init__( self, - _type: NodeType | None, - attrs: Attrs | None, + _type: Optional[NodeType], + attrs: Optional[Attrs], marks: list[Mark], pending_marks: list[Mark], solid: bool, - match: ContentMatch | None, + match: Optional[ContentMatch], options: int, ) -> None: self.type = _type @@ -374,7 +376,7 @@ def __init__( self.active_marks = Mark.none self.stash_marks = [] - def find_wrapping(self, node: Node) -> list[NodeType] | None: + def find_wrapping(self, node: Node) -> Optional[list[NodeType]]: if not self.match: if not self.type: return [] @@ -398,10 +400,10 @@ def find_wrapping(self, node: Node) -> list[NodeType] | None: return self.match.find_wrapping(node.type) - def finish(self, open_end: bool) -> Node | Fragment: + def finish(self, open_end: bool) -> Union[Node, Fragment]: if not self.options & OPT_PRESERVE_WS: try: - last: Node | None = self.content[-1] + last: Optional[Node] = self.content[-1] except IndexError: last = None @@ -425,8 +427,8 @@ def finish(self, open_end: bool) -> Node | Fragment: self.type.create(self.attrs, content, self.marks) if self.type else content ) - def pop_from_stash_mark(self, mark: Mark) -> Mark | None: - found_mark: Mark | None = None + def pop_from_stash_mark(self, mark: Mark) -> Optional[Mark]: + found_mark: Optional[Mark] = None for stash_mark in self.stash_marks[::-1]: if mark.eq(stash_mark): found_mark = stash_mark @@ -457,7 +459,7 @@ def inline_context(self, node: DOMNode) -> bool: class ParseContext: open: int = 0 - find: list[DOMPosition] | None + find: Optional[list[DOMPosition]] needs_block: bool nodes: list[NodeContext] options: ParseOptions @@ -584,7 +586,9 @@ def add_text_node(self, dom_: DOMNode) -> None: else: self.find_inside(dom_) - def add_element(self, dom_: DOMNode, match_after: ParseRule | None = None) -> None: + def add_element( + self, dom_: DOMNode, match_after: Optional[ParseRule] = None + ) -> None: name = dom_.tag.lower() if name in LIST_TAGS and self.parser.normalize_lists: @@ -647,12 +651,12 @@ def ignore_fallback(self, dom_: DOMNode) -> None: ): self.find_place(self.parser.schema.text("-")) - def read_styles(self, styles: list[str]) -> tuple[list[Mark], list[Mark]] | None: + def read_styles(self, styles: list[str]) -> Optional[tuple[list[Mark], list[Mark]]]: add: list[Mark] = Mark.none remove: list[Mark] = Mark.none for i in range(0, len(styles), 2): - after: ParseRule | None = None + after: Optional[ParseRule] = None while True: rule = self.parser.match_style(styles[i], styles[i + 1], self, after) if not rule: @@ -678,11 +682,11 @@ def read_styles(self, styles: list[str]) -> tuple[list[Mark], list[Mark]] | None return add, remove def add_element_by_rule( - self, dom_: DOMNode, rule: ParseRule, continue_after: ParseRule | None = None + self, dom_: DOMNode, rule: ParseRule, continue_after: Optional[ParseRule] = None ) -> None: sync: bool = False - mark: Mark | None = None - node_type: NodeType | None = None + mark: Optional[Mark] = None + node_type: Optional[NodeType] = None if rule.node is not None: node_type = self.parser.schema.nodes[rule.node] @@ -728,8 +732,8 @@ def add_element_by_rule( def add_all( self, parent: DOMNode, - start_index: int | None = None, - end_index: int | None = None, + start_index: Optional[int] = None, + end_index: Optional[int] = None, ) -> None: index = start_index if start_index is not None else 0 @@ -750,8 +754,8 @@ def add_all( self.find_at_point(parent, index) def find_place(self, node: Node) -> bool: - route: list[NodeType] | None = None - sync: NodeContext | None = None + route: Optional[list[NodeType]] = None + sync: Optional[NodeContext] = None depth = self.open while depth >= 0: @@ -807,7 +811,7 @@ def insert_node(self, node: Node) -> bool: return False def enter( - self, type_: NodeType, attrs: Attrs | None = None, preserve_ws: WSType = None + self, type_: NodeType, attrs: Optional[Attrs] = None, preserve_ws: WSType = None ) -> bool: ok = self.find_place(type_.create(attrs)) if ok: @@ -818,7 +822,7 @@ def enter( def enter_inner( self, type_: NodeType, - attrs: Attrs | None = None, + attrs: Optional[Attrs] = None, solid: bool = False, preserve_ws: WSType = None, ) -> None: @@ -855,7 +859,7 @@ def close_extra(self, open_end: bool = False) -> None: self.nodes = self.nodes[: self.open + 1] - def finish(self) -> Node | Fragment: + def finish(self) -> Union[Node, Fragment]: self.open = 0 self.close_extra(self.is_open) return self.nodes[0].finish(self.is_open or bool(self.options.top_open)) @@ -950,7 +954,7 @@ def match(i: int, depth: int) -> bool: return False else: if depth > 0 or (depth == 0 and use_root): - next: NodeType | None = self.nodes[depth].type + next: Optional[NodeType] = self.nodes[depth].type elif option is not None and depth >= min_depth: next = option.node(depth - min_depth).type else: @@ -971,7 +975,7 @@ def match(i: int, depth: int) -> bool: return match(len(parts) - 1, self.open) - def textblock_from_context(self) -> NodeType | None: + def textblock_from_context(self) -> Optional[NodeType]: context = self.options.context if context: @@ -1033,7 +1037,7 @@ def remove_pending_mark(self, mark: Mark, upto: NodeContext) -> None: def normalize_list(dom_: DOMNode) -> None: child = next(iter(dom_)) - prev_item: DOMNode | None = None + prev_item: Optional[DOMNode] = None while child is not None: name = child.tag.lower() if get_node_type(child) == 1 else None @@ -1097,7 +1101,7 @@ def scan(match: ContentMatch) -> bool: return False -def find_same_mark_in_set(mark: Mark, mark_set: list[Mark]) -> Mark | None: +def find_same_mark_in_set(mark: Mark, mark_set: list[Mark]) -> Optional[Mark]: for comp in mark_set: if mark.eq(comp): return comp diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index babf6a2..5455197 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final, Optional, Union, cast from prosemirror.utils import Attrs, JSONDict @@ -15,7 +15,7 @@ def __init__(self, type: "MarkType", attrs: Attrs) -> None: self.attrs = attrs def add_to_set(self, set: list["Mark"]) -> list["Mark"]: - copy: list["Mark"] | None | None = None + copy: Optional[list["Mark"]] = None placed = False for i in range(len(set)): other = set[i] @@ -66,7 +66,7 @@ def from_json( type = schema.marks.get(name) if not type: raise ValueError(f"There is no mark type {name} in this schema") - return type.create(cast(JSONDict | None, json_data.get("attrs"))) + return type.create(cast(Optional[JSONDict], json_data.get("attrs"))) @classmethod def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: @@ -77,7 +77,7 @@ def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b)) @classmethod - def set_from(cls, marks: "list[Mark] | Mark | None") -> list["Mark"]: + def set_from(cls, marks: Union[list["Mark"], "Mark", None]) -> list["Mark"]: if not marks: return cls.none if isinstance(marks, Mark): diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 6e9703d..8635c03 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union, cast from typing_extensions import TypeGuard @@ -20,7 +20,7 @@ class ChildInfo(TypedDict): - node: "Node | None" + node: Optional["Node"] index: int offset: int @@ -30,7 +30,7 @@ def __init__( self, type: "NodeType", attrs: "Attrs", - content: Fragment | None, + content: Optional[Fragment], marks: list[Mark], ) -> None: self.type = type @@ -49,7 +49,7 @@ def child_count(self) -> int: def child(self, index: int) -> "Node": return self.content.child(index) - def maybe_child(self, index: int) -> "Node | None": + def maybe_child(self, index: int) -> Optional["Node"]: return self.content.maybe_child(index) def for_each(self, f: Callable[["Node", int, int], None]) -> None: @@ -59,13 +59,13 @@ def nodes_between( self, from_: int, to: int, - f: Callable[["Node", int, "Node | None", int], bool | None], + f: Callable[["Node", int, Optional["Node"], int], Optional[bool]], start_pos: int = 0, ) -> None: self.content.nodes_between(from_, to, f, start_pos, self) def descendants( - self, f: Callable[["Node", int, "Node | None", int], bool | None] + self, f: Callable[["Node", int, Optional["Node"], int], Optional[bool]] ) -> None: self.nodes_between(0, self.content.size, f) @@ -80,16 +80,16 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Callable[["Node"], str] | str = "", + leaf_text: Union[Callable[["Node"], str], str] = "", ) -> str: return self.content.text_between(from_, to, block_separator, leaf_text) @property - def first_child(self) -> "Node | None": + def first_child(self) -> Optional["Node"]: return self.content.first_child @property - def last_child(self) -> "Node | None": + def last_child(self) -> Optional["Node"]: return self.content.last_child def eq(self, other: "Node") -> bool: @@ -103,8 +103,8 @@ def same_markup(self, other: "Node") -> bool: def has_markup( self, type: "NodeType", - attrs: "Attrs | None" = None, - marks: list[Mark] | None = None, + attrs: Optional["Attrs"] = None, + marks: Optional[list[Mark]] = None, ) -> bool: return ( self.type.name == type.name @@ -112,7 +112,7 @@ def has_markup( and (Mark.same_set(self.marks, marks or Mark.none)) ) - def copy(self, content: Fragment | None = None) -> "Node": + def copy(self, content: Optional[Fragment] = None) -> "Node": if content == self.content: return self return self.__class__(self.type, self.attrs, content, self.marks) @@ -122,13 +122,13 @@ def mark(self, marks: list[Mark]) -> "Node": return self return self.__class__(self.type, self.attrs, self.content, marks) - def cut(self, from_: int, to: int | None = None) -> "Node": + def cut(self, from_: int, to: Optional[int] = None) -> "Node": if from_ == 0 and to == self.content.size: return self return self.copy(self.content.cut(from_, to)) def slice( - self, from_: int, to: int | None = None, include_parents: bool = False + self, from_: int, to: Optional[int] = None, include_parents: bool = False ) -> Slice: if to is None: to = self.content.size @@ -145,7 +145,7 @@ def slice( def replace(self, from_: int, to: int, slice: Slice) -> "Node": return replace(self.resolve(from_), self.resolve(to), slice) - def node_at(self, pos: int) -> "Node | None": + def node_at(self, pos: int) -> Optional["Node"]: node = self while True: index_info = node.content.find_index(pos) @@ -183,12 +183,14 @@ def resolve(self, pos: int) -> ResolvedPos: def resolve_no_cache(self, pos: int) -> ResolvedPos: return ResolvedPos.resolve(self, pos) - def range_has_mark(self, from_: int, to: int, type: "Mark | MarkType") -> bool: + def range_has_mark( + self, from_: int, to: int, type: Union["Mark", "MarkType"] + ) -> bool: found = False if to > from_: def iteratee( - node: "Node", pos: int, parent: "Node | None", index: int + node: "Node", pos: int, parent: Optional["Node"], index: int ) -> bool: nonlocal found if type.is_in_set(node.marks): @@ -254,12 +256,12 @@ def can_replace( to: int, replacement: Fragment = Fragment.empty, start: int = 0, - end: int | None = None, + end: Optional[int] = None, ) -> bool: if end is None: end = replacement.child_count one = self.content_match_at(from_).match_fragment(replacement, start, end) - two: "ContentMatch | None" = None + two: Optional["ContentMatch"] = None if one: two = one.match_fragment(self.content, to) if not two or not two.valid_end: @@ -270,12 +272,12 @@ def can_replace( return True def can_replace_with( - self, from_: int, to: int, type: "NodeType", marks: list[Mark] | None = None + self, from_: int, to: int, type: "NodeType", marks: Optional[list[Mark]] = None ) -> bool: if marks and not self.type.allows_marks(marks): return False start = self.content_match_at(from_).match_type(type) - end: "ContentMatch | None" = None + end: Optional["ContentMatch"] = None if start: end = start.match_fragment(self.content, to) return end.valid_end if end else False @@ -325,7 +327,9 @@ def to_json(self) -> JSONDict: return obj @classmethod - def from_json(cls, schema: "Schema[Any, Any]", json_data: JSONDict | str) -> "Node": + def from_json( + cls, schema: "Schema[Any, Any]", json_data: Union[JSONDict, str] + ) -> "Node": if isinstance(json_data, str): import json @@ -380,7 +384,7 @@ def text_between( from_: int, to: int, block_separator: str = "", - leaf_text: Callable[["Node"], str] | str = "", + leaf_text: Union[Callable[["Node"], str], str] = "", ) -> str: return self.text[from_:to] @@ -400,7 +404,7 @@ def with_text(self, text: str) -> "TextNode": return self return TextNode(self.type, self.attrs, text, self.marks) - def cut(self, from_: int = 0, to: int | None = None) -> "TextNode": + def cut(self, from_: int = 0, to: Optional[int] = None) -> "TextNode": if to is None: to = text_length(self.text) if from_ == 0 and to == text_length(self.text): diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 7aa4d54..70445c3 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast from prosemirror.utils import JSONDict @@ -34,8 +34,8 @@ def remove_range(content: Fragment, from_: int, to: int) -> Fragment: def insert_into( - content: Fragment, dist: int, insert: Fragment, parent: "Node | None" -) -> Fragment | None: + content: Fragment, dist: int, insert: Fragment, parent: Optional["Node"] +) -> Optional[Fragment]: a = content.find_index(dist) index, offset = a["index"], a["offset"] child = content.maybe_child(index) @@ -62,7 +62,7 @@ def __init__(self, content: Fragment, open_start: int, open_end: int) -> None: def size(self) -> int: return self.content.size - self.open_start - self.open_end - def insert_at(self, pos: int, fragment: Fragment) -> "Slice | None": + def insert_at(self, pos: int, fragment: Fragment) -> Optional["Slice"]: content = insert_into(self.content, pos + self.open_start, fragment, None) if content: return Slice(content, self.open_start, self.open_end) @@ -85,7 +85,7 @@ def eq(self, other: "Slice") -> bool: def __str__(self) -> str: return f"{self.content}({self.open_start},{self.open_end})" - def to_json(self) -> JSONDict | None: + def to_json(self) -> Optional[JSONDict]: if not self.content.size: return None json: JSONDict = {"content": self.content.to_json()} @@ -105,7 +105,7 @@ def to_json(self) -> JSONDict | None: def from_json( cls, schema: "Schema[Any, Any]", - json_data: JSONDict | None, + json_data: Optional[JSONDict], ) -> "Slice": if not json_data: return cls.empty @@ -195,8 +195,8 @@ def add_node(child: "Node", target: list["Node"]) -> None: def add_range( - start: "ResolvedPos | None", - end: "ResolvedPos | None", + start: Optional["ResolvedPos"], + end: Optional["ResolvedPos"], depth: int, target: list["Node"], ) -> None: diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index b3cc4ff..c0ddd4d 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, cast +from typing import TYPE_CHECKING, Callable, Optional, cast from .mark import Mark @@ -13,7 +13,7 @@ def __init__(self, pos: int, path: list["Node | int"], parent_offset: int) -> No self.depth = int(len(path) / 3 - 1) self.parent_offset = parent_offset - def resolve_depth(self, val: int | None = None) -> int: + def resolve_depth(self, val: Optional[int] = None) -> int: if val is None: return self.depth return self.depth + val if val < 0 else val @@ -29,7 +29,7 @@ def doc(self) -> "Node": def node(self, depth: int) -> "Node": return cast("Node", self.path[self.resolve_depth(depth) * 3]) - def index(self, depth: int | None = None) -> int: + def index(self, depth: Optional[int] = None) -> int: return cast(int, self.path[self.resolve_depth(depth) * 3 + 1]) def index_after(self, depth: int) -> int: @@ -38,15 +38,15 @@ def index_after(self, depth: int) -> int: 0 if depth == self.depth and not self.text_offset else 1 ) - def start(self, depth: int | None = None) -> int: + def start(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) return 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 - def end(self, depth: int | None = None) -> int: + def end(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) return self.start(depth) + self.node(depth).content.size - def before(self, depth: int | None = None) -> int: + def before(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) if not depth: raise ValueError("There is no position before the top level node") @@ -54,7 +54,7 @@ def before(self, depth: int | None = None) -> int: self.pos if depth == self.depth + 1 else cast(int, self.path[depth * 3 - 1]) ) - def after(self, depth: int | None = None) -> int: + def after(self, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) if not depth: raise ValueError("There is no position after the top level node") @@ -70,7 +70,7 @@ def text_offset(self) -> int: return self.pos - cast(int, self.path[-1]) @property - def node_after(self) -> "Node | None": + def node_after(self) -> Optional["Node"]: parent = self.parent index = self.index(self.depth) if index == parent.child_count: @@ -80,14 +80,14 @@ def node_after(self) -> "Node | None": return parent.child(index).cut(d_off) if d_off else child @property - def node_before(self) -> "Node | None": + def node_before(self) -> Optional["Node"]: index = self.index(self.depth) d_off = self.pos - cast(int, self.path[-1]) if d_off: return self.parent.child(index).cut(0, d_off) return None if index == 0 else self.parent.child(index - 1) - def pos_at_index(self, index: int, depth: int | None = None) -> int: + def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: depth = self.resolve_depth(depth) node = cast("Node", self.path[depth * 3]) pos = 0 if depth == 0 else cast(int, self.path[depth * 3 - 1]) + 1 @@ -117,7 +117,7 @@ def marks(self) -> list["Mark"]: i += 1 return marks - def marks_across(self, end: "ResolvedPos") -> list["Mark"] | None: + def marks_across(self, end: "ResolvedPos") -> Optional[list["Mark"]]: after = self.parent.maybe_child(self.index()) if not after or not after.is_inline: return None @@ -143,9 +143,9 @@ def shared_depth(self, pos: int) -> int: def block_range( self, - other: "ResolvedPos | None" = None, - pred: Callable[["Node"], bool] | None = None, - ) -> "NodeRange | None": + other: Optional["ResolvedPos"] = None, + pred: Optional[Callable[["Node"], bool]] = None, + ) -> Optional["NodeRange"]: if other is None: other = self if other.pos < self.pos: diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index a3fd6d7..de41d4a 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -3,7 +3,9 @@ Callable, Generic, Literal, + Optional, TypeVar, + Union, cast, ) @@ -16,7 +18,7 @@ from prosemirror.utils import JSON, Attrs, JSONDict -def default_attrs(attrs: "Attributes") -> Attrs | None: +def default_attrs(attrs: "Attributes") -> Optional[Attrs]: defaults = {} for attr_name, attr in attrs.items(): if not attr.has_default: @@ -25,7 +27,7 @@ def default_attrs(attrs: "Attributes") -> Attrs | None: return defaults -def compute_attrs(attrs: "Attributes", value: Attrs | None) -> Attrs: +def compute_attrs(attrs: "Attributes", value: Optional[Attrs]) -> Attrs: built = {} for name in attrs: given = None @@ -41,7 +43,7 @@ def compute_attrs(attrs: "Attributes", value: Attrs | None) -> Attrs: return built -def init_attrs(attrs: "AttributeSpecs | None") -> "Attributes": +def init_attrs(attrs: Optional["AttributeSpecs"]) -> "Attributes": result = {} if attrs: for name in attrs: @@ -65,7 +67,7 @@ class NodeType: inline_content: bool - mark_set: list["MarkType"] | None + mark_set: Optional[list["MarkType"]] def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> None: self.name = name @@ -74,7 +76,7 @@ def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> N self.groups = spec["group"].split(" ") if "group" in spec else [] self.attrs = init_attrs(spec.get("attrs")) self.default_attrs = default_attrs(self.attrs) - self._content_match: ContentMatch | None = None + self._content_match: Optional[ContentMatch] = None self.mark_set = None self.inline_content = False self.is_block = not (spec.get("inline") or name == "text") @@ -120,16 +122,16 @@ def has_required_attrs(self) -> bool: def compatible_content(self, other: "NodeType") -> bool: return self == other or (self.content_match.compatible(other.content_match)) - def compute_attrs(self, attrs: Attrs | None) -> Attrs: + def compute_attrs(self, attrs: Optional[Attrs]) -> Attrs: if attrs is None and self.default_attrs is not None: return self.default_attrs return compute_attrs(self.attrs, attrs) def create( self, - attrs: Attrs | None = None, - content: Fragment | Node | list[Node] | None = None, - marks: list[Mark] | None = None, + attrs: Optional[Attrs] = None, + content: Union[Fragment, Node, list[Node], None] = None, + marks: Optional[list[Mark]] = None, ) -> Node: if self.is_text: raise ValueError("NodeType.create cannot construct text nodes") @@ -142,9 +144,9 @@ def create( def create_checked( self, - attrs: Attrs | None = None, - content: Fragment | Node | list[Node] | None = None, - marks: list[Mark] | None = None, + attrs: Optional[Attrs] = None, + content: Union[Fragment, Node, list[Node], None] = None, + marks: Optional[list[Mark]] = None, ) -> Node: content = Fragment.from_(content) if not self.valid_content(content): @@ -153,10 +155,10 @@ def create_checked( def create_and_fill( self, - attrs: Attrs | None = None, - content: Fragment | Node | list[Node] | None = None, - marks: list[Mark] | None = None, - ) -> Node | None: + attrs: Optional[Attrs] = None, + content: Union[Fragment, Node, list[Node], None] = None, + marks: Optional[list[Mark]] = None, + ) -> Optional[Node]: attrs = self.compute_attrs(attrs) frag = Fragment.from_(content) if frag.size: @@ -192,7 +194,7 @@ def allows_marks(self, marks: list[Mark]) -> bool: def allowed_marks(self, marks: list[Mark]) -> list[Mark]: if self.mark_set is None: return marks - copy: list[Mark] | None = None + copy: Optional[list[Mark]] = None for i, mark in enumerate(marks): if not self.allows_mark_type(mark.type): if not copy: @@ -246,7 +248,7 @@ def is_required(self) -> bool: class MarkType: excluded: list["MarkType"] - instance: Mark | None + instance: Optional[Mark] def __init__( self, name: str, rank: int, schema: "Schema[Any, Any]", spec: "MarkSpec" @@ -264,7 +266,7 @@ def __init__( def create( self, - attrs: Attrs | None = None, + attrs: Optional[Attrs] = None, ) -> Mark: if not attrs and self.instance: return self.instance @@ -284,7 +286,7 @@ def compile( def remove_from_set(self, set_: list["Mark"]) -> list["Mark"]: return [item for item in set_ if item.type != self] - def is_in_set(self, set: list[Mark]) -> Mark | None: + def is_in_set(self, set: list[Mark]) -> Optional[Mark]: return next((item for item in set if item.type == self), None) def excludes(self, other: "MarkType") -> bool: @@ -409,10 +411,10 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: def node( self, - type: str | NodeType, - attrs: Attrs | None = None, - content: Fragment | Node | list[Node] | None = None, - marks: list[Mark] | None = None, + type: Union[str, NodeType], + attrs: Optional[Attrs] = None, + content: Union[Fragment, Node, list[Node], None] = None, + marks: Optional[list[Mark]] = None, ) -> Node: if isinstance(type, str): type = self.node_type(type) @@ -422,7 +424,7 @@ def node( raise ValueError(f"Node type from different schema used ({type.name})") return type.create_checked(attrs, content, marks) - def text(self, text: str, marks: list[Mark] | None = None) -> TextNode: + def text(self, text: str, marks: Optional[list[Mark]] = None) -> TextNode: type = self.nodes[cast(Nodes, "text")] return TextNode( type, cast(Attrs, type.default_attrs), text, Mark.set_from(marks) @@ -430,14 +432,14 @@ def text(self, text: str, marks: list[Mark] | None = None) -> TextNode: def mark( self, - type: str | MarkType, - attrs: Attrs | None = None, + type: Union[str, MarkType], + attrs: Optional[Attrs] = None, ) -> Mark: if isinstance(type, str): type = self.marks[cast(Marks, type)] return type.create(attrs) - def node_from_json(self, json_data: JSONDict) -> Node | TextNode: + def node_from_json(self, json_data: JSONDict) -> Union[Node, TextNode]: return Node.from_json(self, json_data) def mark_from_json( diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 63ff971..8e33679 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -2,8 +2,9 @@ from typing import ( Any, Callable, + Optional, Sequence, - TypeAlias, + Union, cast, ) @@ -12,7 +13,7 @@ from .node import Node from .schema import Schema -HTMLNode: TypeAlias = "Element | str" +HTMLNode = Union["Element", "str"] class DocumentFragment: @@ -62,7 +63,7 @@ def __str__(self) -> str: return f"<{open_tag_str}>{children_str}" -HTMLOutputSpec = str | Sequence[Any] | Element +HTMLOutputSpec = Union[str, Sequence[Any], Element] class DOMSerializer: @@ -75,12 +76,12 @@ def __init__( self.marks = marks def serialize_fragment( - self, fragment: Fragment, target: Element | None = None + self, fragment: Fragment, target: Optional[Element] = None ) -> DocumentFragment: tgt: DocumentFragment = target or DocumentFragment(children=[]) top = tgt - active: list[tuple[Mark, DocumentFragment]] | None = None + active: Optional[list[tuple[Mark, DocumentFragment]]] = None def each(node: Node, offset: int, index: int) -> None: nonlocal top, active @@ -137,14 +138,16 @@ def serialize_node(self, node: Node) -> HTMLNode: def serialize_mark( self, mark: Mark, inline: bool - ) -> tuple[HTMLNode, Element | None] | None: + ) -> Optional[tuple[HTMLNode, Optional[Element]]]: to_dom = self.marks.get(mark.type.name) if to_dom: return type(self).render_spec(to_dom(mark, inline)) return None @classmethod - def render_spec(cls, structure: HTMLOutputSpec) -> tuple[HTMLNode, Element | None]: + def render_spec( + cls, structure: HTMLOutputSpec + ) -> tuple[HTMLNode, Optional[Element]]: if isinstance(structure, str): return html.escape(structure), None if isinstance(structure, Element): @@ -152,7 +155,7 @@ def render_spec(cls, structure: HTMLOutputSpec) -> tuple[HTMLNode, Element | Non tag_name = structure[0] if " " in tag_name[1:]: raise NotImplementedError("XML namespaces are not supported") - content_dom: Element | None = None + content_dom: Optional[Element] = None dom = Element(name=tag_name, attrs={}, children=[]) attrs = structure[1] if len(structure) > 1 else None start = 1 diff --git a/prosemirror/transform/attr_step.py b/prosemirror/transform/attr_step.py index f4bbf65..88eeca6 100644 --- a/prosemirror/transform/attr_step.py +++ b/prosemirror/transform/attr_step.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Optional, Union, cast from prosemirror.model import Fragment, Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -37,7 +37,7 @@ def invert(self, doc: Node) -> "AttrStep": assert node_at_pos is not None return AttrStep(self.pos, self.attr, node_at_pos.attrs[self.attr]) - def map(self, mapping: Mappable) -> Step | None: + def map(self, mapping: Mappable) -> Optional[Step]: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AttrStep(pos.pos, self.attr, self.value) @@ -50,7 +50,9 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AttrStep": + def from_json( + schema: Schema[Any, Any], json_data: Union[JSONDict, str] + ) -> "AttrStep": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/doc_attr_step.py b/prosemirror/transform/doc_attr_step.py index c00d990..0d84ce8 100644 --- a/prosemirror/transform/doc_attr_step.py +++ b/prosemirror/transform/doc_attr_step.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast from prosemirror.model import Node, Schema from prosemirror.transform.map import Mappable, StepMap @@ -39,7 +39,9 @@ def to_json(self) -> JSONDict: return json_data @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "DocAttrStep": + def from_json( + schema: Schema[Any, Any], json_data: Union[JSONDict, str] + ) -> "DocAttrStep": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/map.py b/prosemirror/transform/map.py index 196f58b..639c375 100644 --- a/prosemirror/transform/map.py +++ b/prosemirror/transform/map.py @@ -1,6 +1,6 @@ import abc from collections.abc import Callable -from typing import Any, ClassVar, Literal, overload +from typing import ClassVar, Literal, Optional, Union, overload lower16 = 0xFFFF factor16 = 2**16 @@ -25,7 +25,9 @@ def recover_offset(value: int) -> int: class MapResult: - def __init__(self, pos: int, del_info: int = 0, recover: int | None = None) -> None: + def __init__( + self, pos: int, del_info: int = 0, recover: Optional[int] = None + ) -> None: self.pos = pos self.del_info = del_info self.recover = recover @@ -68,7 +70,7 @@ def map_result(self, pos: int, assoc: int = 1) -> MapResult: class StepMap(Mappable): empty: ClassVar["StepMap"] - def __init__(self, ranges: list[int | Any], inverted: bool = False) -> None: + def __init__(self, ranges: list[int], inverted: bool = False) -> None: # prosemirror-transform overrides the constructor to return the # StepMap.empty singleton when ranges are empty. # It is not easy to do in Python, and the intent of that is to make sure @@ -98,7 +100,7 @@ def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: ... - def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: + def _map(self, pos: int, assoc: int, simple: bool) -> Union[MapResult, int]: diff = 0 old_index = 2 if self.inverted else 1 new_index = 1 if self.inverted else 2 @@ -180,17 +182,17 @@ def __str__(self) -> str: class Mapping(Mappable): def __init__( self, - maps: list[StepMap] | None = None, - mirror: list[int] | None = None, - from_: int | None = None, - to: int | None = None, + maps: Optional[list[StepMap]] = None, + mirror: Optional[list[int]] = None, + from_: Optional[int] = None, + to: Optional[int] = None, ) -> None: self.maps = maps or [] self.from_ = from_ or 0 self.to = len(self.maps) if to is None else to self.mirror = mirror - def slice(self, from_: int = 0, to: int | None = None) -> "Mapping": + def slice(self, from_: int = 0, to: Optional[int] = None) -> "Mapping": if to is None: to = len(self.maps) return Mapping(self.maps, self.mirror, from_, to) @@ -200,7 +202,7 @@ def copy(self) -> "Mapping": self.maps[:], (self.mirror[:] if self.mirror else None), self.from_, self.to ) - def append_map(self, map: StepMap, mirrors: int | None = None) -> None: + def append_map(self, map: StepMap, mirrors: Optional[int] = None) -> None: self.maps.append(map) self.to = len(self.maps) if mirrors is not None: @@ -217,7 +219,7 @@ def append_mapping(self, mapping: "Mapping") -> None: (start_size + mirr) if (mirr is not None and mirr < i) else None, ) - def get_mirror(self, n: int) -> int | None: + def get_mirror(self, n: int) -> Optional[int]: if self.mirror: for i in range(len(self.mirror)): if (self.mirror[i]) == n: @@ -263,7 +265,7 @@ def _map(self, pos: int, assoc: int, simple: Literal[True]) -> int: def _map(self, pos: int, assoc: int, simple: Literal[False]) -> MapResult: ... - def _map(self, pos: int, assoc: int, simple: bool) -> MapResult | int: + def _map(self, pos: int, assoc: int, simple: bool) -> Union[MapResult, int]: del_info = 0 i = self.from_ diff --git a/prosemirror/transform/mark_step.py b/prosemirror/transform/mark_step.py index b117226..5150c73 100644 --- a/prosemirror/transform/mark_step.py +++ b/prosemirror/transform/mark_step.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, cast +from typing import Any, Callable, Optional, Union, cast from prosemirror.model import Fragment, Mark, Node, Schema, Slice from prosemirror.transform.map import Mappable @@ -8,8 +8,8 @@ def map_fragment( fragment: Fragment, - f: Callable[[Node, Node | None, int], Node], - parent: Node | None = None, + f: Callable[[Node, Optional[Node], int], Node], + parent: Optional[Node] = None, ) -> Fragment: mapped = [] for i in range(fragment.child_count): @@ -34,7 +34,7 @@ def apply(self, doc: Node) -> StepResult: from__ = doc.resolve(self.from_) parent = from__.node(from__.shared_depth(self.to)) - def iteratee(node: Node, parent: Node | None, i: int) -> Node: + def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: if parent and ( not node.is_atom or not parent.type.allows_mark_type(self.mark.type) ): @@ -48,17 +48,17 @@ def iteratee(node: Node, parent: Node | None, i: int) -> Node: ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc: Node | None = None) -> "RemoveMarkStep": + def invert(self, doc: Optional[Node] = None) -> "RemoveMarkStep": return RemoveMarkStep(self.from_, self.to, self.mark) - def map(self, mapping: Mappable) -> "AddMarkStep | None": + def map(self, mapping: Mappable) -> Optional["AddMarkStep"]: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if from_.deleted and to.deleted or from_.pos > to.pos: return None return AddMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: "Step") -> "AddMarkStep | None": + def merge(self, other: "Step") -> Optional["AddMarkStep"]: if ( isinstance(other, AddMarkStep) and other.mark.eq(self.mark) @@ -79,7 +79,9 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "AddMarkStep": + def from_json( + schema: Schema[Any, Any], json_data: Union[JSONDict, str] + ) -> "AddMarkStep": if isinstance(json_data, str): import json @@ -109,7 +111,7 @@ def __init__(self, from_: int, to: int, mark: Mark) -> None: def apply(self, doc: Node) -> StepResult: old_slice = doc.slice(self.from_, self.to) - def iteratee(node: Node, parent: Node | None, i: int) -> Node: + def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: return node.mark(self.mark.remove_from_set(node.marks)) slice = Slice( @@ -119,17 +121,17 @@ def iteratee(node: Node, parent: Node | None, i: int) -> Node: ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc: Node | None = None) -> AddMarkStep: + def invert(self, doc: Optional[Node] = None) -> AddMarkStep: return AddMarkStep(self.from_, self.to, self.mark) - def map(self, mapping: Mappable) -> "RemoveMarkStep | None": + def map(self, mapping: Mappable) -> Optional["RemoveMarkStep"]: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if (from_.deleted and to.deleted) or (from_.pos > to.pos): return None return RemoveMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: "Step") -> "RemoveMarkStep | None": + def merge(self, other: "Step") -> Optional["RemoveMarkStep"]: if ( isinstance(other, RemoveMarkStep) and (other.mark.eq(self.mark)) @@ -150,7 +152,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": if isinstance(json_data, str): import json @@ -188,7 +190,7 @@ def apply(self, doc: Node) -> StepResult: Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def invert(self, doc: Node) -> "RemoveNodeMarkStep | AddNodeMarkStep": + def invert(self, doc: Node) -> Union["RemoveNodeMarkStep", "AddNodeMarkStep"]: node = doc.node_at(self.pos) if node: new_set = self.mark.add_to_set(node.marks) @@ -199,7 +201,7 @@ def invert(self, doc: Node) -> "RemoveNodeMarkStep | AddNodeMarkStep": return AddNodeMarkStep(self.pos, self.mark) return RemoveNodeMarkStep(self.pos, self.mark) - def map(self, mapping: Mappable) -> "AddNodeMarkStep | None": + def map(self, mapping: Mappable) -> Optional["AddNodeMarkStep"]: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AddNodeMarkStep(pos.pos, self.mark) @@ -211,7 +213,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": if isinstance(json_data, str): import json @@ -247,13 +249,13 @@ def apply(self, doc: Node) -> StepResult: Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def invert(self, doc: Node) -> "RemoveNodeMarkStep | AddNodeMarkStep": + def invert(self, doc: Node) -> Union["RemoveNodeMarkStep", "AddNodeMarkStep"]: node = doc.node_at(self.pos) if not node or not self.mark.is_in_set(node.marks): return self return AddNodeMarkStep(self.pos, self.mark) - def map(self, mapping: Mappable) -> "RemoveNodeMarkStep | None": + def map(self, mapping: Mappable) -> Optional["RemoveNodeMarkStep"]: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else RemoveNodeMarkStep(pos.pos, self.mark) @@ -265,7 +267,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": if isinstance(json_data, str): import json diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index 977ef1a..bc2b9f3 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Optional, cast from prosemirror.model import ( ContentMatch, @@ -16,9 +16,9 @@ def replace_step( doc: Node, from_: int, - to: int | None = None, - slice: Slice | None = None, -) -> Step | None: + to: Optional[int] = None, + slice: Optional[Slice] = None, +) -> Optional[Step]: if to is None: to = from_ if slice is None: @@ -58,9 +58,9 @@ def __init__( self, slice_depth: int, frontier_depth: int, - parent: Node | None, - inject: Fragment | None = None, - wrap: list[NodeType] | None = None, + parent: Optional[Node], + inject: Optional[Fragment] = None, + wrap: Optional[list[NodeType]] = None, ) -> None: self.slice_depth = slice_depth self.frontier_depth = frontier_depth @@ -106,7 +106,7 @@ def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice) -> None: def depth(self) -> int: return len(self.frontier) - 1 - def fit(self) -> Step | None: + def fit(self) -> Optional[Step]: while self.unplaced.size: fit = self.find_fittable() if fit: @@ -147,7 +147,7 @@ def fit(self) -> Step | None: return ReplaceStep(from__.pos, to_.pos, slice) return None - def find_fittable(self) -> _Fittable | None: + def find_fittable(self) -> Optional[_Fittable]: start_depth = self.unplaced.open_start cur = self.unplaced.content open_end = self.unplaced.open_end @@ -182,18 +182,18 @@ def find_fittable(self) -> _Fittable | None: inject = _nothing wrap = _nothing - def _lazy_inject() -> Fragment | None: + def _lazy_inject() -> Optional[Fragment]: nonlocal inject if inject is _nothing: inject = match.fill_before(Fragment.from_(first), False) - return cast(Fragment | None, inject) + return cast(Optional[Fragment], inject) - def _lazy_wrap() -> list[NodeType] | None: + def _lazy_wrap() -> Optional[list[NodeType]]: nonlocal wrap assert first is not None if wrap is _nothing: wrap = match.find_wrapping(first.type) - return cast(list[NodeType] | None, wrap) + return cast(Optional[list[NodeType]], wrap) if pass_ == 1 and ( (match.match_type(first.type) or _lazy_inject()) @@ -355,11 +355,11 @@ def must_move_inline(self) -> int: _nothing = object() level = _nothing - def _lazy_level() -> _CloseLevel | None: + def _lazy_level() -> Optional[_CloseLevel]: nonlocal level if level is _nothing: level = self.find_close_level(self.to_) - return cast(_CloseLevel | None, level) + return cast(Optional[_CloseLevel], level) if ( not top.type.is_text_block @@ -383,7 +383,7 @@ def _lazy_level() -> _CloseLevel | None: after += 1 return after - def find_close_level(self, to_: ResolvedPos) -> _CloseLevel | None: + def find_close_level(self, to_: ResolvedPos) -> Optional[_CloseLevel]: for i in range(min(self.depth, to_.depth), -1, -1): match = self.frontier[i].match type_ = self.frontier[i].type @@ -406,7 +406,7 @@ def find_close_level(self, to_: ResolvedPos) -> _CloseLevel | None: ) return None - def close(self, to_: ResolvedPos) -> ResolvedPos | None: + def close(self, to_: ResolvedPos) -> Optional[ResolvedPos]: close = self.find_close_level(to_) if not close: return None @@ -425,8 +425,8 @@ def close(self, to_: ResolvedPos) -> ResolvedPos | None: def open_frontier_node( self, type_: NodeType, - attrs: Attrs | None = None, - content: Fragment | None = None, + attrs: Optional[Attrs] = None, + content: Optional[Fragment] = None, ) -> None: top = self.frontier[self.depth] top_match = top.match.match_type(type_) @@ -505,7 +505,7 @@ def content_after_fits( type_: NodeType, match: ContentMatch, open_: bool, -) -> Fragment | None: +) -> Optional[Fragment]: node = to_.node(depth) index = to_.index_after(depth) if open_ else to_.index(depth) if index == node.child_count and not type_.compatible_content(node.type): @@ -526,7 +526,7 @@ def close_fragment( depth: int, old_open: int, new_open: int, - parent: Node | None, + parent: Optional[Node], ) -> Fragment: if depth < old_open: first = fragment.first_child diff --git a/prosemirror/transform/replace_step.py b/prosemirror/transform/replace_step.py index 30a22a2..7ab32d2 100644 --- a/prosemirror/transform/replace_step.py +++ b/prosemirror/transform/replace_step.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Optional, Union, cast from prosemirror.model import Node, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -8,7 +8,7 @@ class ReplaceStep(Step): def __init__( - self, from_: int, to: int, slice: Slice, structure: bool | None = None + self, from_: int, to: int, slice: Slice, structure: Optional[bool] = None ) -> None: super().__init__() self.from_ = from_ @@ -29,14 +29,14 @@ def invert(self, doc: Node) -> "ReplaceStep": self.from_, self.from_ + self.slice.size, doc.slice(self.from_, self.to) ) - def map(self, mapping: Mappable) -> "ReplaceStep | None": + def map(self, mapping: Mappable) -> Optional["ReplaceStep"]: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if from_.deleted and to.deleted: return None return ReplaceStep(from_.pos, max(from_.pos, to.pos), self.slice) - def merge(self, other: "Step") -> "ReplaceStep | None": + def merge(self, other: "Step") -> Optional["ReplaceStep"]: if not isinstance(other, ReplaceStep) or other.structure or self.structure: return None if ( @@ -86,7 +86,9 @@ def to_json(self) -> JSONDict: return json_data @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "ReplaceStep": + def from_json( + schema: Schema[Any, Any], json_data: Union[JSONDict, str] + ) -> "ReplaceStep": if isinstance(json_data, str): import json @@ -99,7 +101,7 @@ def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "ReplaceSt return ReplaceStep( json_data["from"], json_data["to"], - Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), + Slice.from_json(schema, cast(Optional[JSONDict], json_data.get("slice"))), bool(json_data.get("structure")), ) @@ -116,7 +118,7 @@ def __init__( gap_to: int, slice: Slice, insert: int, - structure: bool | None = None, + structure: Optional[bool] = None, ) -> None: super().__init__() self.from_ = from_ @@ -167,7 +169,7 @@ def invert(self, doc: Node) -> "ReplaceAroundStep": self.structure, ) - def map(self, mapping: Mappable) -> "ReplaceAroundStep | None": + def map(self, mapping: Mappable) -> Optional["ReplaceAroundStep"]: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) gap_from = mapping.map(self.gap_from, -1) @@ -201,7 +203,7 @@ def to_json(self) -> JSONDict: @staticmethod def from_json( - schema: Schema[Any, Any], json_data: JSONDict | str + schema: Schema[Any, Any], json_data: Union[JSONDict, str] ) -> "ReplaceAroundStep": if isinstance(json_data, str): import json @@ -221,7 +223,7 @@ def from_json( json_data["to"], json_data["gapFrom"], json_data["gapTo"], - Slice.from_json(schema, cast(JSONDict | None, json_data.get("slice"))), + Slice.from_json(schema, cast(Optional[JSONDict], json_data.get("slice"))), json_data["insert"], bool(json_data.get("structure")), ) diff --git a/prosemirror/transform/step.py b/prosemirror/transform/step.py index 7c48ef6..bf647b3 100644 --- a/prosemirror/transform/step.py +++ b/prosemirror/transform/step.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Literal, Type, TypeVar, cast, overload +from typing import Any, Literal, Optional, Type, TypeVar, Union, cast, overload from prosemirror.model import Node, ReplaceError, Schema, Slice from prosemirror.transform.map import Mappable, StepMap @@ -25,10 +25,10 @@ def invert(self, _doc: Node) -> "Step": ... @abc.abstractmethod - def map(self, _mapping: Mappable) -> "Step | None": + def map(self, _mapping: Mappable) -> Optional["Step"]: ... - def merge(self, _other: "Step") -> "Step | None": + def merge(self, _other: "Step") -> Optional["Step"]: return None @abc.abstractmethod @@ -36,7 +36,7 @@ def to_json(self) -> JSONDict: ... @staticmethod - def from_json(schema: Schema[Any, Any], json_data: JSONDict | str) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": if isinstance(json_data, str): import json @@ -69,7 +69,7 @@ def __init__(self, doc: Node, failed: Literal[None]) -> None: def __init__(self, doc: None, failed: str) -> None: ... - def __init__(self, doc: Node | None, failed: str | None) -> None: + def __init__(self, doc: Optional[Node], failed: Optional[str]) -> None: self.doc = doc self.failed = failed diff --git a/prosemirror/transform/structure.py b/prosemirror/transform/structure.py index 3307da6..fb52b7d 100644 --- a/prosemirror/transform/structure.py +++ b/prosemirror/transform/structure.py @@ -1,4 +1,4 @@ -from typing import TypedDict +from typing import Optional, TypedDict, Union from prosemirror.model import ContentMatch, Node, NodeRange, NodeType, Slice from prosemirror.utils import Attrs @@ -10,7 +10,7 @@ def can_cut(node: Node, start: int, end: int) -> bool: return False -def lift_target(range_: NodeRange) -> int | None: +def lift_target(range_: NodeRange) -> Optional[int]: parent = range_.parent content = parent.content.cut_by_index(range_.start_index, range_.end_index) depth = range_.depth @@ -33,15 +33,15 @@ def lift_target(range_: NodeRange) -> int | None: class NodeTypeWithAttrs(TypedDict): type: NodeType - attrs: Attrs | None + attrs: Optional[Attrs] def find_wrapping( range_: NodeRange, node_type: NodeType, - attrs: Attrs | None = None, - inner_range: NodeRange | None = None, -) -> list[NodeTypeWithAttrs] | None: + attrs: Optional[Attrs] = None, + inner_range: Optional[NodeRange] = None, +) -> Optional[list[NodeTypeWithAttrs]]: if inner_range is None: inner_range = range_ @@ -67,7 +67,9 @@ def with_attrs(type: NodeType) -> NodeTypeWithAttrs: return NodeTypeWithAttrs(type=type, attrs=None) -def find_wrapping_outside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: +def find_wrapping_outside( + range_: NodeRange, type: NodeType +) -> Optional[list[NodeType]]: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -78,7 +80,7 @@ def find_wrapping_outside(range_: NodeRange, type: NodeType) -> list[NodeType] | return around if parent.can_replace_with(start_index, end_index, outer) else None -def find_wrapping_inside(range_: NodeRange, type: NodeType) -> list[NodeType] | None: +def find_wrapping_inside(range_: NodeRange, type: NodeType) -> Optional[list[NodeType]]: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -89,7 +91,7 @@ def find_wrapping_inside(range_: NodeRange, type: NodeType) -> list[NodeType] | return None last_type = inside[-1] if len(inside) else type - inner_match: ContentMatch | None = last_type.content_match + inner_match: Optional[ContentMatch] = last_type.content_match i = start_index while inner_match and i < end_index: @@ -111,14 +113,14 @@ def can_change_type(doc: Node, pos: int, type: NodeType) -> bool: def can_split( doc: Node, pos: int, - depth: int | None = None, - types_after: list[NodeTypeWithAttrs] | None = None, + depth: Optional[int] = None, + types_after: Optional[list[NodeTypeWithAttrs]] = None, ) -> bool: if depth is None: depth = 1 pos_ = doc.resolve(pos) base = pos_.depth - depth - inner_type: NodeTypeWithAttrs | Node | None = None + inner_type: Union[NodeTypeWithAttrs, Node, None] = None if types_after: inner_type = types_after[-1] @@ -163,7 +165,7 @@ def can_split( rest = rest.replace_child( 0, override_child["type"].create(override_child.get("attrs")) ) - after: NodeTypeWithAttrs | Node | None = None + after: Union[NodeTypeWithAttrs, Node, None] = None if types_after and len(types_after) > i: after = types_after[i] if not after: @@ -191,7 +193,7 @@ def can_split( ) -def can_join(doc: Node, pos: int) -> bool | None: +def can_join(doc: Node, pos: int) -> Optional[bool]: pos_ = doc.resolve(pos) index = pos_.index() return ( @@ -201,13 +203,13 @@ def can_join(doc: Node, pos: int) -> bool | None: ) -def joinable(a: Node | None, b: Node | None) -> bool: +def joinable(a: Optional[Node], b: Optional[Node]) -> bool: if a and b and not a.is_leaf: return a.can_append(b) return False -def join_point(doc: Node, pos: int, dir: int = -1) -> int | None: +def join_point(doc: Node, pos: int, dir: int = -1) -> Optional[int]: pos_ = doc.resolve(pos) for d in range(pos_.depth, -1, -1): before = None @@ -237,7 +239,7 @@ def join_point(doc: Node, pos: int, dir: int = -1) -> int | None: return None -def insert_point(doc: Node, pos: int, node_type: NodeType) -> int | None: +def insert_point(doc: Node, pos: int, node_type: NodeType) -> Optional[int]: pos_ = doc.resolve(pos) if pos_.parent.can_replace_with(pos_.index(), pos_.index(), node_type): return pos @@ -259,7 +261,7 @@ def insert_point(doc: Node, pos: int, node_type: NodeType) -> int | None: return None -def drop_point(doc: Node, pos: int, slice: Slice) -> int | None: +def drop_point(doc: Node, pos: int, slice: Slice) -> Optional[int]: pos_ = doc.resolve(pos) if not slice.content.size: return pos diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index 2781ea4..5b273d4 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -1,5 +1,5 @@ import re -from typing import TypedDict +from typing import Optional, TypedDict, Union from prosemirror.model import ( ContentMatch, @@ -34,7 +34,7 @@ from .doc_attr_step import DocAttrStep -def defines_content(type: NodeType | MarkType) -> bool | None: +def defines_content(type: Union[NodeType, MarkType]) -> Optional[bool]: if isinstance(type, NodeType): return type.spec.get("defining") or type.spec.get("definingForContent") return False @@ -90,10 +90,10 @@ def add_step(self, step: Step, doc: Node) -> None: def add_mark(self, from_: int, to: int, mark: Mark) -> "Transform": removed = [] added = [] - removing: RemoveMarkStep | None = None - adding: AddMarkStep | None = None + removing: Optional[RemoveMarkStep] = None + adding: Optional[AddMarkStep] = None - def iteratee(node: Node, pos: int, parent: Node | None, i: int) -> None: + def iteratee(node: Node, pos: int, parent: Optional[Node], i: int) -> None: nonlocal removing nonlocal adding if not node.is_inline: @@ -136,7 +136,7 @@ def remove_mark( self, from_: int, to: int, - mark: Mark | MarkType | None = None, + mark: Union[Mark, MarkType, None] = None, ) -> "Transform": class MatchedTypedDict(TypedDict): style: Mark @@ -147,7 +147,9 @@ class MatchedTypedDict(TypedDict): matched: list[MatchedTypedDict] = [] step = 0 - def iteratee(node: Node, pos: int, parent: Node | None, i: int) -> bool | None: + def iteratee( + node: Node, pos: int, parent: Optional[Node], i: int + ) -> Optional[bool]: nonlocal step if not node.is_inline: return None @@ -198,7 +200,7 @@ def clear_incompatible( self, pos: int, parent_type: NodeType, - match: ContentMatch | None = None, + match: Optional[ContentMatch] = None, ) -> "Transform": if match is None: match = parent_type.content_match @@ -252,8 +254,8 @@ def clear_incompatible( def replace( self, from_: int, - to: int | None = None, - slice: Slice | None = None, + to: Optional[int] = None, + slice: Optional[Slice] = None, ) -> "Transform": if to is None: to = from_ @@ -268,7 +270,7 @@ def replace_with( self, from_: int, to: int, - content: list[Node] | Node | Fragment, + content: Union[list[Node], Node, Fragment], ) -> "Transform": return self.replace(from_, to, Slice(Fragment.from_(content), 0, 0)) @@ -278,7 +280,7 @@ def delete(self, from_: int, to: int) -> "Transform": def insert( self, pos: int, - content: list[Node] | Node | Fragment, + content: Union[list[Node], Node, Fragment], ) -> "Transform": return self.replace_with(pos, pos, content) @@ -498,9 +500,9 @@ def wrap( def set_block_type( self, from_: int, - to: int | None, + to: Optional[int], type: NodeType, - attrs: Attrs | None, + attrs: Optional[Attrs], ) -> "Transform": if to is None: to = from_ @@ -509,8 +511,8 @@ def set_block_type( map_from = len(self.steps) def iteratee( - node: "Node", pos: int, parent: "Node | None", i: int - ) -> bool | None: + node: "Node", pos: int, parent: Optional["Node"], i: int + ) -> Optional[bool]: if ( node.is_text_block and not node.has_markup(type, attrs) @@ -544,9 +546,9 @@ def iteratee( def set_node_markup( self, pos: int, - type: NodeType | None, - attrs: Attrs | None, - marks: list[Mark] | None = None, + type: Optional[NodeType], + attrs: Optional[Attrs], + marks: Optional[list[Mark]] = None, ) -> "Transform": node = self.doc.node_at(pos) if not node: @@ -570,7 +572,9 @@ def set_node_markup( ) ) - def set_node_attribute(self, pos: int, attr: str, value: str | int) -> "Transform": + def set_node_attribute( + self, pos: int, attr: str, value: Union[str, int] + ) -> "Transform": return self.step(AttrStep(pos, attr, value)) def set_doc_attribute(self, attr: str, value: JSON) -> "Transform": @@ -579,7 +583,7 @@ def set_doc_attribute(self, attr: str, value: JSON) -> "Transform": def add_node_mark(self, pos: int, mark: Mark) -> "Transform": return self.step(AddNodeMarkStep(pos, mark)) - def remove_node_mark(self, pos: int, mark: Mark | MarkType) -> "Transform": + def remove_node_mark(self, pos: int, mark: Union[Mark, MarkType]) -> "Transform": if isinstance(mark, MarkType): node = self.doc.node_at(pos) @@ -597,8 +601,8 @@ def remove_node_mark(self, pos: int, mark: Mark | MarkType) -> "Transform": def split( self, pos: int, - depth: int | None = None, - types_after: list[structure.NodeTypeWithAttrs] | None = None, + depth: Optional[int] = None, + types_after: Optional[list[structure.NodeTypeWithAttrs]] = None, ) -> "Transform": if depth is None: depth = 1 diff --git a/prosemirror/utils.py b/prosemirror/utils.py index 99d6739..8dd0c35 100644 --- a/prosemirror/utils.py +++ b/prosemirror/utils.py @@ -1,11 +1,11 @@ -from typing import Mapping, Sequence +from typing import Mapping, Sequence, Union from typing_extensions import TypeAlias JSONDict: TypeAlias = Mapping[str, "JSON"] JSONList: TypeAlias = Sequence["JSON"] -JSON: TypeAlias = JSONDict | JSONList | str | int | float | bool | None +JSON: TypeAlias = Union[JSONDict, JSONList, str, int, float, bool, None] Attrs: TypeAlias = JSONDict From 0e6c2e3e080e3d15ba6cd24469f0ad681acf10fe Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 09:15:48 -0500 Subject: [PATCH 35/40] Removing the mypy check from "tests" since there seems to be no way to override the strict option for some modules (https://github.com/python/mypy/issues/11401). --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3c13753..0634fa7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -34,4 +34,4 @@ jobs: - run: poetry install --no-interaction - run: poetry run black --check prosemirror tests - run: poetry run ruff prosemirror tests - - run: poetry run mypy prosemirror tests + - run: poetry run mypy prosemirror From f467c25856759910e4d9f71610acf9a226c00b17 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 09:56:44 -0500 Subject: [PATCH 36/40] Minor changes to types are reviewing the original TS repos. --- prosemirror/model/to_dom.py | 11 +++++---- prosemirror/transform/attr_step.py | 2 +- prosemirror/transform/doc_attr_step.py | 2 +- prosemirror/transform/mark_step.py | 32 +++++++++++++------------- prosemirror/transform/transform.py | 8 +++---- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index 8e33679..d723a7c 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -2,6 +2,7 @@ from typing import ( Any, Callable, + Mapping, Optional, Sequence, Union, @@ -11,7 +12,7 @@ from .fragment import Fragment from .mark import Mark from .node import Node -from .schema import Schema +from .schema import MarkType, NodeType, Schema HTMLNode = Union["Element", "str"] @@ -76,7 +77,7 @@ def __init__( self.marks = marks def serialize_fragment( - self, fragment: Fragment, target: Optional[Element] = None + self, fragment: Fragment, target: Union[Element, DocumentFragment, None] = None ) -> DocumentFragment: tgt: DocumentFragment = target or DocumentFragment(children=[]) @@ -189,7 +190,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMSerializer": @classmethod def nodes_from_schema( - cls, schema: Schema[Any, Any] + cls, schema: Schema[str, Any] ) -> dict[str, Callable[["Node"], HTMLOutputSpec]]: result = gather_to_dom(schema.nodes) if "text" not in result: @@ -203,7 +204,9 @@ def marks_from_schema( return gather_to_dom(schema.marks) -def gather_to_dom(obj: dict[str, Any]) -> dict[str, Callable[..., Any]]: +def gather_to_dom( + obj: Mapping[str, Union[NodeType, MarkType]] +) -> dict[str, Callable[..., Any]]: result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") diff --git a/prosemirror/transform/attr_step.py b/prosemirror/transform/attr_step.py index 88eeca6..cb6ef01 100644 --- a/prosemirror/transform/attr_step.py +++ b/prosemirror/transform/attr_step.py @@ -32,7 +32,7 @@ def apply(self, doc: Node) -> StepResult: def get_map(self) -> StepMap: return StepMap.empty - def invert(self, doc: Node) -> "AttrStep": + def invert(self, doc: Node) -> Step: node_at_pos = doc.node_at(self.pos) assert node_at_pos is not None return AttrStep(self.pos, self.attr, node_at_pos.attrs[self.attr]) diff --git a/prosemirror/transform/doc_attr_step.py b/prosemirror/transform/doc_attr_step.py index 0d84ce8..b7fb63b 100644 --- a/prosemirror/transform/doc_attr_step.py +++ b/prosemirror/transform/doc_attr_step.py @@ -23,7 +23,7 @@ def apply(self, doc: Node) -> StepResult: def get_map(self) -> StepMap: return StepMap.empty - def invert(self, doc: Node) -> "DocAttrStep": + def invert(self, doc: Node) -> Step: return DocAttrStep(self.attr, doc.attrs[self.attr]) def map(self, mapping: Mappable) -> Optional[Step]: diff --git a/prosemirror/transform/mark_step.py b/prosemirror/transform/mark_step.py index 5150c73..5571475 100644 --- a/prosemirror/transform/mark_step.py +++ b/prosemirror/transform/mark_step.py @@ -8,8 +8,8 @@ def map_fragment( fragment: Fragment, - f: Callable[[Node, Optional[Node], int], Node], - parent: Optional[Node] = None, + f: Callable[[Node, Node, int], Node], + parent: Node, ) -> Fragment: mapped = [] for i in range(fragment.child_count): @@ -48,17 +48,17 @@ def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc: Optional[Node] = None) -> "RemoveMarkStep": + def invert(self, doc: Optional[Node] = None) -> Step: return RemoveMarkStep(self.from_, self.to, self.mark) - def map(self, mapping: Mappable) -> Optional["AddMarkStep"]: + def map(self, mapping: Mappable) -> Optional[Step]: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if from_.deleted and to.deleted or from_.pos > to.pos: return None return AddMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: "Step") -> Optional["AddMarkStep"]: + def merge(self, other: Step) -> Optional[Step]: if ( isinstance(other, AddMarkStep) and other.mark.eq(self.mark) @@ -115,23 +115,23 @@ def iteratee(node: Node, parent: Optional[Node], i: int) -> Node: return node.mark(self.mark.remove_from_set(node.marks)) slice = Slice( - map_fragment(old_slice.content, iteratee), + map_fragment(old_slice.content, iteratee, doc), old_slice.open_start, old_slice.open_end, ) return StepResult.from_replace(doc, self.from_, self.to, slice) - def invert(self, doc: Optional[Node] = None) -> AddMarkStep: + def invert(self, doc: Optional[Node] = None) -> Step: return AddMarkStep(self.from_, self.to, self.mark) - def map(self, mapping: Mappable) -> Optional["RemoveMarkStep"]: + def map(self, mapping: Mappable) -> Optional[Step]: from_ = mapping.map_result(self.from_, 1) to = mapping.map_result(self.to, -1) if (from_.deleted and to.deleted) or (from_.pos > to.pos): return None return RemoveMarkStep(from_.pos, to.pos, self.mark) - def merge(self, other: "Step") -> Optional["RemoveMarkStep"]: + def merge(self, other: Step) -> Optional[Step]: if ( isinstance(other, RemoveMarkStep) and (other.mark.eq(self.mark)) @@ -152,7 +152,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> Step: if isinstance(json_data, str): import json @@ -190,7 +190,7 @@ def apply(self, doc: Node) -> StepResult: Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def invert(self, doc: Node) -> Union["RemoveNodeMarkStep", "AddNodeMarkStep"]: + def invert(self, doc: Node) -> Step: node = doc.node_at(self.pos) if node: new_set = self.mark.add_to_set(node.marks) @@ -201,7 +201,7 @@ def invert(self, doc: Node) -> Union["RemoveNodeMarkStep", "AddNodeMarkStep"]: return AddNodeMarkStep(self.pos, self.mark) return RemoveNodeMarkStep(self.pos, self.mark) - def map(self, mapping: Mappable) -> Optional["AddNodeMarkStep"]: + def map(self, mapping: Mappable) -> Optional[Step]: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else AddNodeMarkStep(pos.pos, self.mark) @@ -213,7 +213,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> Step: if isinstance(json_data, str): import json @@ -249,13 +249,13 @@ def apply(self, doc: Node) -> StepResult: Slice(Fragment.from_(updated), 0, 0 if node.is_leaf else 1), ) - def invert(self, doc: Node) -> Union["RemoveNodeMarkStep", "AddNodeMarkStep"]: + def invert(self, doc: Node) -> Step: node = doc.node_at(self.pos) if not node or not self.mark.is_in_set(node.marks): return self return AddNodeMarkStep(self.pos, self.mark) - def map(self, mapping: Mappable) -> Optional["RemoveNodeMarkStep"]: + def map(self, mapping: Mappable) -> Optional[Step]: pos = mapping.map_result(self.pos, 1) return None if pos.deleted_after else RemoveNodeMarkStep(pos.pos, self.mark) @@ -267,7 +267,7 @@ def to_json(self) -> JSONDict: } @staticmethod - def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Step": + def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> Step: if isinstance(json_data, str): import json diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index 5b273d4..f965ea6 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -270,7 +270,7 @@ def replace_with( self, from_: int, to: int, - content: Union[list[Node], Node, Fragment], + content: Union[Fragment, Node, list[Node]], ) -> "Transform": return self.replace(from_, to, Slice(Fragment.from_(content), 0, 0)) @@ -280,7 +280,7 @@ def delete(self, from_: int, to: int) -> "Transform": def insert( self, pos: int, - content: Union[list[Node], Node, Fragment], + content: Union[Fragment, Node, list[Node]], ) -> "Transform": return self.replace_with(pos, pos, content) @@ -572,9 +572,7 @@ def set_node_markup( ) ) - def set_node_attribute( - self, pos: int, attr: str, value: Union[str, int] - ) -> "Transform": + def set_node_attribute(self, pos: int, attr: str, value: JSON) -> "Transform": return self.step(AttrStep(pos, attr, value)) def set_doc_attribute(self, attr: str, value: JSON) -> "Transform": From be372f1b615fed0f9a33dd34937ea466f38498aa Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 10:13:50 -0500 Subject: [PATCH 37/40] More minor fixes related to type annotations syntax and Python 3.9. --- prosemirror/model/resolvedpos.py | 8 +++++--- prosemirror/test_builder/build.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index c0ddd4d..d381de9 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Optional, cast +from typing import TYPE_CHECKING, Callable, Optional, Union, cast from .mark import Mark @@ -7,7 +7,9 @@ class ResolvedPos: - def __init__(self, pos: int, path: list["Node | int"], parent_offset: int) -> None: + def __init__( + self, pos: int, path: list[Union["Node", int]], parent_offset: int + ) -> None: self.pos = pos self.path = path self.depth = int(len(path) / 3 - 1) @@ -181,7 +183,7 @@ def __str__(self) -> str: def resolve(cls, doc: "Node", pos: int) -> "ResolvedPos": if not (pos >= 0 and pos <= doc.content.size): raise ValueError(f"Position {pos} out of range") - path: list["Node | int"] = [] + path: list[Union["Node", int]] = [] start = 0 parent_offset = pos node = doc diff --git a/prosemirror/test_builder/build.py b/prosemirror/test_builder/build.py index e1ab6d1..1096992 100644 --- a/prosemirror/test_builder/build.py +++ b/prosemirror/test_builder/build.py @@ -2,7 +2,7 @@ import re from collections.abc import Callable -from typing import Any +from typing import Any, Union from prosemirror.model import Node, Schema from prosemirror.utils import JSONDict @@ -12,7 +12,7 @@ def flatten( schema: Schema[Any, Any], - children: list[Node | JSONDict | str], + children: list[Union[Node, JSONDict, str]], f: Callable[[Node], Node], ) -> tuple[list[Node], dict[str, int]]: result, pos, tag = [], 0, NO_TAG From f1fad36635f917e7263cea0b8946e9ae9370f58f Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 10:27:45 -0500 Subject: [PATCH 38/40] Making type annotations compatible with Python 3.8. --- prosemirror/model/content.py | 54 +++++++++---------- prosemirror/model/fragment.py | 16 +++--- prosemirror/model/from_dom.py | 60 ++++++++++----------- prosemirror/model/mark.py | 16 +++--- prosemirror/model/node.py | 16 +++--- prosemirror/model/replace.py | 12 ++--- prosemirror/model/resolvedpos.py | 10 ++-- prosemirror/model/schema.py | 66 ++++++++++++------------ prosemirror/model/to_dom.py | 23 +++++---- prosemirror/schema/basic/schema_basic.py | 6 +-- prosemirror/schema/list/schema_list.py | 6 +-- prosemirror/transform/map.py | 8 +-- prosemirror/transform/replace.py | 12 ++--- prosemirror/transform/step.py | 6 +-- prosemirror/transform/structure.py | 10 ++-- prosemirror/transform/transform.py | 18 +++---- 16 files changed, 174 insertions(+), 165 deletions(-) diff --git a/prosemirror/model/content.py b/prosemirror/model/content.py index 4b3a7df..95442d4 100644 --- a/prosemirror/model/content.py +++ b/prosemirror/model/content.py @@ -3,6 +3,8 @@ from typing import ( TYPE_CHECKING, ClassVar, + Dict, + List, Literal, NamedTuple, NoReturn, @@ -30,10 +32,10 @@ def __init__(self, type: "NodeType", next: "ContentMatch") -> None: class WrapCacheEntry: target: "NodeType" - computed: Optional[list["NodeType"]] + computed: Optional[List["NodeType"]] def __init__( - self, target: "NodeType", computed: Optional[list["NodeType"]] + self, target: "NodeType", computed: Optional[List["NodeType"]] ) -> None: self.target = target self.computed = computed @@ -55,8 +57,8 @@ class ContentMatch: empty: ClassVar["ContentMatch"] valid_end: bool - next: list[MatchEdge] - wrap_cache: list[WrapCacheEntry] + next: List[MatchEdge] + wrap_cache: List[WrapCacheEntry] def __init__(self, valid_end: bool) -> None: self.valid_end = valid_end @@ -64,7 +66,7 @@ def __init__(self, valid_end: bool) -> None: self.wrap_cache = [] @classmethod - def parse(cls, string: str, node_types: dict[str, "NodeType"]) -> "ContentMatch": + def parse(cls, string: str, node_types: Dict[str, "NodeType"]) -> "ContentMatch": stream = TokenStream(string, node_types) if stream.next() is None: return ContentMatch.empty @@ -117,7 +119,7 @@ def fill_before( ) -> Optional[Fragment]: seen = [self] - def search(match: ContentMatch, types: list["NodeType"]) -> Optional[Fragment]: + def search(match: ContentMatch, types: List["NodeType"]) -> Optional[Fragment]: nonlocal seen finished = match.match_fragment(after, start_index) if finished and (not to_end or finished.valid_end): @@ -136,7 +138,7 @@ def search(match: ContentMatch, types: list["NodeType"]) -> Optional[Fragment]: return search(self, []) - def find_wrapping(self, target: "NodeType") -> Optional[list["NodeType"]]: + def find_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: for entry in self.wrap_cache: if entry.target.name == target.name: return entry.computed @@ -144,9 +146,9 @@ def find_wrapping(self, target: "NodeType") -> Optional[list["NodeType"]]: self.wrap_cache.append(WrapCacheEntry(target, computed)) return computed - def compute_wrapping(self, target: "NodeType") -> Optional[list["NodeType"]]: + def compute_wrapping(self, target: "NodeType") -> Optional[List["NodeType"]]: seen = {} - active: list[Active] = [{"match": self, "type": None, "via": None}] + active: List[Active] = [{"match": self, "type": None, "via": None}] while len(active): current = active.pop(0) match = current["match"] @@ -218,9 +220,9 @@ def iteratee(m: "ContentMatch", i: int) -> str: class TokenStream: inline: Optional[bool] - tokens: list[str] + tokens: List[str] - def __init__(self, string: str, node_types: dict[str, "NodeType"]) -> None: + def __init__(self, string: str, node_types: Dict[str, "NodeType"]) -> None: self.string = string self.node_types = node_types self.inline = None @@ -247,12 +249,12 @@ def err(self, str: str) -> NoReturn: class ChoiceExpr(TypedDict): type: Literal["choice"] - exprs: list["Expr"] + exprs: List["Expr"] class SeqExpr(TypedDict): type: Literal["seq"] - exprs: list["Expr"] + exprs: List["Expr"] class PlusExpr(TypedDict): @@ -350,7 +352,7 @@ def parse_expr_range(stream: TokenStream, expr: Expr) -> Expr: return {"type": "range", "min": min_, "max": max_, "expr": expr} -def resolve_name(stream: TokenStream, name: str) -> list["NodeType"]: +def resolve_name(stream: TokenStream, name: str) -> List["NodeType"]: types = stream.node_types type = types.get(name) if type: @@ -400,8 +402,8 @@ class Edge(TypedDict): def nfa( expr: Expr, -) -> list[list[Edge]]: - nfa_: list[list[Edge]] = [[]] +) -> List[List[Edge]]: + nfa_: List[List[Edge]] = [[]] def node() -> int: nonlocal nfa_ @@ -416,17 +418,17 @@ def edge( nfa_[from_].append(edge) return edge - def connect(edges: list[Edge], to: int) -> None: + def connect(edges: List[Edge], to: int) -> None: for edge in edges: edge["to"] = to - def compile(expr: Expr, from_: int) -> list[Edge]: + def compile(expr: Expr, from_: int) -> List[Edge]: if expr["type"] == "choice": return list( reduce( lambda out, expr: [*out, *compile(expr, from_)], expr["exprs"], - cast(list[Edge], []), + cast(List[Edge], []), ) ) elif expr["type"] == "seq": @@ -477,9 +479,9 @@ def cmp(a: int, b: int) -> int: def null_from( - nfa: list[list[Edge]], + nfa: List[List[Edge]], node: int, -) -> list[int]: +) -> List[int]: result = [] def scan(n: int) -> None: @@ -499,21 +501,21 @@ def scan(n: int) -> None: class DFAState(NamedTuple): state: "NodeType" - next: list[int] + next: List[int] -def dfa(nfa: list[list[Edge]]) -> ContentMatch: +def dfa(nfa: List[List[Edge]]) -> ContentMatch: labeled = {} - def explore(states: list[int]) -> ContentMatch: + def explore(states: List[int]) -> ContentMatch: nonlocal labeled - out: list[DFAState] = [] + out: List[DFAState] = [] for node in states: for item in nfa[node]: term, to = item.get("term"), item.get("to") if not term: continue - set: Optional[list[int]] = None + set: Optional[List[int]] = None for t in out: if t[0] == term: set = t[1] diff --git a/prosemirror/model/fragment.py b/prosemirror/model/fragment.py index 43e8f21..3e2b72e 100644 --- a/prosemirror/model/fragment.py +++ b/prosemirror/model/fragment.py @@ -3,7 +3,9 @@ Any, Callable, ClassVar, + Dict, Iterable, + List, Optional, Sequence, Union, @@ -19,16 +21,16 @@ from .node import Node, TextNode -def retIndex(index: int, offset: int) -> dict[str, int]: +def retIndex(index: int, offset: int) -> Dict[str, int]: return {"index": index, "offset": offset} class Fragment: empty: ClassVar["Fragment"] - content: list["Node"] + content: List["Node"] size: int - def __init__(self, content: list["Node"], size: Optional[int] = None) -> None: + def __init__(self, content: List["Node"], size: Optional[int] = None) -> None: self.content = content self.size = size if size is not None else sum(c.node_size for c in content) @@ -123,7 +125,7 @@ def cut(self, from_: int, to: Optional[int] = None) -> "Fragment": to = self.size if from_ == 0 and to == self.size: return self - result: list["Node"] = [] + result: List["Node"] = [] size = 0 if to <= from_: return Fragment(result, size) @@ -224,7 +226,7 @@ def find_diff_end( other_pos = other.size return find_diff_end(self, other, pos, other_pos) - def find_index(self, pos: int, round: int = -1) -> dict[str, int]: + def find_index(self, pos: int, round: int = -1) -> Dict[str, int]: if pos == 0: return retIndex(0, pos) if pos == self.size: @@ -264,10 +266,10 @@ def from_json(cls, schema: "Schema[Any, Any]", value: Any) -> "Fragment": return cls([schema.node_from_json(item) for item in value]) @classmethod - def from_array(cls, array: list["Node"]) -> "Fragment": + def from_array(cls, array: List["Node"]) -> "Fragment": if not array: return cls.empty - joined: Optional[list["Node"]] = None + joined: Optional[List["Node"]] = None size = 0 for i in range(len(array)): node = array[i] diff --git a/prosemirror/model/from_dom.py b/prosemirror/model/from_dom.py index 86a22d0..1b7bb4a 100644 --- a/prosemirror/model/from_dom.py +++ b/prosemirror/model/from_dom.py @@ -1,7 +1,7 @@ import itertools import re from dataclasses import dataclass -from typing import Any, Callable, Literal, Optional, Union, cast +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast import lxml from lxml.cssselect import CSSSelector @@ -30,7 +30,7 @@ class DOMPosition: @dataclass(frozen=True) class ParseOptions: preserve_whitespace: WSType = None - find_positions: Optional[list[DOMPosition]] = None + find_positions: Optional[List[DOMPosition]] = None from_: Optional[int] = None to_: Optional[int] = None top_node: Optional[Node] = None @@ -61,7 +61,7 @@ class ParseRule: preserve_whitespace: WSType @classmethod - def from_json(cls, data: dict[str, Any]) -> "ParseRule": + def from_json(cls, data: Dict[str, Any]) -> "ParseRule": return ParseRule( data.get("tag"), data.get("namespace"), @@ -84,14 +84,14 @@ def from_json(cls, data: dict[str, Any]) -> "ParseRule": class DOMParser: - _tags: list[ParseRule] - _styles: list[ParseRule] + _tags: List[ParseRule] + _styles: List[ParseRule] _normalize_lists: bool schema: Schema[Any, Any] - rules: list[ParseRule] + rules: List[ParseRule] - def __init__(self, schema: Schema[Any, Any], rules: list[ParseRule]) -> None: + def __init__(self, schema: Schema[Any, Any], rules: List[ParseRule]) -> None: self.schema = schema self.rules = rules self._tags = [rule for rule in rules if rule.tag is not None] @@ -209,8 +209,8 @@ def match_style( return None @classmethod - def schema_rules(cls, schema: Schema[Any, Any]) -> list[ParseRule]: - result: list[ParseRule] = [] + def schema_rules(cls, schema: Schema[Any, Any]) -> List[ParseRule]: + result: List[ParseRule] = [] def insert(rule: ParseRule) -> None: priority = rule.priority if rule.priority is not None else 50 @@ -262,7 +262,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": return cast("DOMParser", schema.cached["dom_parser"]) -BLOCK_TAGS: dict[str, bool] = { +BLOCK_TAGS: Dict[str, bool] = { "address": True, "article": True, "aside": True, @@ -297,7 +297,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": "ul": True, } -IGNORE_TAGS: dict[str, bool] = { +IGNORE_TAGS: Dict[str, bool] = { "head": True, "noscript": True, "object": True, @@ -306,7 +306,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMParser": "title": True, } -LIST_TAGS: dict[str, bool] = {"ol": True, "ul": True} +LIST_TAGS: Dict[str, bool] = {"ol": True, "ul": True} OPT_PRESERVE_WS = 1 @@ -331,17 +331,17 @@ def ws_options_for( class NodeContext: match: Optional[ContentMatch] - content: list[Node] + content: List[Node] - active_marks: list[Mark] - stash_marks: list[Mark] + active_marks: List[Mark] + stash_marks: List[Mark] type: Optional[NodeType] options: int attrs: Optional[Attrs] - marks: list[Mark] - pending_marks: list[Mark] + marks: List[Mark] + pending_marks: List[Mark] solid: bool @@ -349,8 +349,8 @@ def __init__( self, _type: Optional[NodeType], attrs: Optional[Attrs], - marks: list[Mark], - pending_marks: list[Mark], + marks: List[Mark], + pending_marks: List[Mark], solid: bool, match: Optional[ContentMatch], options: int, @@ -376,7 +376,7 @@ def __init__( self.active_marks = Mark.none self.stash_marks = [] - def find_wrapping(self, node: Node) -> Optional[list[NodeType]]: + def find_wrapping(self, node: Node) -> Optional[List[NodeType]]: if not self.match: if not self.type: return [] @@ -459,9 +459,9 @@ def inline_context(self, node: DOMNode) -> bool: class ParseContext: open: int = 0 - find: Optional[list[DOMPosition]] + find: Optional[List[DOMPosition]] needs_block: bool - nodes: list[NodeContext] + nodes: List[NodeContext] options: ParseOptions is_open: bool parser: DOMParser @@ -651,9 +651,9 @@ def ignore_fallback(self, dom_: DOMNode) -> None: ): self.find_place(self.parser.schema.text("-")) - def read_styles(self, styles: list[str]) -> Optional[tuple[list[Mark], list[Mark]]]: - add: list[Mark] = Mark.none - remove: list[Mark] = Mark.none + def read_styles(self, styles: List[str]) -> Optional[Tuple[List[Mark], List[Mark]]]: + add: List[Mark] = Mark.none + remove: List[Mark] = Mark.none for i in range(0, len(styles), 2): after: Optional[ParseRule] = None @@ -754,7 +754,7 @@ def add_all( self.find_at_point(parent, index) def find_place(self, node: Node) -> bool: - route: Optional[list[NodeType]] = None + route: Optional[List[NodeType]] = None sync: Optional[NodeContext] = None depth = self.open @@ -1059,9 +1059,9 @@ def matches(dom_: DOMNode, selector_str: str) -> bool: return bool(dom_ in selector(dom_)) # type: ignore[operator] -def parse_styles(style: str) -> list[str]: +def parse_styles(style: str) -> List[str]: regex = r"\s*([\w-]+)\s*:\s*([^;]+)" - result: list[str] = [] + result: List[str] = [] for m in re.findall(regex, style): result.append(m[0]) @@ -1077,7 +1077,7 @@ def mark_may_apply(mark_type: MarkType, node_type: NodeType) -> bool: if not parent.allows_mark_type(mark_type): continue - seen: list[ContentMatch] = [] + seen: List[ContentMatch] = [] def scan(match: ContentMatch) -> bool: seen.append(match) @@ -1101,7 +1101,7 @@ def scan(match: ContentMatch) -> bool: return False -def find_same_mark_in_set(mark: Mark, mark_set: list[Mark]) -> Optional[Mark]: +def find_same_mark_in_set(mark: Mark, mark_set: List[Mark]) -> Optional[Mark]: for comp in mark_set: if mark.eq(comp): return comp diff --git a/prosemirror/model/mark.py b/prosemirror/model/mark.py index 5455197..fabc247 100644 --- a/prosemirror/model/mark.py +++ b/prosemirror/model/mark.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Final, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Final, List, Optional, Union, cast from prosemirror.utils import Attrs, JSONDict @@ -8,14 +8,14 @@ class Mark: - none: Final[list["Mark"]] = [] + none: Final[List["Mark"]] = [] def __init__(self, type: "MarkType", attrs: Attrs) -> None: self.type = type self.attrs = attrs - def add_to_set(self, set: list["Mark"]) -> list["Mark"]: - copy: Optional[list["Mark"]] = None + def add_to_set(self, set: List["Mark"]) -> List["Mark"]: + copy: Optional[List["Mark"]] = None placed = False for i in range(len(set)): other = set[i] @@ -40,10 +40,10 @@ def add_to_set(self, set: list["Mark"]) -> list["Mark"]: copy.append(self) return copy - def remove_from_set(self, set: list["Mark"]) -> list["Mark"]: + def remove_from_set(self, set: List["Mark"]) -> List["Mark"]: return [item for item in set if not item.eq(self)] - def is_in_set(self, set: list["Mark"]) -> bool: + def is_in_set(self, set: List["Mark"]) -> bool: return any(item.eq(self) for item in set) def eq(self, other: "Mark") -> bool: @@ -69,7 +69,7 @@ def from_json( return type.create(cast(Optional[JSONDict], json_data.get("attrs"))) @classmethod - def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: + def same_set(cls, a: List["Mark"], b: List["Mark"]) -> bool: if a == b: return True if len(a) != len(b): @@ -77,7 +77,7 @@ def same_set(cls, a: list["Mark"], b: list["Mark"]) -> bool: return all(item_a.eq(item_b) for (item_a, item_b) in zip(a, b)) @classmethod - def set_from(cls, marks: Union[list["Mark"], "Mark", None]) -> list["Mark"]: + def set_from(cls, marks: Union[List["Mark"], "Mark", None]) -> List["Mark"]: if not marks: return cls.none if isinstance(marks, Mark): diff --git a/prosemirror/model/node.py b/prosemirror/model/node.py index 8635c03..55d17bc 100644 --- a/prosemirror/model/node.py +++ b/prosemirror/model/node.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypedDict, Union, cast from typing_extensions import TypeGuard @@ -31,7 +31,7 @@ def __init__( type: "NodeType", attrs: "Attrs", content: Optional[Fragment], - marks: list[Mark], + marks: List[Mark], ) -> None: self.type = type self.attrs = attrs @@ -104,7 +104,7 @@ def has_markup( self, type: "NodeType", attrs: Optional["Attrs"] = None, - marks: Optional[list[Mark]] = None, + marks: Optional[List[Mark]] = None, ) -> bool: return ( self.type.name == type.name @@ -117,7 +117,7 @@ def copy(self, content: Optional[Fragment] = None) -> "Node": return self return self.__class__(self.type, self.attrs, content, self.marks) - def mark(self, marks: list[Mark]) -> "Node": + def mark(self, marks: List[Mark]) -> "Node": if marks == self.marks: return self return self.__class__(self.type, self.attrs, self.content, marks) @@ -272,7 +272,7 @@ def can_replace( return True def can_replace_with( - self, from_: int, to: int, type: "NodeType", marks: Optional[list[Mark]] = None + self, from_: int, to: int, type: "NodeType", marks: Optional[List[Mark]] = None ) -> bool: if marks and not self.type.allows_marks(marks): return False @@ -356,7 +356,7 @@ def __init__( type: "NodeType", attrs: "Attrs", content: str, - marks: list[Mark], + marks: List[Mark], ) -> None: super().__init__(type, attrs, None, marks) if not content: @@ -392,7 +392,7 @@ def text_between( def node_size(self) -> int: return text_length(self.text) - def mark(self, marks: list[Mark]) -> "TextNode": + def mark(self, marks: List[Mark]) -> "TextNode": return ( self if marks == self.marks @@ -423,7 +423,7 @@ def to_json( return {**super().to_json(), "text": self.text} -def wrap_marks(marks: list[Mark], str: str) -> str: +def wrap_marks(marks: List[Mark], str: str) -> str: i = len(marks) - 1 while i >= 0: str = marks[i].type.name + "(" + str + ")" diff --git a/prosemirror/model/replace.py b/prosemirror/model/replace.py index 70445c3..a6cc4dd 100644 --- a/prosemirror/model/replace.py +++ b/prosemirror/model/replace.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, cast from prosemirror.utils import JSONDict @@ -186,7 +186,7 @@ def joinable(before: "ResolvedPos", after: "ResolvedPos", depth: int) -> "Node": return node -def add_node(child: "Node", target: list["Node"]) -> None: +def add_node(child: "Node", target: List["Node"]) -> None: last = len(target) - 1 if last >= 0 and pm_node.is_text(child) and child.same_markup(target[last]): target[last] = child.with_text(cast("TextNode", target[last]).text + child.text) @@ -198,7 +198,7 @@ def add_range( start: Optional["ResolvedPos"], end: Optional["ResolvedPos"], depth: int, - target: list["Node"], + target: List["Node"], ) -> None: node = cast("ResolvedPos", end or start).node(depth) start_index = 0 @@ -233,7 +233,7 @@ def replace_three_way( ) -> Fragment: open_start = joinable(from_, start, depth + 1) if from_.depth > depth else None open_end = joinable(end, to, depth + 1) if to.depth > depth else None - content: list["Node"] = [] + content: List["Node"] = [] add_range(None, from_, depth, content) if open_start and open_end and start.index(depth) == end.index(depth): check_join(open_start, open_end) @@ -254,7 +254,7 @@ def replace_three_way( def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Fragment: - content: list["Node"] = [] + content: List["Node"] = [] add_range(None, from_, depth, content) if from_.depth > depth: type = joinable(from_, to, depth + 1) @@ -265,7 +265,7 @@ def replace_two_way(from_: "ResolvedPos", to: "ResolvedPos", depth: int) -> Frag def prepare_slice_for_replace( slice: Slice, along: "ResolvedPos" -) -> dict[str, "ResolvedPos"]: +) -> Dict[str, "ResolvedPos"]: extra = along.depth - slice.open_start parent = along.node(extra) node = parent.copy(slice.content) diff --git a/prosemirror/model/resolvedpos.py b/prosemirror/model/resolvedpos.py index d381de9..2d95f10 100644 --- a/prosemirror/model/resolvedpos.py +++ b/prosemirror/model/resolvedpos.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Optional, Union, cast +from typing import TYPE_CHECKING, Callable, List, Optional, Union, cast from .mark import Mark @@ -8,7 +8,7 @@ class ResolvedPos: def __init__( - self, pos: int, path: list[Union["Node", int]], parent_offset: int + self, pos: int, path: List[Union["Node", int]], parent_offset: int ) -> None: self.pos = pos self.path = path @@ -97,7 +97,7 @@ def pos_at_index(self, index: int, depth: Optional[int] = None) -> int: pos += node.child(i).node_size return pos - def marks(self) -> list["Mark"]: + def marks(self) -> List["Mark"]: parent = self.parent index = self.index() if parent.content.size == 0: @@ -119,7 +119,7 @@ def marks(self) -> list["Mark"]: i += 1 return marks - def marks_across(self, end: "ResolvedPos") -> Optional[list["Mark"]]: + def marks_across(self, end: "ResolvedPos") -> Optional[List["Mark"]]: after = self.parent.maybe_child(self.index()) if not after or not after.is_inline: return None @@ -183,7 +183,7 @@ def __str__(self) -> str: def resolve(cls, doc: "Node", pos: int) -> "ResolvedPos": if not (pos >= 0 and pos <= doc.content.size): raise ValueError(f"Position {pos} out of range") - path: list[Union["Node", int]] = [] + path: List[Union["Node", int]] = [] start = 0 parent_offset = pos node = doc diff --git a/prosemirror/model/schema.py b/prosemirror/model/schema.py index de41d4a..f21914f 100644 --- a/prosemirror/model/schema.py +++ b/prosemirror/model/schema.py @@ -1,7 +1,9 @@ from typing import ( Any, Callable, + Dict, Generic, + List, Literal, Optional, TypeVar, @@ -67,7 +69,7 @@ class NodeType: inline_content: bool - mark_set: Optional[list["MarkType"]] + mark_set: Optional[List["MarkType"]] def __init__(self, name: str, schema: "Schema[Any, Any]", spec: "NodeSpec") -> None: self.name = name @@ -130,8 +132,8 @@ def compute_attrs(self, attrs: Optional[Attrs]) -> Attrs: def create( self, attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, list[Node], None] = None, - marks: Optional[list[Mark]] = None, + content: Union[Fragment, Node, List[Node], None] = None, + marks: Optional[List[Mark]] = None, ) -> Node: if self.is_text: raise ValueError("NodeType.create cannot construct text nodes") @@ -145,8 +147,8 @@ def create( def create_checked( self, attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, list[Node], None] = None, - marks: Optional[list[Mark]] = None, + content: Union[Fragment, Node, List[Node], None] = None, + marks: Optional[List[Mark]] = None, ) -> Node: content = Fragment.from_(content) if not self.valid_content(content): @@ -156,8 +158,8 @@ def create_checked( def create_and_fill( self, attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, list[Node], None] = None, - marks: Optional[list[Mark]] = None, + content: Union[Fragment, Node, List[Node], None] = None, + marks: Optional[List[Mark]] = None, ) -> Optional[Node]: attrs = self.compute_attrs(attrs) frag = Fragment.from_(content) @@ -186,15 +188,15 @@ def valid_content(self, content: Fragment) -> bool: def allows_mark_type(self, mark_type: "MarkType") -> bool: return self.mark_set is None or mark_type in self.mark_set - def allows_marks(self, marks: list[Mark]) -> bool: + def allows_marks(self, marks: List[Mark]) -> bool: if self.mark_set is None: return True return all(self.allows_mark_type(mark.type) for mark in marks) - def allowed_marks(self, marks: list[Mark]) -> list[Mark]: + def allowed_marks(self, marks: List[Mark]) -> List[Mark]: if self.mark_set is None: return marks - copy: Optional[list[Mark]] = None + copy: Optional[List[Mark]] = None for i, mark in enumerate(marks): if not self.allows_mark_type(mark.type): if not copy: @@ -210,9 +212,9 @@ def allowed_marks(self, marks: list[Mark]) -> list[Mark]: @classmethod def compile( - cls, nodes: dict["Nodes", "NodeSpec"], schema: "Schema[Nodes, Marks]" - ) -> dict["Nodes", "NodeType"]: - result: dict["Nodes", "NodeType"] = {} + cls, nodes: Dict["Nodes", "NodeSpec"], schema: "Schema[Nodes, Marks]" + ) -> Dict["Nodes", "NodeType"]: + result: Dict["Nodes", "NodeType"] = {} for name, spec in nodes.items(): result[name] = NodeType(name, schema, spec) @@ -233,7 +235,7 @@ def __repr__(self) -> str: return self.__str__() -Attributes: TypeAlias = dict[str, "Attribute"] +Attributes: TypeAlias = Dict[str, "Attribute"] class Attribute: @@ -247,7 +249,7 @@ def is_required(self) -> bool: class MarkType: - excluded: list["MarkType"] + excluded: List["MarkType"] instance: Optional[Mark] def __init__( @@ -274,8 +276,8 @@ def create( @classmethod def compile( - cls, marks: dict["Marks", "MarkSpec"], schema: "Schema[Nodes, Marks]" - ) -> dict["Marks", "MarkType"]: + cls, marks: Dict["Marks", "MarkSpec"], schema: "Schema[Nodes, Marks]" + ) -> Dict["Marks", "MarkType"]: result = {} rank = 0 for name, spec in marks.items(): @@ -283,10 +285,10 @@ def compile( rank += 1 return result - def remove_from_set(self, set_: list["Mark"]) -> list["Mark"]: + def remove_from_set(self, set_: List["Mark"]) -> List["Mark"]: return [item for item in set_ if item.type != self] - def is_in_set(self, set: list[Mark]) -> Optional[Mark]: + def is_in_set(self, set: List[Mark]) -> Optional[Mark]: return next((item for item in set if item.type == self), None) def excludes(self, other: "MarkType") -> bool: @@ -309,13 +311,13 @@ class SchemaSpec(TypedDict, Generic[Nodes, Marks]): # determines which [parse rules](#model.NodeSpec.parseDOM) take # precedence by default, and which nodes come first in a given # [group](#model.NodeSpec.group). - nodes: dict[Nodes, "NodeSpec"] + nodes: Dict[Nodes, "NodeSpec"] # The mark types that exist in this schema. The order in which they # are provided determines the order in which [mark # sets](#model.Mark.addToSet) are sorted and in which [parse # rules](#model.MarkSpec.parseDOM) are tried. - marks: NotRequired[dict[Marks, "MarkSpec"]] + marks: NotRequired[Dict[Marks, "MarkSpec"]] # The name of the default top-level node for the schema. Defaults # to `"doc"`. @@ -342,12 +344,12 @@ class NodeSpec(TypedDict, total=False): defining: bool isolating: bool toDOM: Callable[[Node], Any] # FIXME: add types - parseDOM: list[dict[str, Any]] # FIXME: add types + parseDOM: List[Dict[str, Any]] # FIXME: add types toDebugString: Callable[[Node], str] leafText: Callable[[Node], str] -AttributeSpecs: TypeAlias = dict[str, "AttributeSpec"] +AttributeSpecs: TypeAlias = Dict[str, "AttributeSpec"] class MarkSpec(TypedDict, total=False): @@ -357,7 +359,7 @@ class MarkSpec(TypedDict, total=False): group: str spanning: bool toDOM: Callable[[Mark, bool], Any] # FIXME: add types - parseDOM: list[dict[str, Any]] # FIXME: add types + parseDOM: List[Dict[str, Any]] # FIXME: add types class AttributeSpec(TypedDict, total=False): @@ -367,9 +369,9 @@ class AttributeSpec(TypedDict, total=False): class Schema(Generic[Nodes, Marks]): spec: SchemaSpec[Nodes, Marks] - nodes: dict[Nodes, "NodeType"] + nodes: Dict[Nodes, "NodeType"] - marks: dict[Marks, "MarkType"] + marks: Dict[Marks, "MarkType"] def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: self.spec = spec @@ -384,7 +386,7 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: mark_expr = type.spec.get("marks") if content_expr not in content_expr_cache: content_expr_cache[content_expr] = ContentMatch.parse( - content_expr, cast(dict[str, "NodeType"], self.nodes) + content_expr, cast(Dict[str, "NodeType"], self.nodes) ) type.content_match = content_expr_cache[content_expr] @@ -406,15 +408,15 @@ def __init__(self, spec: SchemaSpec[Nodes, Marks]) -> None: ) self.top_node_type = self.nodes[cast(Nodes, self.spec.get("topNode") or "doc")] - self.cached: dict[str, Any] = {} + self.cached: Dict[str, Any] = {} self.cached["wrappings"] = {} def node( self, type: Union[str, NodeType], attrs: Optional[Attrs] = None, - content: Union[Fragment, Node, list[Node], None] = None, - marks: Optional[list[Mark]] = None, + content: Union[Fragment, Node, List[Node], None] = None, + marks: Optional[List[Mark]] = None, ) -> Node: if isinstance(type, str): type = self.node_type(type) @@ -424,7 +426,7 @@ def node( raise ValueError(f"Node type from different schema used ({type.name})") return type.create_checked(attrs, content, marks) - def text(self, text: str, marks: Optional[list[Mark]] = None) -> TextNode: + def text(self, text: str, marks: Optional[List[Mark]] = None) -> TextNode: type = self.nodes[cast(Nodes, "text")] return TextNode( type, cast(Attrs, type.default_attrs), text, Mark.set_from(marks) @@ -455,7 +457,7 @@ def node_type(self, name: str) -> NodeType: return found -def gather_marks(schema: Schema[Any, Any], marks: list[str]) -> list[MarkType]: +def gather_marks(schema: Schema[Any, Any], marks: List[str]) -> List[MarkType]: found = [] for name in marks: mark = schema.marks.get(name) diff --git a/prosemirror/model/to_dom.py b/prosemirror/model/to_dom.py index d723a7c..2e782cb 100644 --- a/prosemirror/model/to_dom.py +++ b/prosemirror/model/to_dom.py @@ -2,9 +2,12 @@ from typing import ( Any, Callable, + Dict, + List, Mapping, Optional, Sequence, + Tuple, Union, cast, ) @@ -18,7 +21,7 @@ class DocumentFragment: - def __init__(self, children: list[HTMLNode]) -> None: + def __init__(self, children: List[HTMLNode]) -> None: self.children = children def __str__(self) -> str: @@ -48,7 +51,7 @@ def __str__(self) -> str: class Element(DocumentFragment): def __init__( - self, name: str, attrs: dict[str, str], children: list[HTMLNode] + self, name: str, attrs: Dict[str, str], children: List[HTMLNode] ) -> None: self.name = name self.attrs = attrs @@ -70,8 +73,8 @@ def __str__(self) -> str: class DOMSerializer: def __init__( self, - nodes: dict[str, Callable[[Node], HTMLOutputSpec]], - marks: dict[str, Callable[[Mark, bool], HTMLOutputSpec]], + nodes: Dict[str, Callable[[Node], HTMLOutputSpec]], + marks: Dict[str, Callable[[Mark, bool], HTMLOutputSpec]], ) -> None: self.nodes = nodes self.marks = marks @@ -82,7 +85,7 @@ def serialize_fragment( tgt: DocumentFragment = target or DocumentFragment(children=[]) top = tgt - active: Optional[list[tuple[Mark, DocumentFragment]]] = None + active: Optional[List[Tuple[Mark, DocumentFragment]]] = None def each(node: Node, offset: int, index: int) -> None: nonlocal top, active @@ -139,7 +142,7 @@ def serialize_node(self, node: Node) -> HTMLNode: def serialize_mark( self, mark: Mark, inline: bool - ) -> Optional[tuple[HTMLNode, Optional[Element]]]: + ) -> Optional[Tuple[HTMLNode, Optional[Element]]]: to_dom = self.marks.get(mark.type.name) if to_dom: return type(self).render_spec(to_dom(mark, inline)) @@ -148,7 +151,7 @@ def serialize_mark( @classmethod def render_spec( cls, structure: HTMLOutputSpec - ) -> tuple[HTMLNode, Optional[Element]]: + ) -> Tuple[HTMLNode, Optional[Element]]: if isinstance(structure, str): return html.escape(structure), None if isinstance(structure, Element): @@ -191,7 +194,7 @@ def from_schema(cls, schema: Schema[Any, Any]) -> "DOMSerializer": @classmethod def nodes_from_schema( cls, schema: Schema[str, Any] - ) -> dict[str, Callable[["Node"], HTMLOutputSpec]]: + ) -> Dict[str, Callable[["Node"], HTMLOutputSpec]]: result = gather_to_dom(schema.nodes) if "text" not in result: result["text"] = lambda node: node.text @@ -200,13 +203,13 @@ def nodes_from_schema( @classmethod def marks_from_schema( cls, schema: Schema[Any, Any] - ) -> dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: + ) -> Dict[str, Callable[["Mark", bool], HTMLOutputSpec]]: return gather_to_dom(schema.marks) def gather_to_dom( obj: Mapping[str, Union[NodeType, MarkType]] -) -> dict[str, Callable[..., Any]]: +) -> Dict[str, Callable[..., Any]]: result = {} for name in obj: to_dom = obj[name].spec.get("toDOM") diff --git a/prosemirror/schema/basic/schema_basic.py b/prosemirror/schema/basic/schema_basic.py index 39857af..2c022a3 100644 --- a/prosemirror/schema/basic/schema_basic.py +++ b/prosemirror/schema/basic/schema_basic.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict from prosemirror.model import Schema from prosemirror.model.schema import MarkSpec, NodeSpec @@ -9,7 +9,7 @@ pre_dom = ["pre", ["code", 0]] br_dom = ["br"] -nodes: dict[str, NodeSpec] = { +nodes: Dict[str, NodeSpec] = { "doc": {"content": "block+"}, "paragraph": { "content": "inline*", @@ -90,7 +90,7 @@ strong_dom = ["strong", 0] code_dom = ["code", 0] -marks: dict[str, MarkSpec] = { +marks: Dict[str, MarkSpec] = { "link": { "attrs": {"href": {}, "title": {"default": None}}, "inclusive": False, diff --git a/prosemirror/schema/list/schema_list.py b/prosemirror/schema/list/schema_list.py index a0146b8..8b984c0 100644 --- a/prosemirror/schema/list/schema_list.py +++ b/prosemirror/schema/list/schema_list.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Dict, cast from prosemirror.model.schema import Nodes, NodeSpec @@ -27,8 +27,8 @@ def add(obj: "NodeSpec", props: "NodeSpec") -> "NodeSpec": def add_list_nodes( - nodes: dict["Nodes", "NodeSpec"], item_content: str, list_group: str -) -> dict["Nodes", "NodeSpec"]: + nodes: Dict["Nodes", "NodeSpec"], item_content: str, list_group: str +) -> Dict["Nodes", "NodeSpec"]: copy = nodes.copy() copy.update( { diff --git a/prosemirror/transform/map.py b/prosemirror/transform/map.py index 639c375..d0778a8 100644 --- a/prosemirror/transform/map.py +++ b/prosemirror/transform/map.py @@ -1,6 +1,6 @@ import abc from collections.abc import Callable -from typing import ClassVar, Literal, Optional, Union, overload +from typing import ClassVar, List, Literal, Optional, Union, overload lower16 = 0xFFFF factor16 = 2**16 @@ -70,7 +70,7 @@ def map_result(self, pos: int, assoc: int = 1) -> MapResult: class StepMap(Mappable): empty: ClassVar["StepMap"] - def __init__(self, ranges: list[int], inverted: bool = False) -> None: + def __init__(self, ranges: List[int], inverted: bool = False) -> None: # prosemirror-transform overrides the constructor to return the # StepMap.empty singleton when ranges are empty. # It is not easy to do in Python, and the intent of that is to make sure @@ -182,8 +182,8 @@ def __str__(self) -> str: class Mapping(Mappable): def __init__( self, - maps: Optional[list[StepMap]] = None, - mirror: Optional[list[int]] = None, + maps: Optional[List[StepMap]] = None, + mirror: Optional[List[int]] = None, from_: Optional[int] = None, to: Optional[int] = None, ) -> None: diff --git a/prosemirror/transform/replace.py b/prosemirror/transform/replace.py index bc2b9f3..37945f0 100644 --- a/prosemirror/transform/replace.py +++ b/prosemirror/transform/replace.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import List, Optional, cast from prosemirror.model import ( ContentMatch, @@ -60,7 +60,7 @@ def __init__( frontier_depth: int, parent: Optional[Node], inject: Optional[Fragment] = None, - wrap: Optional[list[NodeType]] = None, + wrap: Optional[List[NodeType]] = None, ) -> None: self.slice_depth = slice_depth self.frontier_depth = frontier_depth @@ -91,7 +91,7 @@ def __init__(self, from__: ResolvedPos, to_: ResolvedPos, slice: Slice) -> None: self.from__ = from__ self.unplaced = slice - self.frontier: list[_FrontierItem] = [] + self.frontier: List[_FrontierItem] = [] for i in range(from__.depth + 1): node = from__.node(i) self.frontier.append( @@ -188,12 +188,12 @@ def _lazy_inject() -> Optional[Fragment]: inject = match.fill_before(Fragment.from_(first), False) return cast(Optional[Fragment], inject) - def _lazy_wrap() -> Optional[list[NodeType]]: + def _lazy_wrap() -> Optional[List[NodeType]]: nonlocal wrap assert first is not None if wrap is _nothing: wrap = match.find_wrapping(first.type) - return cast(Optional[list[NodeType]], wrap) + return cast(Optional[List[NodeType]], wrap) if pass_ == 1 and ( (match.match_type(first.type) or _lazy_inject()) @@ -557,7 +557,7 @@ def close_fragment( def covered_depths( from__: ResolvedPos, to_: ResolvedPos, -) -> list[int]: +) -> List[int]: result = [] min_depth = min(from__.depth, to_.depth) for d in range(min_depth, -1, -1): diff --git a/prosemirror/transform/step.py b/prosemirror/transform/step.py index bf647b3..5c3a160 100644 --- a/prosemirror/transform/step.py +++ b/prosemirror/transform/step.py @@ -1,12 +1,12 @@ import abc -from typing import Any, Literal, Optional, Type, TypeVar, Union, cast, overload +from typing import Any, Dict, Literal, Optional, Type, TypeVar, Union, cast, overload from prosemirror.model import Node, ReplaceError, Schema, Slice from prosemirror.transform.map import Mappable, StepMap from prosemirror.utils import JSONDict # like a registry -STEPS_BY_ID: dict[str, Type["Step"]] = {} +STEPS_BY_ID: Dict[str, Type["Step"]] = {} StepSubclass = TypeVar("StepSubclass", bound="Step") @@ -50,7 +50,7 @@ def from_json(schema: Schema[Any, Any], json_data: Union[JSONDict, str]) -> "Ste return type.from_json(schema, json_data) -def step_json_id(id: str, step_class: type[StepSubclass]) -> type[StepSubclass]: +def step_json_id(id: str, step_class: Type[StepSubclass]) -> Type[StepSubclass]: if id in STEPS_BY_ID: raise ValueError(f"Duplicated JSON ID for step type: {id}") diff --git a/prosemirror/transform/structure.py b/prosemirror/transform/structure.py index fb52b7d..15d10ca 100644 --- a/prosemirror/transform/structure.py +++ b/prosemirror/transform/structure.py @@ -1,4 +1,4 @@ -from typing import Optional, TypedDict, Union +from typing import List, Optional, TypedDict, Union from prosemirror.model import ContentMatch, Node, NodeRange, NodeType, Slice from prosemirror.utils import Attrs @@ -41,7 +41,7 @@ def find_wrapping( node_type: NodeType, attrs: Optional[Attrs] = None, inner_range: Optional[NodeRange] = None, -) -> Optional[list[NodeTypeWithAttrs]]: +) -> Optional[List[NodeTypeWithAttrs]]: if inner_range is None: inner_range = range_ @@ -69,7 +69,7 @@ def with_attrs(type: NodeType) -> NodeTypeWithAttrs: def find_wrapping_outside( range_: NodeRange, type: NodeType -) -> Optional[list[NodeType]]: +) -> Optional[List[NodeType]]: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -80,7 +80,7 @@ def find_wrapping_outside( return around if parent.can_replace_with(start_index, end_index, outer) else None -def find_wrapping_inside(range_: NodeRange, type: NodeType) -> Optional[list[NodeType]]: +def find_wrapping_inside(range_: NodeRange, type: NodeType) -> Optional[List[NodeType]]: parent = range_.parent start_index = range_.start_index end_index = range_.end_index @@ -114,7 +114,7 @@ def can_split( doc: Node, pos: int, depth: Optional[int] = None, - types_after: Optional[list[NodeTypeWithAttrs]] = None, + types_after: Optional[List[NodeTypeWithAttrs]] = None, ) -> bool: if depth is None: depth = 1 diff --git a/prosemirror/transform/transform.py b/prosemirror/transform/transform.py index f965ea6..7dcacd6 100644 --- a/prosemirror/transform/transform.py +++ b/prosemirror/transform/transform.py @@ -1,5 +1,5 @@ import re -from typing import Optional, TypedDict, Union +from typing import List, Optional, TypedDict, Union from prosemirror.model import ( ContentMatch, @@ -57,8 +57,8 @@ class Transform: def __init__(self, doc: Node) -> None: self.doc = doc - self.steps: list[Step] = [] - self.docs: list[Node] = [] + self.steps: List[Step] = [] + self.docs: List[Node] = [] self.mapping = Mapping() @property @@ -144,7 +144,7 @@ class MatchedTypedDict(TypedDict): to: int step: int - matched: list[MatchedTypedDict] = [] + matched: List[MatchedTypedDict] = [] step = 0 def iteratee( @@ -270,7 +270,7 @@ def replace_with( self, from_: int, to: int, - content: Union[Fragment, Node, list[Node]], + content: Union[Fragment, Node, List[Node]], ) -> "Transform": return self.replace(from_, to, Slice(Fragment.from_(content), 0, 0)) @@ -280,7 +280,7 @@ def delete(self, from_: int, to: int) -> "Transform": def insert( self, pos: int, - content: Union[Fragment, Node, list[Node]], + content: Union[Fragment, Node, List[Node]], ) -> "Transform": return self.replace_with(pos, pos, content) @@ -473,7 +473,7 @@ def lift(self, range_: NodeRange, target: int) -> "Transform": ) def wrap( - self, range_: NodeRange, wrappers: list[structure.NodeTypeWithAttrs] + self, range_: NodeRange, wrappers: List[structure.NodeTypeWithAttrs] ) -> "Transform": content = Fragment.empty i = len(wrappers) - 1 @@ -548,7 +548,7 @@ def set_node_markup( pos: int, type: Optional[NodeType], attrs: Optional[Attrs], - marks: Optional[list[Mark]] = None, + marks: Optional[List[Mark]] = None, ) -> "Transform": node = self.doc.node_at(pos) if not node: @@ -600,7 +600,7 @@ def split( self, pos: int, depth: Optional[int] = None, - types_after: Optional[list[structure.NodeTypeWithAttrs]] = None, + types_after: Optional[List[structure.NodeTypeWithAttrs]] = None, ) -> "Transform": if depth is None: depth = 1 From d12715e5ad7be7b30a3ab2f9c961796da9cb50f2 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 10:32:50 -0500 Subject: [PATCH 39/40] Using Callable from Typing for Python 3.8 compatibility reasons. --- prosemirror/test_builder/build.py | 3 +-- prosemirror/transform/map.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/prosemirror/test_builder/build.py b/prosemirror/test_builder/build.py index 1096992..0dc7029 100644 --- a/prosemirror/test_builder/build.py +++ b/prosemirror/test_builder/build.py @@ -1,8 +1,7 @@ # type: ignore import re -from collections.abc import Callable -from typing import Any, Union +from typing import Any, Callable, Union from prosemirror.model import Node, Schema from prosemirror.utils import JSONDict diff --git a/prosemirror/transform/map.py b/prosemirror/transform/map.py index d0778a8..1c8d564 100644 --- a/prosemirror/transform/map.py +++ b/prosemirror/transform/map.py @@ -1,6 +1,5 @@ import abc -from collections.abc import Callable -from typing import ClassVar, List, Literal, Optional, Union, overload +from typing import Callable, ClassVar, List, Literal, Optional, Union, overload lower16 = 0xFFFF factor16 = 2**16 From b061e0e90d1015b7e39a0337440c111c834c5870 Mon Sep 17 00:00:00 2001 From: Ernesto Ferro Date: Thu, 16 Nov 2023 10:34:23 -0500 Subject: [PATCH 40/40] Minor version bump due to the mypy related changes. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0b9e21d..359f78a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "prosemirror" -version = "0.3.7" +version = "0.4.0" description = "Python implementation of core ProseMirror modules for collaborative editing" readme = "README.md" authors = ["Shen Li "]